发布时间:2023-03-13 19:00
import torch
import torch.nn as nn
import numpy as np
import sys
sys.path.append("..")
import d2lzh_pytorch as d2l
import sys
from matplotlib import pyplot as plt
n_train, n_test, num_inputs = 20, 100, 200
#训练数据集越小,越容易过拟合。训练数据集为20,测试数据集为100,特征的纬度选择200.
#数据越小,模型越简单,过拟合越容易发生
true_w, true_b = torch.ones(num_inputs, 1) * 0.01, 0.05
#真实的权重就是0.01*全1的一个向量,偏差b为0.05
"""
读取一个人工的数据集
"""
features = torch.randn((n_train + n_test, num_inputs)) #特征
labels = torch.matmul(features, true_w) + true_b #样本数
labels += torch.tensor(np.random.normal(0, 0.01,size=labels.size()), dtype=torch.float)
train_features, test_features = features[:n_train, :],features[n_train:, :]
train_labels, test_labels = labels[:n_train], labels[n_train:]
"""
定义训练和测试模型
"""
batch_size, num_epochs, lr = 1, 100, 0.003
net, loss = d2l.linreg, d2l.squared_loss
dataset = torch.utils.data.TensorDataset(train_features,train_labels)
train_iter = torch.utils.data.DataLoader(dataset, batch_size,shuffle=True)
def fit_and_plot_pytorch(wd):
#对权重参数衰减,权重名称一般是以weight结尾
net=nn.Linear(num_inputs,1)
nn.init.normal_(net.weight,mean=0,std=1)
nn.init.normal_(net.bias,mean=0,std=1)
optimizer_w=torch.optim.SGD(params=[net.weight],lr=lr,weight_decay=wd)
#对权重参数衰减
optimizer_b=torch.optim.SGD(params=[net.bias],lr=lr)
#不对偏差参数衰减
train_ls,test_ls=[],[]
for _ in range(num_epochs):
for X,y in train_iter:
l = loss(net(X), y).mean()
optimizer_w.zero_grad()
optimizer_b.zero_grad()
l.backward()
optimizer_w.step()
optimizer_b.step()
# 对两个optimizer实例分别调⽤step函数,从⽽分别更新权᯿和偏差
train_ls.append(loss(net(train_features),
train_labels).mean().item())
test_ls.append(loss(net(test_features),
test_labels).mean().item())
d2l.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',
range(1, num_epochs + 1), test_ls, ['train', 'test'])
print('L2 norm of w:', net.weight.data.norm().item())
fit_and_plot_pytorch(0)
plt.show()
fit_and_plot_pytorch(3)
plt.show()
详解JavaScript中if语句优化和部分语法糖小技巧推荐
k8s笔记14--初次体验 开源云原生软件交付平台zadig
npm : 无法将“npm”项识别为 cmdlet、函数、脚本文件或可运行程序的名称。请检查名称的拼写,如果包括路径,请确保路径正确,然后再试一次。
java利用接口和抽象类改写求圆的面积和梯形的面积_Java接口和抽象类详解
原始jdbc连接数据读取字符串生成zip文件并读成byte数组
python初学者代码示例_python入门(非常详细的教程)
Wireshark之流量包分析+日志分析 (护网:蓝队)web安全 取证 分析黑客攻击流程(上篇)
python底层与机器底层关系_起底 Python 的底层逻辑
【论文阅读】RepVGG: Making VGG-style ConvNets Great Again(CVPR2021)