Pytorch模型保存与提取

发布时间:2023-04-07 14:30

Pytorch模型保存与提取

Pytorch模型保存与提取

本篇笔记主要对应于莫凡Pytorch中的3.4节。主要讲了如何使用Pytorch保存和提取我们的神经网络。
在Pytorch中,网络的存储主要使用torch.save函数来完成。
我们将通过两种方式展示模型的保存和提取。
第一种保存方式是保存整个模型,在重新提取时直接加载整个模型。第二种保存方法是只保存模型的参数,这种方式只保存了参数,而不会保存模型的结构等信息。
两种方式各有优缺点。

保存完整模型不需要知道网络的结构,一次性保存一次性读入。缺点是模型比较大时耗时较长,保存的文件也大。
而只保存参数的方式存储快捷,保存的文件也小一些,但缺点是丢失了网络的结构信息,恢复模型时需要提前建立一个特定结构的网络再读入参数。
以下使用代码展示。

数据生成与展示

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

复制代码
这里还是生成一组带有噪声的y=x2y=x^{2}y=x2数据进行回归拟合。

# torch.manual_seed(1)    # reproducible

# fake data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)

复制代码

基本网络搭建与保存

我们使用nn.Sequential模块来快速搭建一个网络完成回归操作,网络由两层Linear层和中间的激活层ReLU组成。我们设置输入输出的维度为1,中间隐藏层变量的维度为10,以加快训练。
这里使用两种方式进行保存。

def save():
    # save net1
    net1 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )
    optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
    loss_func = torch.nn.MSELoss()
    
    for step in range(100):
        prediction = net1(x)
        loss = loss_func(prediction, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    # plot result
    plt.figure(1, figsize=(10, 3))
    plt.subplot(131)
    plt.title(\'Net1\')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), \'r-\', lw=5)
    plt.savefig(\"./img/05_save.png\")
        
    torch.save(net1, \'net.pkl\')                        # entire network
    torch.save(net1.state_dict(), \'net_params.pkl\')    # parameters

复制代码
在这个save函数中,我们首先使用nn.Sequential模块构建了一个基础的两层神经网络。然后对其进行训练,展示训练结果。之后使用两种方式进行保存。
第一种方式直接保存整个网络,代码为

torch.save(net1, \'net.pkl\')                        # entire network
复制代码
第二种方式只保存网络参数,代码为
torch.save(net1.state_dict(), \'net_params.pkl\')    # parameters

复制代码

对保存的模型进行提取恢复

这里我们为两种不同存储方式保存的模型分别定义恢复提取的函数
首先是对整个网络的提取。直接使用torch.load就可以,无需其他额外操作。

def restore_net():
    # 提取神经网络
    net2 = torch.load(\'net.pkl\')
    prediction = net2(x)
    
    # plot result
    plt.subplot(132)
    plt.title(\'Net2\')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), \'r-\', lw=5)
    plt.savefig(\"./img/05_res_net.png\")

复制代码
而对于参数的读取,我们首先需要先搭建好一个与之前保存的模型相同架构的网络,然后使用这个网络的load_state_dict方法进行参数读取和恢复。以下展示了使用参数方式读取网络的示例:

def restore_params():
    # 提取神经网络
    net3 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )
    net3.load_state_dict(torch.load(\'net_params.pkl\'))
    prediction = net3(x)
    
    # plot result
    plt.subplot(133)
    plt.title(\'Net3\')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), \'r-\', lw=5)
    plt.savefig(\"./img/05_res_para.png\")
    plt.show()

复制代码

对比不同提取方法的效果

接下来我们对比一下这两种方法的提取效果

# save net1
save()

# restore entire net (may slow)
restore_net()

# restore only the net parameters
restore_params()

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号