发布时间:2023-04-05 09:00
转自AI Studio,原文链接:百度网盘AI大赛-图像处理挑战赛:文档检测优化赛 Baseline - 飞桨AI Studio
使用Resnet152回归图像中文档的拐角坐标完成百度网盘AI大赛-图像处理挑战赛:文档检测优化赛。
比赛链接
生活中人们使用手机进行文档扫描逐渐成为一件普遍的事情,为了提高人们的使用体验,我们期望通过算法技术去除杂乱的拍摄背景并精准框取文档边缘,选手需要通过深度学习技术训练模型,对给定的真实场景下采集得到的带有拍摄背景的文件图片进行边缘智能识别,并最终输出处理后的扫描结果图片。
本次比赛要求选手设计算法在给定图片中划定一块四边形区域,以尽可能与图片中的文档部分重合。
因此,本次任务可以同时看作回归问题和分割问题。
本项目将本次任务看作回归问题来处理,使用Resnet152+Linear的网络结构回归四个角的坐标。
In [ ]
! wget https://staticsns.cdn.bcebos.com/amis/2022-4/1649731549425/train_datasets_document_detection_0411.zip
! unzip -oq /home/aistudio/train_datasets_document_detection_0411.zip
! rm -rf __MACOSX
! rm -rf /home/aistudio/train_datasets_document_detection_0411.zip
通过paddle.io.dataset构造读取器,便于读取数据。
数据预处理包括:
In [ ]
import paddle
import numpy as np
import pandas as pd
import cv2
class MyDateset(paddle.io.Dataset):
def __init__(self, mode = \'train\', train_imgs_dir = \'/home/aistudio/train_datasets_document_detection_0411/images/\', train_txt = \'/home/aistudio/train_datasets_document_detection_0411/data_info.txt\'):
super(MyDateset, self).__init__()
self.mode = mode
self.train_imgs_dir = train_imgs_dir
with open(train_txt,\'r\') as f:
self.train_infor = f.readlines()
def __getitem__(self, index):
item = self.train_infor[index][:-1]
splited = item.split(\',\')
img_name = splited[0]
img = cv2.imread(self.train_imgs_dir+img_name+\'.jpg\')
h, w, c = img.shape
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# 对图片进行resize,调整明暗对比度等参数
img = paddle.vision.transforms.resize(img, (512,512), interpolation=\'bilinear\')
if np.random.rand()<1/3:
img = paddle.vision.transforms.adjust_brightness(img, np.random.rand()*2)
else:
if np.random.rand()<1/2:
img = paddle.vision.transforms.adjust_contrast(img, np.random.rand()*2)
else:
img = paddle.vision.transforms.adjust_hue(img, np.random.rand()-0.5)
img = img.transpose((2,0,1))
img = img/255
sites = []
for i in range(1,len(splited),2):
sites.append([float(splited[i])/w,float(splited[i+1])/h])
label = []
for i in range(4):
x, y = self.get_corner(sites, i+1)
label.append(x)
label.append(y)
img = paddle.to_tensor(img).astype(\'float32\')
label = paddle.to_tensor(label).astype(\'float32\')
return img, label
def get_corner(self, sites, corner_flag):
# corner_flag 1:top_left 2:top_right 3:bottom_right 4:bottom_left
if corner_flag == 1:
target_sites = [0,0]
elif corner_flag == 2 :
target_sites = [1,0]
elif corner_flag == 3 :
target_sites = [1,1]
elif corner_flag == 4 :
target_sites = [0,1]
min_dis = 3
best_x = 0
best_y = 0
for site in sites:
if abs(site[0]-target_sites[0])+abs(site[1]-target_sites[1])
In [ ]
class MyNet(paddle.nn.Layer):
def __init__(self):
super(MyNet,self).__init__()
self.resnet = paddle.vision.models.resnet152(pretrained=True, num_classes=0)
self.flatten = paddle.nn.Flatten()
self.linear = paddle.nn.Linear(2048, 8)
def forward(self, img):
y = self.resnet(img)
y = self.flatten(y)
out = self.linear(y)
return out
第一次训练后参数为0.66左右,重复训练+调整学习率可以达到0.89左右。
In [ ]
model = MyNet()
model.train()
train_dataset=MyDateset()
# 需要接续之前的模型重复训练可以取消注释
# param_dict = paddle.load(\'./model.pdparams\')
# model.load_dict(param_dict)
train_dataloader = paddle.io.DataLoader(
train_dataset,
batch_size=16,
shuffle=True,
drop_last=False)
max_epoch=10
scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.0001, T_max=max_epoch)
opt = paddle.optimizer.Adam(learning_rate=scheduler, parameters=model.parameters())
now_step=0
for epoch in range(max_epoch):
for step, data in enumerate(train_dataloader):
now_step+=1
img, label = data
pre = model(img)
loss = paddle.nn.functional.square_error_cost(pre,label).mean()
loss.backward()
opt.step()
opt.clear_gradients()
if now_step%100==0:
print(\"epoch: {}, batch: {}, loss is: {}\".format(epoch, step, loss.mean().numpy()))
paddle.save(model.state_dict(), \'model.pdparams\')
本题目提交需要提交对应的模型和预测文件。predict.py需要读取同目录下的模型信息,并预测坐标点-保存为json或预测分割后的图片-保存为图片形式。
想要自定义训练模型,只需要将predict.py中的模型和process函数中的do something 替换为自己的模型内容即可。
提交分割模型时,取消predict中52行的注释部分即可保存分割后的图片信息
如果不想自己反复训练模型可以直接从fork后就有的model.pdparams文件开始训练,这个模型精度为0.88~
In [ ]
# 压缩可提交文件
! zip submit.zip model.pdparams predict.py
本项目使用极简的方式完成了百度网盘AI大赛-图像处理挑战赛:文档检测优化赛,但仍有改进的空间。比如:
最后,祝大家都能有好成绩!
请点击此处查看本环境基本用法.
Please click here for more detailed instructions.