发布时间:2022-12-01 22:00
数据集:
(train_image,train_label), (test_image,test_label) = tf.keras.datasets.fashion_mnist.load_data()
多分类的损失函数使用多分类的交叉熵,对应categorical_crossentropy和sparse_categorical_crossentropy,前者用于标签为独热编码,后者用于label为自然数编码形式
总结:
flatten作为输入的扁平化层,使用add添加到模型
二维图像数据需要扁平化处理,训练集和测试集的 数据特征部分 均需要归一化
predict = model2.predict(test_image):模型预测参数输入test特征矩阵,返回所有样本的分类结果
import matplotlib.pyplot as plt
import tensorflow as tf
import pandas as pd
import numpy as np
%matplotlib inline
(train_image,train_label), (test_image,test_label) = tf.keras.datasets.fashion_mnist.load_data()
plt.imshow(train_image[0])
train_label[0]
train_image.shape
train_image.shape
train_label.shape
# turn the train and test data to 0-1
train_image, test_image = train_image/255, test_image/255
# Dense take one dimention of input, so has to flatten the image shape
model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28,28)))
model.add(tf.keras.layers.Dense(128, activation='relu'))
model.add(tf.keras.layers.Dense(10, activation='softmax'))
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy', metrics='acc')
model.fit(train_image, train_label, epochs=5)
model.evaluate(test_image,test_label)
# use categorical_crossentropy
train_label_onehot=tf.keras.utils.to_categorical(train_label)
test_label_onehot=tf.keras.utils.to_categorical(test_label)
model2 = tf.keras.Sequential()
model2.add(tf.keras.layers.Flatten(input_shape=(28,28)))
model2.add(tf.keras.layers.Dense(128, activation='relu'))
model2.add(tf.keras.layers.Dense(10, activation='softmax'))
model2.compile(optimizer='adam',loss='categorical_crossentropy', metrics='acc')
model2.fit(train_image,train_label_onehot)
model2.evaluate(test_image, test_label_onehot)
predict = model2.predict(test_image)
predict[0]
np.argmax(predict[0])
test_label[0]
微信小程序:王者改名微信小程序源码下载另一版本支持流量主收益
OpenSSL SSL_read: Connection was reset, errno 10054
pytorch学习笔记(二)——加载数据Dataset以及Dataloader的使用
使用python采集某二手房源数据并做数据可视化展示(含完整源代码)
SpringBoot整合Redis使用@Cacheable和RedisTemplate
基于VS2019的C++人脸识别libfacedetection-master数据库的配置