《动手学深度学习-pytorch》书中定义函数后加#@save的含义

发布时间:2023-02-09 13:00

笔者在学习《动手学深度学习-pytorch》一书的时候,发现很多定义的函数后面都会加#@save符号,如下:

def use_svg_display():  #@save
    """使⽤svg格式在Jupyter中显⽰绘图"""
    backend_inline.set_matplotlib_formats('svg')

书中的解释为:“有时,为了避免不必要的重复,我们将本书中经常导⼊和引⽤的函数、类等封装在d2l包中。对于要保存到包中的任何代码块,⽐如⼀个函数、⼀个类或者多个导⼊,我们都会标记为#@save。”于是笔者试着先from d2l import torch as d2l,然后在pycharm控制台键入d2l.use,发现可以补全为d2l.use_svg_display(),这说明use_svg_display()函数是封装在d2l库中的。

另外在序言中有这么一句话,“有时,我们想深⼊研究模型的细节,这些的细节通常会被深度学习框架的⾼级抽象隐藏起来。特别是在基础教程中,我们希望你了解在给定层或优化器中发⽣的⼀切。在这些情况下,我们通常会提供两个版本的⽰例:⼀个是我们从零开始实现⼀切,仅依赖于NumPy接⼝和⾃动微分;另⼀个是更实际的⽰例,我们使⽤深度学习框架的⾼级API编写简洁的代码。”因此笔者认为这里#@save的意思应该是指:虽然书中对这些定义的函数展开讲解,但这些函数都是可以直接调用,封装在d2l库中的,此外也和没有封装的,临时定义的函数做出区分,如下:

def evaluate_loss(net, data_iter, loss): #@save
	"""评估给定数据集上模型的损失"""
	metric = d2l.Accumulator(2) # 损失的总和,样本数量
	for X, y in data_iter:
		out = net(X)
		y = y.reshape(out.shape)
		l = loss(out, y)
		metric.add(l.sum(), l.numel())
	return metric[0] / metric[1]


def train(train_features, test_features, train_labels, test_labels, num_epochs=400):
    loss = nn.MSELoss(reduction='none')
    input_shape = train_features.shape[-1] # 不设置偏置,因为我们已经在多项式中实现了它
    net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))
    batch_size = min(10, train_labels.shape[0])
    train_iter = d2l.load_array((train_features, train_labels.reshape(-1,1)),
                                batch_size)
    test_iter = d2l.load_array((test_features, test_labels.reshape(-1,1)),
                               batch_size, is_train=False)
    trainer = torch.optim.SGD(net.parameters(), lr=0.01)
    animator = d2l.Animator(xlabel='epoch', ylabel='loss', yscale='log',
                            xlim=[1, num_epochs],
                            ylim=[1e-3, 1e2],
                            legend=['train', 'test'])
    for epoch in range(num_epochs):
        d2l.train_epoch_ch3(net, train_iter, loss, trainer)
        if epoch == 0 or (epoch + 1) % 20 == 0:
            animator.add(epoch + 1, (evaluate_loss(net, train_iter, loss),
                                     evaluate_loss(net, test_iter, loss)))
    print('weight:', net[0].weight.data.numpy())

这是书中第四章的一部分代码,可以看到,evaluate_loss()后面有#@save,而train()后面没有,也就是说,evaluate_loss()是d2l的内置函数,而train()是为了教学临时定义的。

以上是笔者的个人见解,有不对之处还请大家指出。

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号