神经网络风格迁移实战(pytorch)

发布时间:2023-11-13 11:30

from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.optim as optim
from torchvision import transforms, models


def load_img(path, max_size=400, shape=None):
    img = Image.open(path).convert('RGB')

    if (max(img.size)) > max_size:  # 规定图像的最大尺寸
        size = max_size
    else:
        size = max(img.size)

    if shape is not None:
        size = shape
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),
                             (0.229, 0.224, 0.225))
    ])
    '''删除alpha通道(jpg), 转为png,补足另一个维度-batch'''
    img = transform(img)[:3, :, :].unsqueeze(0)
    return img



#转换为plt可以画出来的形式
def im_convert(tensor):
    img = tensor.clone().detach()
    img = img.numpy().squeeze()
    img = img.transpose(1, 2, 0)
    img = img * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    img = img.clip(0, 1)
    return img

# 获取特征层
def get_features(img, model, layers=None):
    if layers is None:
        layers = {
            '0': 'conv1_1',
            '5': 'conv2_1',
            '10': 'conv3_1',
            '19': 'conv4_1',
            '21': 'conv4_2',  # content层
            '28': 'conv5_1'
        }

    features = {}
    x = img

    #name是序号,layer是层结构,类似于这样
    # 2, Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    for name, layer in model._modules.items():
        print(name)
        print(layer)
        x = layer(x)
        if name in layers:
            features[layers[name]] = x
            print(features)

    print(features)
    return features

#计算格拉姆矩阵
def gram_matrix(tensor):

    _, d, h, w = tensor.size()  # 第一个是batch_size

    tensor = tensor.view(d, h * w)

    gram = torch.mm(tensor, tensor.t())

    return gram



#使用预训练的VGG19,features表示只提取不包括全连接层的部分
vgg = models.vgg19(pretrained=True).features
#此时vgg:
# Sequential(
#   (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (1): ReLU(inplace=True)
#   (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (3): ReLU(inplace=True)
#   (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
#   (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (6): ReLU(inplace=True)
#   (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (8): ReLU(inplace=True)
#   (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
#   (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (11): ReLU(inplace=True)
#   (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (13): ReLU(inplace=True)
#   (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (15): ReLU(inplace=True)
#   (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (17): ReLU(inplace=True)
#   (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
#   (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (20): ReLU(inplace=True)
#   (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (22): ReLU(inplace=True)
#   (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (24): ReLU(inplace=True)
#   (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (26): ReLU(inplace=True)
#   (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
#   (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (29): ReLU(inplace=True)
#   (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (31): ReLU(inplace=True)
#   (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (33): ReLU(inplace=True)
#   (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (35): ReLU(inplace=True)
#   (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# )



for i in vgg.parameters():
    i.requires_grad_(False)		#不要求训练VGG的参数



content  = load_img('turtle.jpg')
style = load_img('wave.jpg', shape=content.shape[-2:])		#让两张图尺寸一样


content_features = get_features(content, vgg)
style_features = get_features(style, vgg)

style_grams = {layer:gram_matrix(style_features[layer]) for layer in style_features}

#target就是生成图
target = content.clone().requires_grad_(True)

#定义不同层的权重
style_weights = {
    'conv1_1': 1,
    'conv2_1': 0.8,
    'conv3_1': 0.5,
    'conv4_1': 0.3,
    'conv5_1': 0.1,
}

#定义2种损失对应的权重
content_weight = 1
style_weight = 1e6

show_every = 2  #400
optimizer = optim.Adam([target], lr=0.003)
steps = 2000

for ii in range(steps):
    print(ii)
    target_features = get_features(target, vgg)

    #内容损失函数
    content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)   
    #风格损失函数
    style_loss = 0

    #加上每一层的格拉姆矩阵的损失
    for layer in style_weights:  #layer就是conv1_1、conv2_1等等这些

        target_feature = target_features[layer]
        target_gram = gram_matrix(target_feature)
        _, d, h, w = target_feature.shape
        style_gram = style_grams[layer]
        layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
        style_loss += layer_style_loss/(d*h*w)     #加到总的style_loss里,除以大小
        
    total_loss = content_weight * content_loss + style_weight * style_loss
    
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    if ii % show_every == 0 :
        print('Total Loss:',total_loss.item())
        plt.imshow(im_convert(target))
        plt.savefig('img_trans/%d.png' % ii)
        #plt.show()

 

 

神经网络风格迁移实战(pytorch)_第1张图片           神经网络风格迁移实战(pytorch)_第2张图片

神经网络风格迁移实战(pytorch)_第3张图片

训练的轮数默认为2000了,有点少,最终结果应该是类似这样的

神经网络风格迁移实战(pytorch)_第4张图片

 

神经网络风格迁移实战(pytorch)_第5张图片

 

因为没用GPU,所以速度略慢,跑了大概一天一晚上才跑完

 

参考

https://www.cnblogs.com/MartinLwx/p/10572466.html

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号