发布时间:2023-02-09 13:00
笔者在学习《动手学深度学习-pytorch》一书的时候,发现很多定义的函数后面都会加#@save符号,如下:
def use_svg_display(): #@save
"""使⽤svg格式在Jupyter中显⽰绘图"""
backend_inline.set_matplotlib_formats('svg')
书中的解释为:“有时,为了避免不必要的重复,我们将本书中经常导⼊和引⽤的函数、类等封装在d2l包中。对于要保存到包中的任何代码块,⽐如⼀个函数、⼀个类或者多个导⼊,我们都会标记为#@save。”于是笔者试着先from d2l import torch as d2l,然后在pycharm控制台键入d2l.use,发现可以补全为d2l.use_svg_display(),这说明use_svg_display()函数是封装在d2l库中的。
另外在序言中有这么一句话,“有时,我们想深⼊研究模型的细节,这些的细节通常会被深度学习框架的⾼级抽象隐藏起来。特别是在基础教程中,我们希望你了解在给定层或优化器中发⽣的⼀切。在这些情况下,我们通常会提供两个版本的⽰例:⼀个是我们从零开始实现⼀切,仅依赖于NumPy接⼝和⾃动微分;另⼀个是更实际的⽰例,我们使⽤深度学习框架的⾼级API编写简洁的代码。”因此笔者认为这里#@save的意思应该是指:虽然书中对这些定义的函数展开讲解,但这些函数都是可以直接调用,封装在d2l库中的,此外也和没有封装的,临时定义的函数做出区分,如下:
def evaluate_loss(net, data_iter, loss): #@save
"""评估给定数据集上模型的损失"""
metric = d2l.Accumulator(2) # 损失的总和,样本数量
for X, y in data_iter:
out = net(X)
y = y.reshape(out.shape)
l = loss(out, y)
metric.add(l.sum(), l.numel())
return metric[0] / metric[1]
def train(train_features, test_features, train_labels, test_labels, num_epochs=400):
loss = nn.MSELoss(reduction='none')
input_shape = train_features.shape[-1] # 不设置偏置,因为我们已经在多项式中实现了它
net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))
batch_size = min(10, train_labels.shape[0])
train_iter = d2l.load_array((train_features, train_labels.reshape(-1,1)),
batch_size)
test_iter = d2l.load_array((test_features, test_labels.reshape(-1,1)),
batch_size, is_train=False)
trainer = torch.optim.SGD(net.parameters(), lr=0.01)
animator = d2l.Animator(xlabel='epoch', ylabel='loss', yscale='log',
xlim=[1, num_epochs],
ylim=[1e-3, 1e2],
legend=['train', 'test'])
for epoch in range(num_epochs):
d2l.train_epoch_ch3(net, train_iter, loss, trainer)
if epoch == 0 or (epoch + 1) % 20 == 0:
animator.add(epoch + 1, (evaluate_loss(net, train_iter, loss),
evaluate_loss(net, test_iter, loss)))
print('weight:', net[0].weight.data.numpy())
这是书中第四章的一部分代码,可以看到,evaluate_loss()后面有#@save,而train()后面没有,也就是说,evaluate_loss()是d2l的内置函数,而train()是为了教学临时定义的。
以上是笔者的个人见解,有不对之处还请大家指出。
如何设计一个漂亮的仪表盘—Jeecg仪表盘轻松实现【数据可视化专题】
EasyNLP开源|中文NLP+大模型落地,EasyNLP is all you need
写在21年初的后端社招面试经历(两年经验): 蚂蚁 头条 PingCAP
electron打包报错permission denied, rmdir '/tmp/electron-packager'
MindSpore报错 StridedSlice这个算子在Ascend硬件上不支持input是uint8的数据类型
HaaS轻应用(Python):基于HaaS-AI的文字识别
阿里熬一个月肝出这份Java面试手册 Zookeeper 面试篇
JavaSE实战——API(上) Eclipse使用、Object、Scanner、String、StringBuffer、StringBuilder、Integer、模拟用户登录案例