PyTorch官方手册中logistic-regression的Demo

发布时间:2022-11-16 22:30

一、环境

OS:Ubuntu 18.04
Environment: PyTorch&Anaconda3
Editor: Spyder

二、代码部分

代码及数据集来自官方,注释部分为我个人学习笔记之用
数据集及源代码下载地址https://github.com/zergtant/pytorch-handbook/blob/master/chapter3/german.data-numeric
单独下载单个文件办法请参阅GitHub如何在下载单个文件 「一键式解决方案」
代码中有一段归一化代码,参考数据归一化常用的两种方法

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Feb 28 17:55:41 2020
"""
import torch
import torch.nn as nn
import numpy as np

data=np.loadtxt("german.data-numeric")

#归一化处理
n,l=data.shape
for j in range(l-1):
    meanVal=np.mean(data[:,j])
    stdVal=np.std(data[:,j])
    data[:,j]=(data[:,j]-meanVal)/stdVal

np.random.shuffle(data)#打乱数据

train_data=data[:900,:l-1]
train_lab=data[:900,l-1]-1
test_data=data[900:,:l-1]
test_lab=data[900:,l-1]-1

#建立神经网络
class LR(nn.Module):
    def __init__(self):
        super(LR,self).__init__()
        self.fc=nn.Linear(24,2) # 由于24个维度已经固定了,所以这里写24
    def forward(self,x):
        out=self.fc(x)
        out=torch.sigmoid(out)#关键步骤 进行非线性的激活
        return out
#建立测试函数,用于计算正确率
def test(pred,lab):
    t=pred.max(-1)[1]==lab
    #这个语法很有意思,居然可以这么解释
    return torch.mean(t.float())



net=LR() 
criterion=nn.CrossEntropyLoss() # 使用CrossEntropyLoss损失
optm=torch.optim.Adam(net.parameters()) # Adam优化
epochs=1000 # 训练1000次


for i in range(epochs):
    net.train() # 指定模型为训练模式,计算梯度
    # 输入值都需要转化成torch的Tensor
    x=torch.from_numpy(train_data).float()
    y=torch.from_numpy(train_lab).long()
    y_hat=net(x)
    loss=criterion(y_hat,y) # 计算损失
    optm.zero_grad() # 前一步的损失清零
    loss.backward() # 反向传播
    optm.step() # 优化
    if (i+1)%100 ==0 : # 这里我们每100次输出相关的信息
        # 指定模型为计算模式
        net.eval()
        test_in=torch.from_numpy(test_data).float()
        test_l=torch.from_numpy(test_lab).long()
        test_out=net(test_in)
        # 使用我们的测试函数计算准确率
        accu=test(test_out,test_l)
        print("Epoch:{},Loss:{:.4f},Accuracy:{:.2f}".format(i+1,loss.item(),accu))


三、示例

Epoch:100,Loss:0.6412,Accuracy:0.61
Epoch:200,Loss:0.6155,Accuracy:0.71
Epoch:300,Loss:0.5985,Accuracy:0.72
Epoch:400,Loss:0.5863,Accuracy:0.73
Epoch:500,Loss:0.5771,Accuracy:0.75
Epoch:600,Loss:0.5698,Accuracy:0.73
Epoch:700,Loss:0.5638,Accuracy:0.73
Epoch:800,Loss:0.5589,Accuracy:0.73
Epoch:900,Loss:0.5546,Accuracy:0.73
Epoch:1000,Loss:0.5510,Accuracy:0.73

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号