发布时间:2023-04-02 12:00
本篇博客是《手把手实战教学!语义分割从0到1》系列的第二篇实战教学——模型的训练,将首先介绍一些常见模型,然后重点讲解如何使用自己的数据集训练一个语义分割模型。模型训练主要参考了这个开源库:GitHub - qq995431104/pytorch_segmentation: Semantic segmentation models, datasets and losses implemented in PyTorch.。
本系列总的介绍,以及其他章节的汇总,见:手把手实战教学!语义分割从0到1:开篇_AI数据工厂-CSDN博客_语义分割实战。
目录
1、常用的语义分割网络
2、训练自己的语义分割模型
2.1、数据准备
2.2、代码准备
2.3、修改配置文件
2.4、Dataset及DataLoader
2.5、开始训练
2.6、查看训练状态
3、下篇预告
从FCN开始,语义分割正式进入了深度学习时代,此后,U-Net、SegNet、PSPNet、DeepLab系列各种网络层出不穷。如果需要进一步了解各个网络的相关知识点,可以参考我的专栏:https://blog.csdn.net/oyezhou/category_10704356.html,该专栏包含了语义分割多个网络的介绍以及其他知识点。
我们本篇博客将利用DeepLabV3+进行实战。
在本系列博客的上一篇《手把手实战教学!语义分割从0到1:一、数据集制作》介绍了如何制作语义分割数据集,如果按照上面的说明一步一步走,应该现在已经有了一批标注并转换好的VOC格式的分割数据集。我们把做好的数据集放到某个目录下备用。
先从GitHub - qq995431104/pytorch_segmentation: Semantic segmentation models, datasets and losses implemented in PyTorch.把我们用到的开源库git clone下来。然后,按照requirements的要求安装好对应的库。
该开源代码包含了多个分割网络,如FCN、U-Net、PSPNet、DeepLabv3+,均可通过配置相应参数来使用。
该开源代码为了统一配置我们的训练参数,做了一个配置文件(pytorch_segmentation/config.json),里面可以配置网络的backbone、分割模型、数据集、优化器、loss以及其他超参。
我们这里使用的数据集是VOC类型的,然后用的模型为DeepLabV3+,总的配置如下:
{
\"name\": \"DeepLabv3_plus\",
\"n_gpu\": 1,
\"use_synch_bn\": true,
\"arch\": {
\"type\": \"DeepLab\",
\"args\": {
\"backbone\": \"resnet101\",
\"freeze_bn\": false,
\"freeze_backbone\": false
}
},
\"train_loader\": {
\"type\": \"MyVOC\",
\"args\":{
\"data_dir\": \"D:/dataset/my_dataset\",
\"batch_size\": 4,
\"base_size\": 718,
\"crop_size\": 718,
\"augment\": true,
\"shuffle\": true,
\"scale\": true,
\"flip\": true,
\"rotate\": true,
\"blur\": false,
\"split\": \"train\",
\"num_workers\": 0
}
},
\"val_loader\": {
\"type\": \"MyVOC\",
\"args\":{
\"data_dir\": \"D:/dataset/my_dataset\",
\"batch_size\": 4,
\"crop_size\": 718,
\"val\": true,
\"split\": \"val\",
\"num_workers\": 0
}
},
\"optimizer\": {
\"type\": \"SGD\",
\"differential_lr\": true,
\"args\":{
\"lr\": 0.005,
\"weight_decay\": 1e-4,
\"momentum\": 0.99
}
},
\"loss\": \"CrossEntropyLoss2d\",
\"ignore_index\": 255,
\"lr_scheduler\": {
\"type\": \"Poly\",
\"args\": {}
},
\"trainer\": {
\"epochs\": 120,
\"save_dir\": \"saved/\",
\"save_period\": 10,
\"monitor\": \"max Mean_IoU\",
\"early_stop\": 20,
\"tensorboard\": true,
\"log_dir\": \"saved/runs\",
\"log_per_iter\": 10,
\"val\": true,
\"val_per_epochs\": 5
}
}
可以参考我上面的配置,自定义你自己的训练配置。
在“pytorch_segmentation/dataloaders”目录下有几种常见数据集的DataLoader定义,如VOC、COCO等,我们这里由于使用的是VOC格式的数据集,所以可以基于“voc.py”这个文件进行修改。这里贴出我修改后的dataset和DataLoader定义:
# Originally written by Kazuto Nakashima
# https://github.com/kazuto1011/deeplab-pytorch
from base import BaseDataSet, BaseDataLoader
from utils import palette
import numpy as np
import os
from PIL import Image
class VOCDataset(BaseDataSet):
\"\"\"
my VOC-like dataset
\"\"\"
def __init__(self, **kwargs):
self.num_classes = 2
self.palette = palette.get_voc_palette(self.num_classes)
super(VOCDataset, self).__init__(**kwargs)
def _set_files(self):
self.image_dir = os.path.join(self.root, \'JPEGImages\')
self.label_dir = os.path.join(self.root, \'SegmentationClass\')
file_list = os.path.join(self.root, self.split + \".txt\")
self.files = [line.rstrip() for line in tuple(open(file_list, \"r\"))]
def _load_data(self, index):
image_id = self.files[index]
image_path = os.path.join(self.image_dir, image_id + \'.jpg\')
label_path = os.path.join(self.label_dir, image_id + \'.png\')
image = np.asarray(Image.open(image_path).convert(\'RGB\'), dtype=np.float32)
label = np.asarray(Image.open(label_path), dtype=np.int32)
image_id = self.files[index].split(\"/\")[-1].split(\".\")[0]
return image, label, image_id
class MyVOC(BaseDataLoader):
def __init__(self, data_dir, batch_size, split, crop_size=None, base_size=None, scale=True, num_workers=1, val=False,
shuffle=False, flip=False, rotate=False, blur= False, augment=False, val_split= None, return_id=False):
# update at 2021.01.29
self.MEAN = [0.4935838, 0.48873937, 0.45739236]
self.STD = [0.22273207, 0.22567303, 0.22986929]
kwargs = {
\'root\': data_dir,
\'split\': split,
\'mean\': self.MEAN,
\'std\': self.STD,
\'augment\': augment,
\'crop_size\': crop_size,
\'base_size\': base_size,
\'scale\': scale,
\'flip\': flip,
\'blur\': blur,
\'rotate\': rotate,
\'return_id\': return_id,
\'val\': val
}
self.dataset = VOCDataset(**kwargs)
super(MyVOC, self).__init__(self.dataset, batch_size, shuffle, num_workers, val_split)
其中,palette是每个类别的颜色定义,我是用的voc的定义,你也可以自行定义每个类别的颜色:
palette.py:
def get_voc_palette(num_classes):
n = num_classes
palette = [0]*(n*3)
for j in range(0,n):
lab = j
palette[j*3+0] = 0
palette[j*3+1] = 0
palette[j*3+2] = 0
i = 0
while (lab > 0):
palette[j*3+0] |= (((lab >> 0) & 1) << (7-i))
palette[j*3+1] |= (((lab >> 1) & 1) << (7-i))
palette[j*3+2] |= (((lab >> 2) & 1) << (7-i))
i = i + 1
lab >>= 3
return palette
ADE20K_palette = [0,0,0,120,120,120,180,120,120,6,230,230,80,50,50,4,200,
3,120,120,80,140,140,140,204,5,255,230,230,230,4,250,7,224,
5,255,235,255,7,150,5,61,120,120,70,8,255,51,255,6,82,143,
255,140,204,255,4,255,51,7,204,70,3,0,102,200,61,230,250,255,
6,51,11,102,255,255,7,71,255,9,224,9,7,230,220,220,220,255,9,
92,112,9,255,8,255,214,7,255,224,255,184,6,10,255,71,255,41,
10,7,255,255,224,255,8,102,8,255,255,61,6,255,194,7,255,122,8,
0,255,20,255,8,41,255,5,153,6,51,255,235,12,255,160,150,20,0,
163,255,140,140,140,250,10,15,20,255,0,31,255,0,255,31,0,255,224
,0,153,255,0,0,0,255,255,71,0,0,235,255,0,173,255,31,0,255,11,200,
200,255,82,0,0,255,245,0,61,255,0,255,112,0,255,133,255,0,0,255,
163,0,255,102,0,194,255,0,0,143,255,51,255,0,0,82,255,0,255,41,0,
255,173,10,0,255,173,255,0,0,255,153,255,92,0,255,0,255,255,0,245,
255,0,102,255,173,0,255,0,20,255,184,184,0,31,255,0,255,61,0,71,255,
255,0,204,0,255,194,0,255,82,0,10,255,0,112,255,51,0,255,0,194,255,0,
122,255,0,255,163,255,153,0,0,255,10,255,112,0,143,255,0,82,0,255,163,
255,0,255,235,0,8,184,170,133,0,255,0,255,92,184,0,255,255,0,31,0,184,
255,0,214,255,255,0,112,92,255,0,0,224,255,112,224,255,70,184,160,163,
0,255,153,0,255,71,255,0,255,0,163,255,204,0,255,0,143,0,255,235,133,255,
0,255,0,235,245,0,255,255,0,122,255,245,0,10,190,212,214,255,0,0,204,255,
20,0,255,255,255,0,0,153,255,0,41,255,0,255,204,41,0,255,41,255,0,173,0,
255,0,245,255,71,0,255,122,0,255,0,255,184,0,92,255,184,255,0,0,133,255,
255,214,0,25,194,194,102,255,0,92,0,255]
CityScpates_palette = [128,64,128,244,35,232,70,70,70,102,102,156,190,153,153,153,153,153,
250,170,30,220,220,0,107,142,35,152,251,152,70,130,180,220,20,60,255,0,0,0,0,142,
0,0,70,0,60,100,0,80,100,0,0,230,119,11,32,128,192,0,0,64,128,128,64,128,0,192,
128,128,192,128,64,64,0,192,64,0,64,192,0,192,192,0,64,64,128,192,64,128,64,192,
128,192,192,128,0,0,64,128,0,64,0,128,64,128,128,64,0,0,192,128,0,192,0,128,192,
128,128,192,64,0,64,192,0,64,64,128,64,192,128,64,64,0,192,192,0,192,64,128,192,
192,128,192,0,64,64,128,64,64,0,192,64,128,192,64,0,64,192,128,64,192,0,192,192,
128,192,192,64,64,64,192,64,64,64,192,64,192,192,64,64,64,192,192,64,192,64,192,
192,192,192,192,32,0,0,160,0,0,32,128,0,160,128,0,32,0,128,160,0,128,32,128,128,
160,128,128,96,0,0,224,0,0,96,128,0,224,128,0,96,0,128,224,0,128,96,128,128,224,
128,128,32,64,0,160,64,0,32,192,0,160,192,0,32,64,128,160,64,128,32,192,128,160,
192,128,96,64,0,224,64,0,96,192,0,224,192,0,96,64,128,224,64,128,96,192,128,224,
192,128,32,0,64,160,0,64,32,128,64,160,128,64,32,0,192,160,0,192,32,128,192,160,
128,192,96,0,64,224,0,64,96,128,64,224,128,64,96,0,192,224,0,192,96,128,192,224,
128,192,32,64,64,160,64,64,32,192,64,160,192,64,32,64,192,160,64,192,32,192,192,
160,192,192,96,64,64,224,64,64,96,192,64,224,192,64,96,64,192,224,64,192,96,192,
192,224,192,192,0,32,0,128,32,0,0,160,0,128,160,0,0,32,128,128,32,128,0,160,128,
128,160,128,64,32,0,192,32,0,64,160,0,192,160,0,64,32,128,192,32,128,64,160,128,
192,160,128,0,96,0,128,96,0,0,224,0,128,224,0,0,96,128,128,96,128,0,224,128,128,
224,128,64,96,0,192,96,0,64,224,0,192,224,0,64,96,128,192,96,128,64,224,128,192,
224,128,0,32,64,128,32,64,0,160,64,128,160,64,0,32,192,128,32,192,0,160,192,128,
160,192,64,32,64,192,32,64,64,160,64,192,160,64,64,32,192,192,32,192,64,160,192,
192,160,192,0,96,64,128,96,64,0,224,64,128,224,64,0,96,192,128,96,192,0,224,192,
128,224,192,64,96,64,192,96,64,64,224,64,192,224,64,64,96,192,192,96,192,64,224,
192,192,224,192,32,32,0,160,32,0,32,160,0,160,160,0,32,32,128,160,32,128,32,160,
128,160,160,128,96,32,0,224,32,0,96,160,0,224,160,0,96,32,128,224,32,128,96,160,
128,224,160,128,32,96,0,160,96,0,32,224,0,160,224,0,32,96,128,160,96,128,32,224,
128,160,224,128,96,96,0,224,96,0,96,224,0,224,224,0,96,96,128,224,96,128,96,224,
128,224,224,128,32,32,64,160,32,64,32,160,64,160,160,64,32,32,192,160,32,192,32,
160,192,160,160,192,96,32,64,224,32,64,96,160,64,224,160,64,96,32,192,224,32,192,
96,160,192,224,160,192,32,96,64,160,96,64,32,224,64,160,224,64,32,96,192,160,96,
192,32,224,192,160,224,192,96,96,64,224,96,64,96,224,64,224,224,64,96,96,192,224,
96,192,96,224,192,0,0,0]
COCO_palette = [31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 227,
119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 160, 44,
214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207,
31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75,
227, 119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 160, 44,
214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189,
34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75,
227, 119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127,
14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189,
34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 103,
189, 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127,
14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 127, 127
, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 103,
189, 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 14,
44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 127, 127,
127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 103, 189,
140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 44,
160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189, 34, 23, 190,
207, 31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 119, 194,
127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148,
103, 189, 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127,
14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189, 34,
23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 227,
119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39,
40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119,
180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 127, 127,
127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 14]
当以上步骤完成后,即可开始训练了。从“train.py”启动训练:
python train.py --config config.json
或者在Pycharm中直接run。
训练过程中,可以利用tensorboard来查看训练状态:
tensorboard --logdir saved
本篇介绍了模型的训练,那么后续我们使用的时候则需要编写推理代码。虽然这个开源库也提供了推理代码,不过我们要做的是把整个工程的推理部分抽离出来,单独形成一个小而紧凑的工程,只进行推理与可视化等内容。下一篇,也即本系列最后一篇博客,将重点介绍如何把这个推理过程抽离出来,并形成一个精简的工程,以供项目上使用。