PyTorch实现手写数字的识别入门小白教程

发布时间:2025-02-21 19:01

目录
  • 手写数字识别(小白入门)
    • 1.数据预处理
    • 2.训练模型
    • 3.测试模型,保存
    • 4.调用模型
    • 5.完整代码

手写数字识别(小白入门)

今早刚刚上了节实验课,关于逻辑回归,所以手有点刺挠就想发个博客,作为刚刚入门的小白,看到代码运行成功就有点小激动,这个实验没啥含金量,所以路过的大牛不要停留,我怕你们吐槽哈哈。

实验结果:

\"PyTorch实现手写数字的识别入门小白教程_第1张图片\"

\"PyTorch实现手写数字的识别入门小白教程_第2张图片\"

 

 

 

1.数据预处理

其实呢,原理很简单,就是使用多变量逻辑回归,将训练28*28图片的灰度值转换成一维矩阵,这就变成了求784个特征向量1个标签的逻辑回归问题。代码如下:

#数据预处理
trainData = np.loadtxt(open(\'digits_training.csv\', \'r\'), delimiter=\",\",skiprows=1)#装载数据
MTrain, NTrain = np.shape(trainData)  #行列数
print(\"训练集:\",MTrain,NTrain)
xTrain = trainData[:,1:NTrain]
xTrain_col_avg = np.mean(xTrain, axis=0) #对各列求均值
xTrain =(xTrain- xTrain_col_avg)/255  #归一化
yTrain = trainData[:,0]

2.训练模型

对于数学差的一批的我来说,学习算法真的是太太太扎心了,好在具体算法封装在了sklearn库中。简单两行代码即可完成。具体参数的含义随随便便一搜到处都是,我就不班门弄斧了,每次看见算法除了头晕啥感觉没有。

model = LogisticRegression(solver=\'lbfgs\', multi_class=\'multinomial\', max_iter=500)
model.fit(xTrain, yTrain)

3.测试模型,保存

接下来测试一下模型,准确率能达到百分之90,也不算太高,训练数据集本来也不是很多。
为了方便,所以把模型保存下来,不至于运行一次就得训练一次。

#测试模型
testData = np.loadtxt(open(\'digits_testing.csv\', \'r\'), delimiter=\",\",skiprows=1)
MTest,NTest = np.shape(testData)
print(\"测试集:\",MTest,NTest)
xTest = testData[:,1:NTest]
xTest = (xTest-xTrain_col_avg) /255   # 使用训练数据的列均值进行处理
yTest = testData[:,0]
yPredict = model.predict(xTest)
errors = np.count_nonzero(yTest - yPredict) #返回非零项个数
print(\"预测完毕。错误:\", errors, \"条\")
print(\"测试数据正确率:\", (MTest - errors) / MTest)

\'\'\'=================================\'\'\'
#保存模型

# 创建文件目录
dirs = \'testModel\'
if not os.path.exists(dirs):
    os.makedirs(dirs)
joblib.dump(model, dirs+\'/model.pkl\')
print(\"模型已保存\")

https://download.csdn.net/download/qq_45874897/12427896 需要的可以自行下载

4.调用模型

既然模型训练好了,就来放几张图片调用模型试一下看看怎么样
导入要测试的图片,然后更改大小为28*28,将图片二值化减小误差。
为了让结果看起来有逼格,所以最后把图片和识别数字同实显示出来。

import  cv2
import numpy as np
from sklearn.externals import joblib

map=cv2.imread(r\"C:\\Users\\lenovo\\Desktop\\[DX6@[C$%@2RS0R2KPE[W@V.png\")
GrayImage = cv2.cvtColor(map, cv2.COLOR_BGR2GRAY)
ret,thresh2=cv2.threshold(GrayImage,127,255,cv2.THRESH_BINARY_INV)
Image=cv2.resize(thresh2,(28,28))
img_array = np.asarray(Image)
z=img_array.reshape(1,-1)

\'\'\'================================================\'\'\'

model = joblib.load(\'testModel\'+\'/model.pkl\')
yPredict = model.predict(z)
print(yPredict)
y=str(yPredict)
cv2.putText(map,y, (10,20), cv2.FONT_HERSHEY_SIMPLEX,0.7,(0,0,255), 2, cv2.LINE_AA)
cv2.imshow(\"map\",map)
cv2.waitKey(0)

5.完整代码

test1.py

import numpy as np
from sklearn.linear_model import LogisticRegression
import os
from sklearn.externals import joblib

#数据预处理
trainData = np.loadtxt(open(\'digits_training.csv\', \'r\'), delimiter=\",\",skiprows=1)#装载数据
MTrain, NTrain = np.shape(trainData)  #行列数
print(\"训练集:\",MTrain,NTrain)
xTrain = trainData[:,1:NTrain]
xTrain_col_avg = np.mean(xTrain, axis=0) #对各列求均值
xTrain =(xTrain- xTrain_col_avg)/255  #归一化
yTrain = trainData[:,0]

\'\'\'=================================\'\'\'
#训练模型
model = LogisticRegression(solver=\'lbfgs\', multi_class=\'multinomial\', max_iter=500)
model.fit(xTrain, yTrain)
print(\"训练完毕\")

\'\'\'=================================\'\'\'
#测试模型
testData = np.loadtxt(open(\'digits_testing.csv\', \'r\'), delimiter=\",\",skiprows=1)
MTest,NTest = np.shape(testData)
print(\"测试集:\",MTest,NTest)
xTest = testData[:,1:NTest]
xTest = (xTest-xTrain_col_avg) /255   # 使用训练数据的列均值进行处理
yTest = testData[:,0]
yPredict = model.predict(xTest)
errors = np.count_nonzero(yTest - yPredict) #返回非零项个数
print(\"预测完毕。错误:\", errors, \"条\")
print(\"测试数据正确率:\", (MTest - errors) / MTest)

\'\'\'=================================\'\'\'
#保存模型

# 创建文件目录
dirs = \'testModel\'
if not os.path.exists(dirs):
    os.makedirs(dirs)
joblib.dump(model, dirs+\'/model.pkl\')
print(\"模型已保存\")

运行结果

\"PyTorch实现手写数字的识别入门小白教程_第3张图片\"

test2.py

import  cv2
import numpy as np
from sklearn.externals import joblib

map=cv2.imread(r\"C:\\Users\\lenovo\\Desktop\\[DX6@[C$%@2RS0R2KPE[W@V.png\")
GrayImage = cv2.cvtColor(map, cv2.COLOR_BGR2GRAY)
ret,thresh2=cv2.threshold(GrayImage,127,255,cv2.THRESH_BINARY_INV)
Image=cv2.resize(thresh2,(28,28))
img_array = np.asarray(Image)
z=img_array.reshape(1,-1)

\'\'\'================================================\'\'\'

model = joblib.load(\'testModel\'+\'/model.pkl\')
yPredict = model.predict(z)
print(yPredict)
y=str(yPredict)
cv2.putText(map,y, (10,20), cv2.FONT_HERSHEY_SIMPLEX,0.7,(0,0,255), 2, cv2.LINE_AA)
cv2.imshow(\"map\",map)
cv2.waitKey(0)

提供几张样本用来测试:

\"PyTorch实现手写数字的识别入门小白教程_第4张图片\"

\"PyTorch实现手写数字的识别入门小白教程_第5张图片\"

\"PyTorch实现手写数字的识别入门小白教程_第6张图片\"

\"PyTorch实现手写数字的识别入门小白教程_第7张图片\"

\"PyTorch实现手写数字的识别入门小白教程_第8张图片\"

\"PyTorch实现手写数字的识别入门小白教程_第9张图片\"

实验中还有很多地方需要优化,比如数据集太少,泛化能力太差,用样本的数据测试正确率挺高,但是用我自己手写的字正确率就太低了,可能我字写的太丑,哎,还是自己太菜了,以后得多学学算法了。

到此这篇关于PyTorch实现手写数字的识别入门小白教程的文章就介绍到这了,更多相关PyTorch手写数字识别内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

你可能感兴趣的

相关推荐

前端如何做单元测试? 看这篇就入门了

俩万搭建安装SpringBoot+VUE【视频+文档+源码】

人脸识别精度提升 | 基于Transformer的人脸识别(附源码)

教你一文解决 js 数字精度丢失问题

C语言常见排序算法之插入排序(直接插入排序,希尔排序)

效率低?响应慢?报表工具痛点及其解决方案

svd在matlab中的使用,matlab - 使用SVD在MATLAB中压缩图像 - 堆栈内存溢出

void* data 数据类型参数以及void *data[ ]解释

成功解决error: Microsoft Visual C++ 14.0 or greater is required. Get it with “Microsoft C++ Build Tools“

Python语言之面向对象

React中编写CSS的常见方案

STM32控制TFTLCD显示屏(理论)

润乾报表 dql分析模块报表实现隔行异色效果

Spring MVC 统一异常处理总结

JavaScript前端实现小说分页功能示例

剖析CocosCreator新资源管理系统

MATLAB安装产品选择,如何选择需要安装的产品

java && 类似%E4%B8%AD%E5%9B%BD这种字符转换问题

Unity 学习笔记 版本 4.6.8

关于pyqtSignal的基本使用

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号