发布时间:2022-11-16 22:30
OS:Ubuntu 18.04
Environment: PyTorch&Anaconda3
Editor: Spyder
代码及数据集来自官方,注释部分为我个人学习笔记之用
数据集及源代码下载地址https://github.com/zergtant/pytorch-handbook/blob/master/chapter3/german.data-numeric
单独下载单个文件办法请参阅GitHub如何在下载单个文件 「一键式解决方案」
代码中有一段归一化代码,参考数据归一化常用的两种方法
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Feb 28 17:55:41 2020
"""
import torch
import torch.nn as nn
import numpy as np
data=np.loadtxt("german.data-numeric")
#归一化处理
n,l=data.shape
for j in range(l-1):
meanVal=np.mean(data[:,j])
stdVal=np.std(data[:,j])
data[:,j]=(data[:,j]-meanVal)/stdVal
np.random.shuffle(data)#打乱数据
train_data=data[:900,:l-1]
train_lab=data[:900,l-1]-1
test_data=data[900:,:l-1]
test_lab=data[900:,l-1]-1
#建立神经网络
class LR(nn.Module):
def __init__(self):
super(LR,self).__init__()
self.fc=nn.Linear(24,2) # 由于24个维度已经固定了,所以这里写24
def forward(self,x):
out=self.fc(x)
out=torch.sigmoid(out)#关键步骤 进行非线性的激活
return out
#建立测试函数,用于计算正确率
def test(pred,lab):
t=pred.max(-1)[1]==lab
#这个语法很有意思,居然可以这么解释
return torch.mean(t.float())
net=LR()
criterion=nn.CrossEntropyLoss() # 使用CrossEntropyLoss损失
optm=torch.optim.Adam(net.parameters()) # Adam优化
epochs=1000 # 训练1000次
for i in range(epochs):
net.train() # 指定模型为训练模式,计算梯度
# 输入值都需要转化成torch的Tensor
x=torch.from_numpy(train_data).float()
y=torch.from_numpy(train_lab).long()
y_hat=net(x)
loss=criterion(y_hat,y) # 计算损失
optm.zero_grad() # 前一步的损失清零
loss.backward() # 反向传播
optm.step() # 优化
if (i+1)%100 ==0 : # 这里我们每100次输出相关的信息
# 指定模型为计算模式
net.eval()
test_in=torch.from_numpy(test_data).float()
test_l=torch.from_numpy(test_lab).long()
test_out=net(test_in)
# 使用我们的测试函数计算准确率
accu=test(test_out,test_l)
print("Epoch:{},Loss:{:.4f},Accuracy:{:.2f}".format(i+1,loss.item(),accu))
Epoch:100,Loss:0.6412,Accuracy:0.61
Epoch:200,Loss:0.6155,Accuracy:0.71
Epoch:300,Loss:0.5985,Accuracy:0.72
Epoch:400,Loss:0.5863,Accuracy:0.73
Epoch:500,Loss:0.5771,Accuracy:0.75
Epoch:600,Loss:0.5698,Accuracy:0.73
Epoch:700,Loss:0.5638,Accuracy:0.73
Epoch:800,Loss:0.5589,Accuracy:0.73
Epoch:900,Loss:0.5546,Accuracy:0.73
Epoch:1000,Loss:0.5510,Accuracy:0.73