发布时间:2023-11-13 11:30
from PIL import Image import matplotlib.pyplot as plt import numpy as np import torch import torch.optim as optim from torchvision import transforms, models def load_img(path, max_size=400, shape=None): img = Image.open(path).convert('RGB') if (max(img.size)) > max_size: # 规定图像的最大尺寸 size = max_size else: size = max(img.size) if shape is not None: size = shape transform = transforms.Compose([ transforms.Resize(size), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) '''删除alpha通道(jpg), 转为png,补足另一个维度-batch''' img = transform(img)[:3, :, :].unsqueeze(0) return img #转换为plt可以画出来的形式 def im_convert(tensor): img = tensor.clone().detach() img = img.numpy().squeeze() img = img.transpose(1, 2, 0) img = img * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406)) img = img.clip(0, 1) return img # 获取特征层 def get_features(img, model, layers=None): if layers is None: layers = { '0': 'conv1_1', '5': 'conv2_1', '10': 'conv3_1', '19': 'conv4_1', '21': 'conv4_2', # content层 '28': 'conv5_1' } features = {} x = img #name是序号,layer是层结构,类似于这样 # 2, Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) for name, layer in model._modules.items(): print(name) print(layer) x = layer(x) if name in layers: features[layers[name]] = x print(features) print(features) return features #计算格拉姆矩阵 def gram_matrix(tensor): _, d, h, w = tensor.size() # 第一个是batch_size tensor = tensor.view(d, h * w) gram = torch.mm(tensor, tensor.t()) return gram #使用预训练的VGG19,features表示只提取不包括全连接层的部分 vgg = models.vgg19(pretrained=True).features #此时vgg: # Sequential( # (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (1): ReLU(inplace=True) # (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (3): ReLU(inplace=True) # (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) # (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (6): ReLU(inplace=True) # (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (8): ReLU(inplace=True) # (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) # (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (11): ReLU(inplace=True) # (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (13): ReLU(inplace=True) # (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (15): ReLU(inplace=True) # (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (17): ReLU(inplace=True) # (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) # (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (20): ReLU(inplace=True) # (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (22): ReLU(inplace=True) # (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (24): ReLU(inplace=True) # (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (26): ReLU(inplace=True) # (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) # (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (29): ReLU(inplace=True) # (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (31): ReLU(inplace=True) # (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (33): ReLU(inplace=True) # (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (35): ReLU(inplace=True) # (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) # ) for i in vgg.parameters(): i.requires_grad_(False) #不要求训练VGG的参数 content = load_img('turtle.jpg') style = load_img('wave.jpg', shape=content.shape[-2:]) #让两张图尺寸一样 content_features = get_features(content, vgg) style_features = get_features(style, vgg) style_grams = {layer:gram_matrix(style_features[layer]) for layer in style_features} #target就是生成图 target = content.clone().requires_grad_(True) #定义不同层的权重 style_weights = { 'conv1_1': 1, 'conv2_1': 0.8, 'conv3_1': 0.5, 'conv4_1': 0.3, 'conv5_1': 0.1, } #定义2种损失对应的权重 content_weight = 1 style_weight = 1e6 show_every = 2 #400 optimizer = optim.Adam([target], lr=0.003) steps = 2000 for ii in range(steps): print(ii) target_features = get_features(target, vgg) #内容损失函数 content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2) #风格损失函数 style_loss = 0 #加上每一层的格拉姆矩阵的损失 for layer in style_weights: #layer就是conv1_1、conv2_1等等这些 target_feature = target_features[layer] target_gram = gram_matrix(target_feature) _, d, h, w = target_feature.shape style_gram = style_grams[layer] layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2) style_loss += layer_style_loss/(d*h*w) #加到总的style_loss里,除以大小 total_loss = content_weight * content_loss + style_weight * style_loss optimizer.zero_grad() total_loss.backward() optimizer.step() if ii % show_every == 0 : print('Total Loss:',total_loss.item()) plt.imshow(im_convert(target)) plt.savefig('img_trans/%d.png' % ii) #plt.show()
训练的轮数默认为2000了,有点少,最终结果应该是类似这样的
因为没用GPU,所以速度略慢,跑了大概一天一晚上才跑完
参考
https://www.cnblogs.com/MartinLwx/p/10572466.html
thinkphp 分页 paginate 怎么使用 each循环数据 进行操作
2020寒冬已至?字节跳动都不是好的出路?四面楚歌的Android工程师该何去何从?
推荐收藏系列:一文理解JVM虚拟机(内存、垃圾回收、性能优化)解决面试中遇到问题(图解版)
【Python爬虫 • selenium】selenium4新版本使用指南
Macs Fan Control Pro风扇控制软件,帮你解决电脑发热、噪音问题
论文翻译《Computer Vision for Autonomous Vehicles Problems, Datasets and State-of-the-Art》(第一、二章)