Pytorch模型保存与提取
Pytorch模型保存与提取
本篇笔记主要对应于莫凡Pytorch中的3.4节。主要讲了如何使用Pytorch保存和提取我们的神经网络。
在Pytorch中,网络的存储主要使用torch.save函数来完成。
我们将通过两种方式展示模型的保存和提取。
第一种保存方式是保存整个模型,在重新提取时直接加载整个模型。第二种保存方法是只保存模型的参数,这种方式只保存了参数,而不会保存模型的结构等信息。
两种方式各有优缺点。
保存完整模型不需要知道网络的结构,一次性保存一次性读入。缺点是模型比较大时耗时较长,保存的文件也大。
而只保存参数的方式存储快捷,保存的文件也小一些,但缺点是丢失了网络的结构信息,恢复模型时需要提前建立一个特定结构的网络再读入参数。
以下使用代码展示。
数据生成与展示
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
复制代码
这里还是生成一组带有噪声的y=x2y=x^{2}y=x2数据进行回归拟合。
# torch.manual_seed(1) # reproducible
# fake data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1)
复制代码
基本网络搭建与保存
我们使用nn.Sequential模块来快速搭建一个网络完成回归操作,网络由两层Linear层和中间的激活层ReLU组成。我们设置输入输出的维度为1,中间隐藏层变量的维度为10,以加快训练。
这里使用两种方式进行保存。
def save():
# save net1
net1 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss()
for step in range(100):
prediction = net1(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# plot result
plt.figure(1, figsize=(10, 3))
plt.subplot(131)
plt.title(\'Net1\')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), \'r-\', lw=5)
plt.savefig(\"./img/05_save.png\")
torch.save(net1, \'net.pkl\') # entire network
torch.save(net1.state_dict(), \'net_params.pkl\') # parameters
复制代码
在这个save函数中,我们首先使用nn.Sequential模块构建了一个基础的两层神经网络。然后对其进行训练,展示训练结果。之后使用两种方式进行保存。
第一种方式直接保存整个网络,代码为
torch.save(net1, \'net.pkl\') # entire network
复制代码
第二种方式只保存网络参数,代码为
torch.save(net1.state_dict(), \'net_params.pkl\') # parameters
复制代码
对保存的模型进行提取恢复
这里我们为两种不同存储方式保存的模型分别定义恢复提取的函数
首先是对整个网络的提取。直接使用torch.load就可以,无需其他额外操作。
def restore_net():
# 提取神经网络
net2 = torch.load(\'net.pkl\')
prediction = net2(x)
# plot result
plt.subplot(132)
plt.title(\'Net2\')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), \'r-\', lw=5)
plt.savefig(\"./img/05_res_net.png\")
复制代码
而对于参数的读取,我们首先需要先搭建好一个与之前保存的模型相同架构的网络,然后使用这个网络的load_state_dict方法进行参数读取和恢复。以下展示了使用参数方式读取网络的示例:
def restore_params():
# 提取神经网络
net3 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
net3.load_state_dict(torch.load(\'net_params.pkl\'))
prediction = net3(x)
# plot result
plt.subplot(133)
plt.title(\'Net3\')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), \'r-\', lw=5)
plt.savefig(\"./img/05_res_para.png\")
plt.show()
复制代码
对比不同提取方法的效果
接下来我们对比一下这两种方法的提取效果
# save net1
save()
# restore entire net (may slow)
restore_net()
# restore only the net parameters
restore_params()