基于ResNet的猫十二分类

发布时间:2023-08-20 11:00

        在这次实战训练中,首先对下载的猫十二数据集进行预处理,使用了tensorflow构建resnet模型,在学习率调度上,使用了1周期调度,并且使用了动量优化和Nesterov加速梯度

 

1.导包

from tensorflow import keras
import tensorflow as tf
from keras.preprocessing import image
import random
from matplotlib import pyplot as plt
import cv2
from tqdm import tqdm
import numpy as np
import math

2.数据预处理

cat_12数据集包含3个部分,训练集cat_12_train,测试集cat_test,以及存储图片名称及标签的train_list.txt

基于ResNet的猫十二分类_第1张图片

(1)定义prepare_image函数从文件中分离路径和标签

def prepare_image(file_path):
    X_train = []
    y_train = []
 
    with open(file_path) as f:
        context = f.readlines()
    random.shuffle(context)
 
    for str in context:
        str = str.strip('\n').split('\t')
 
        X_train.append('./image/cat_12/' + str[0])
        y_train.append(str[1])
 
    return X_train, y_train

(2)定义preprocess_image函数进行图像的归一化

def preprocess_image(img):
    img = image.load_img(img, target_size=(224, 224))
    img = image.img_to_array(img)
    img = img / 255.0
    return img

(3)定义plot_image函数打印图像

def plot_image(images,classes):
    fig,axes = plt.subplots(4,3,figsize=(60,60),sharex=True)
    fig.subplots_adjust(hspace=0.3,wspace=0.3)
    for i,ax in enumerate(axes.flat):
        image = cv2.imread(images[i])
        image = cv2.resize(image,(224,224))
        ax.imshow(cv2.cvtColor(image,cv2.COLOR_BGR2RGB),cmap="hsv")
        ax.set_xlabel("Breed:{}".format(classes[i]))
        ax.xaxis.label.set_size(60)
        ax.set_yticks([])
    plt.show()

3.展示数据集

X_train, y_train = prepare_image('./image/cat_12/train_list.txt')
print(X_train)

基于ResNet的猫十二分类_第2张图片

x_array = np.array(X_train)
y_array = np.array(y_train)
y_unique = np.unique(y_array)
print(y_unique)
['0' '1' '10' '11' '2' '3' '4' '5' '6' '7' '8' '9']
imgs = []
classes = []
for i in y_unique:
    sort = x_array[y_array==i]
    idx = np.random.randint(len(sort)-1)
    imgs.append(sort[idx])
    classes.append(i)
print(imgs)    
plot_image(imgs,classes)

 

4.准备数据集

train_images = []
for i in tqdm(X_train):
    train_image = preprocess_image(i)
    train_images.append(train_image)
train_images = np.array(train_images)
y_train = keras.utils.to_categorical(y_train, 12)
print(train_images.shape)

5. 基于tensorflow构建ResNet

# 构建卷积块
def conv_2d(x,filters,kernel_size,strides,padding="same"):
    x = keras.layers.Conv2D(filters,kernel_size=kernel_size,strides=strides,padding=padding)(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.activations.relu(x)
    return x
# 构建残差块
def resual_block(inputs,filters,strides):
    x = inputs
    x = conv_2d(x,filters=filters,kernel_size=1,strides=strides)
    x = conv_2d(x,filters=filters,kernel_size=3,strides=1)
    x = conv_2d(x,filters=4*filters,kernel_size=1,strides=1)
    
    x_short = conv_2d(inputs,filters=4*filters,kernel_size=1,strides=strides)
    
    x = keras.layers.Add()([x,x_short])
    x = keras.activations.relu(x)
    return x
# 构建resnet_152
def resnet(input_shape,n_classes=1000):
    x_input = keras.layers.Input(shape=input_shape)
    x = conv_2d(x_input,filters=64,kernel_size=7,strides=2)
    x = keras.layers.MaxPooling2D(pool_size=(3,3),strides=2,padding="same")(x)
    # input 64*3
    x = resual_block(x,64,strides=1)
    x = resual_block(x,64,strides=1)
    x = resual_block(x,64,strides=1)
    # input 128*8
    x = resual_block(x,128,strides=2)
    for i in range(7):
        x = resual_block(x,128,strides=1)
    # input 256*36
    x = resual_block(x,256,strides=2)
    for i in range(35):
        x = resual_block(x,256,strides=1)
    # input 512*3
    x = resual_block(x,512,strides=2)
    for i in range(2):
        x = resual_block(x,512,strides=1)
    
    # 全局平均池化
    x = keras.layers.GlobalAveragePooling2D()(x)
    output = keras.layers.Dense(n_classes,activation="softmax")(x)
    model = keras.models.Model(inputs=[x_input],outputs=[output])
    return model
model = resnet([224,224,3],12)

6.1周期调度

k = keras.backend
class One_Cycle(keras.callbacks.Callback):
    def __init__(self,interations,max_rate,min_rate=None,start_rate=None,last_interations=None):
        self.interations = interations
        self.max_rate = max_rate
        self.min_rate = min_rate or self.max_rate/10000
        self.start_rate = start_rate or self.max_rate/100
        self.last_interations = last_interations or self.interations//10+1
        self.half_interations = (self.interations - self.last_interations)//2
        self.interation = 0
        self.loss = []
        self.learning_rate = []
        self.numbers = []
        
    def _interpolate(self,iter1,iter2,rate1,rate2):
        return ((rate2-rate1)*(self.interation-iter1)/(iter2-iter1)+rate1)
    
    def on_batch_begin(self,batch,logs=None):
        if self.interation < self.half_interations:
            rate = self._interpolate(0,self.half_interations,self.start_rate,self.max_rate)
        elif self.interation < 2*self.half_interations:
            rate = self._interpolate(self.half_interations,2*self.half_interations,self.max_rate,self.start_rate)
        else:
            rate = self._interpolate(2*self.half_interations,self.interations,self.max_rate,self.min_rate)
        self.interation += 1
        k.set_value(self.model.optimizer.learning_rate,rate)
        
    def on_batch_end(self,batch,logs=None):
        self.learning_rate.append(k.get_value(self.model.optimizer.learning_rate))
        self.loss.append(logs["loss"])
        self.numbers.append(self.interation)

7.训练模型

n_epochs = 100
one_cycle = One_Cycle(math.ceil(len(X_train)//16)*n_epochs,0.1,min_rate=1e-5)
optimizer = keras.optimizers.SGD(learning_rate=0.001,momentum=0.9,nesterov=True)
model.compile(loss="categorical_crossentropy",optimizer=optimizer,metrics=["accuracy"])
history = model.fit(train_images,y_train,epochs=n_epochs,validation_split=0.2,batch_size=16,callbacks=[one_cycle])

8.绘制损失变化情况

def show_training_history(train_history, train, val):
    plt.plot(train_history[train], linestyle='-', color='b')
    plt.plot(train_history[val], linestyle='--', color='r')
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('train', fontsize=12)
    plt.legend(['train', 'validation'], loc='lower right')
    plt.show()

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号