Faster rcnn 训练 安全帽数据集搭载服务器运行

发布时间:2024-11-24 16:01

文章目录

  • 前言
  • 一、faster_rcnn
  • 二、数据处理
    • 1.坐标跟标签
    • 2.模型所需要的格式
    • 3.读取图像进行处理
    • 4.加载torchvision模型
    • 5.优化器
    • 6.获取xml文件列表
    • 7.训练网络
    • 8.保存模型
  • 三.服务器运行
  • 总结


前言

在目标检测领域,有两种方式,一种是two_stage 比如faster_rcnn mask_rcnn 还有一种是one_stage 比如 yolo 这两种的优缺点很容易看出来one_stage 速度非常快,适合做实时检测,但是精度不是很高,two_stage速度慢,效果好,本文使用torchvision中的 faster rcnn 训练 安全帽数据集。


一、faster_rcnn

\"Faster
这是faster_rcnn网络结构,在这里简单说一下,输入图像经过主干网络提取特征,经过区域建议网络(RPN)得到候选框,候选框经过(ROI)resize到一个固定的尺寸,最后经过fc做分类与回归

二、数据处理

1.坐标跟标签

我们通过labelimg对图片就行标注,得到xml文件,我们要对xml文件进行处理,得到我们想要的结果

def parse_objects(xml_root):
    objects=xml_root.findall(\'object\')

    size=xml_root.find(\'size\')
    height=size.find(\'height\').text
    width=size.find(\'width\').text

    bndboxes=[]
    labels=[]

    for object in objects:
        bndbox=object.find(\'bndbox\')
        if object.find(\'name\').text==\'hat\':
            labels.append(1)
        else :
            labels.append(0)
        x_min=bndbox.find(\'xmin\').text
        y_min = bndbox.find(\'ymin\').text
        x_max = bndbox.find(\'xmax\').text
        y_max = bndbox.find(\'ymax\').text

        bndboxes.append((float(x_min)/float(width),float(y_min)/float(height),
                         float(x_max)/float(width),float(y_max)/float(height)))
    return bndboxes,labels

2.模型所需要的格式

我们得到的坐标跟标签不能直接输入到模型进行训练,要进行一些处理,首先存入到字典,最后存入到列表。

def parse_batch_xmls(batch_xml_files):
    targets=[]
    for xml_file in batch_xml_files:
        file_bndbox_info={}

        xmlparse=ET.parse(xml_file)
        xml_root=xmlparse.getroot()
        bnd_boxes,labels=parse_objects(xml_root)
        file_bndbox_info[\'boxes\']=torch.tensor(bnd_boxes,dtype=torch.float32,device=gpu)
        file_bndbox_info[\'labels\']=torch.tensor(labels,dtype=torch.int64,device=gpu)
        targets.append(file_bndbox_info)
    return targets

3.读取图像进行处理

这里到注意的是opencv不能直接读取中文路径下的图片,可以PIL进行读取,转换一下就可以了。

def read_batch_images(batch_files):
    batch_images=[]
    for file in batch_files:
        image=Image.open(file)
        image=image.convert(\'RGB\')
        image=image.resize((512,512))
        image=np.transpose(image,(2,0,1)).astype(np.float32)/255.
        image=torch.tensor(image,device=gpu)
        batch_images.append(image)

    return batch_images
def conver_xml2jpg_file(jpg_dir,batch_xml_files):
    batch_files=[]
    for xml_file in batch_xml_files:
        basename=os.path.basename(xml_file)

        basename_jpg=basename.replace(\'.xml\',\'.jpg\')
        basename_png=basename.replace(\'.xml\',\'.png\')
        jpg_path=os.path.join(jpg_dir,basename_jpg)
        png_path=os.path.join(jpg_dir,basename_png)

        img_path=jpg_path if os.path.exists(jpg_path) else png_path
        assert  os.path.exists(img_path),\"{0} not exist\".format(img_path)
        batch_files.append(img_path)
    return batch_files

4.加载torchvision模型

model=torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False).to(device=gpu)

5.优化器

opt=torch.optim.Adam(model.parameters(),lr=1e-4,weight_decay=0.001)

6.获取xml文件列表

xmls_list=glob.glob(os.path.join(xml_path,\'*.xml\'))

7.训练网络

for epoch in range(epoches):

    for index in range(0,len(xmls_list),batchsize):

        batch_xml_files=xmls_list[index:index+batchsize]
        targets=parse_batch_xmls(batch_xml_files)
        batch_xml_files=conver_xml2jpg_file(jpg_path,batch_xml_files)
        images=read_batch_images(batch_xml_files)
        opt.zero_grad()
        torch.cuda.empty_cache()
        output=model(images,targets)
        loss=0
        for value in output.values():
            loss+=value
        loss.backward()
        opt.step()
        print(\'loss\',loss.item())

8.保存模型

torch.save(model.state_dict(),\'faster_RCNN_train.pth\')

到这里代码部分就全部结束了

三.服务器运行

1.我们的电脑无法支持海量数据的深度学习的训练,所以这里要使用到服务器,本文使用的服务器是矩池云,附链接:https://www.matpay.net/
2.我们需要配置服务器环境,这个平台有一个好处就是有配置好的深度学习环境,可以直接进行使用,可以说很方便。
\"Faster
3.pycharm连接服务器
选择ssh解释器,输入主机名,端口,用户名
\"Faster
下一步,输入密码
之后配置服务器环境路径
下图,是配置好的结果
\"请添加图片描述\"
到这里,服务器就配置完成了。
代码会自动上传,这里要注意更改代码数据集路径
下图是服务器数据集路径
\"Faster
最后直接运行就可以了。
\"Faster
可以查看服务器的使用情况

\"Faster

总结

本文只是简单介绍了通过torchvision加载faster_rcnn 模型进行自己数据集的训练,然后搭载服务器上运行,得到我们想要的模型,进行测试。觉得本文对自己有帮助的,可以三连支持一波!

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号