发布时间:2023-11-25 10:00
原始数据一般是图片数据和标注数据,其中标注数据目前有两种,一种是VOC格式的.xml文件存储标注信息,另外一种标注格式CoCo用json来存储标注信息,这次我给大家准备的是VOC格式的转YOLO,此类脚本很多。主要涉及三个方面:
三个方面的准备,为下一步网络训练准备好数据。
下面的脚本主要是对voc数据集格式的数据进行划分训练集 集验证集 测试集,具体内容看注解
import os
import random
from tqdm import tqdm
"""
作者:小小博
时间:2022.3.21
环境:打完游戏的夜晚
本脚本有功能:
1.对VOC格式的数据集进行训练集、验证集、测试集划分
"""
# -------------------------------------------------------#
# voc_annotation.py
# 脚本主要是对voc数据集格式的数据进行划分训练集 集验证集 测试集
# 目录结构如下:
# |-VOCdevkit
# |-VOC2007
# |-Annotations
# |-ImageSets
# |-Main
# |-test.txt
# |-train.txt
# |-trainval.txt
# |-val.txt
# |-JPEGImages
# -------------------------------------------------------#
# -------------------------------------------------------#
# classes_path 是存放voc(20类)类别名称的txt文件
# trainval_percent = train_percent 为训练集和验证集的比例
# 例如数据集有1000张图片 trainval_percent 训练集和验证集占90% train_percent再占90%
# 那么训练集+验证集 = 900 张 测试集 = 100张
# 其中 810张 属于训练集 90张属于验证集 810+90 = 900
# VOCdevkit_path VOC 数据集的目录
# -------------------------------------------------------#
classes_path = 'model_data/voc_classes.txt'
trainval_percent = 0.9
train_percent = 0.9
VOCdevkit_path = 'VOCdevkit'
if __name__ == "__main__":
# -------------------------------------------------------#
# random.seed(0)随机数种子,使得我们每次生成的训练集和测试集划分是一致
# -------------------------------------------------------#
random.seed(0)
# -------------------------------------------------------#
# xmlfilepath 图像标签地址 .xml
# save_path 划分结果的存储地址 .txt
# temp_xml Annotations下所有xml文件名['000001.xml', '000002.xml', ...,'001000.xml']
# 如果 Annotations 还有其他类型的文件 例如 .txt .jpg 得要过滤一下
# total_xml 全部存的是
# -------------------------------------------------------#
xml_path = os.path.join(VOCdevkit_path, 'VOC2007/Annotations')
save_path = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Main')
temp_xml = os.listdir(xml_path)
total_xml = [xml for xml in temp_xml if xml.endswith(".xml")]
num = len(total_xml) # 标签总数
list = range(num) # 标签总数(0,num)
tv = int(num*trainval_percent) # 训练集+验证集的总数
tr = int(tv*train_percent) # 训练集总数
trainval = random.sample(list, tv) # 例如1000张样本 随机选取900张的图片的索引作为训练和验证集索引
train = random.sample(trainval, tr) # 900 索引中再随机810的索引作为训练集索引
print("训练集和验证集的数量: ", tv)
print("训练集的数量: ", tr)
print("验证集的数量: ", tv - tr)
print("测试集的数量: ", num - tv)
ftrainval = open(os.path.join(save_path,'trainval.txt'), 'w') # 训练集和验证集txt
ftrain = open(os.path.join(save_path,'train.txt'), 'w') # 训练集txt
fval = open(os.path.join(save_path,'val.txt'), 'w') # 验证集txt
ftest = open(os.path.join(save_path,'test.txt'), 'w') # 测试集txt
tq_epochs = tqdm(list) # 进度条
for i in tq_epochs:
# -------------------------------------------------------#
# total_xml[i][:-4] 取出每一个文件名从.xml前面文件名不需要文件后缀
# i 就是标签所在的索引号
# 如果 索引 i 在 trainval 就写入训练集和验证集 否则就写入测试集
# 如果 索引 i 在 trainval 同时也在train 就写入训练集 否则 就写入验证集
# -------------------------------------------------------#
name = total_xml[i][:-4]+'\n'
if i in trainval:
ftrainval.write(name)
if i in train:
ftrain.write(name)
else:
fval.write(name)
else:
ftest.write(name)
ftrainval.close() # 关闭写入流
ftrain.close() # 关闭写入流
fval.close() # 关闭写入流
ftest.close() # 关闭写入流
VOC标签文件是存在.xml文件中的,而YOLO需要的格式.txt中 (11 0.344193 0.611 0.416431 0.262) 第一个数是类别索引,后面四个数是(x,y,w,h)归一化的坐标信息,具体操作看下面脚本的注释。
"""
作者:小小博
时间:2022.3.21
本脚本有三个个功能:
1.根据train.txt和val.txt将voc数据集标注信息(.xml)转为yolo标注格式(.txt),生成dataset文件(train+val)
2.统计训练集、验证集和测试集的数据并生成相应train_path.txt和val_path.txt test_path.txt文件
3.创建data.data文件,记录classes个数, train以及val数据集文件(.txt)路径和dataset_classes.names文件路径
"""
import os
from tqdm import tqdm
from lxml import etree
import json
import shutil
from os.path import *
# -------------------------------------------------------#
# 对数据和标签进行处理把VOC格式转为YOLO需要的格式 (class_index ,x,y,w,h)
# dir_path 根目录
# images_path 图片的绝对地址
# xml_path 标签的绝对地址
# train_txt_path 训练集的绝对地址
# val_txt_path 验证集的绝对地址
# test_txt_path 测试集的绝对地址
# label_json_path 标签名称的json文件绝对地址
# save_file_root 保存路径地址
# -------------------------------------------------------#
dir_path = dirname(abspath(__file__))
images_path = os.path.join(dir_path, "VOCdevkit/VOC2007", "JPEGImages")
xml_path = os.path.join(dir_path, "VOCdevkit/VOC2007", "Annotations")
train_txt_path = os.path.join(dir_path, "VOCdevkit/VOC2007", "ImageSets/Main", "train.txt")
val_txt_path = os.path.join(dir_path, "VOCdevkit/VOC2007", "ImageSets/Main", "val.txt")
test_txt_path = os.path.join(dir_path, "VOCdevkit/VOC2007", "ImageSets/Main", "test.txt")
label_json_path = os.path.join(dir_path,"data", "pascal_voc_classes.json")
save_file_root = os.path.join(dir_path, "dataset")
# -------------------------------------------------------#
# 保存训练集、验证集、测试集图片绝对地址到.txt
# 同时把三个绝对地址和类别数保存到dataset_data.txt
# 方便后续进行数据装载操作 dataset
# -------------------------------------------------------#
train_annotation_dir = os.path.join(dir_path, "dataset", "train", "labels")
val_annotation_dir = os.path.join(dir_path, "dataset", "val", "labels")
test_annotation_dir = os.path.join(dir_path, "dataset", "test", "labels")
train_path_txt = os.path.join(dir_path, "train_path.txt")
val_path_txt = os.path.join(dir_path, "val_path.txt")
test_path_txt = os.path.join(dir_path, "test_path.txt")
dataset_data = os.path.join(dir_path, "dataset.data")
classes_label = os.path.join(dir_path, "dataset_classes.names")
# -------------------------------------------------------#
# 检查文件/文件夹都是否存在
# -------------------------------------------------------#
assert os.path.exists(images_path), "images path not exist..."
assert os.path.exists(xml_path), "xml path not exist..."
assert os.path.exists(train_txt_path), "train txt file not exist..."
assert os.path.exists(val_txt_path), "val txt file not exist..."
assert os.path.exists(test_txt_path), "test txt file not exist..."
assert os.path.exists(label_json_path), "label_json_path does not exist..."
# -------------------------------------------------------#
# 如果dataset不存在 就创建一个
# -------------------------------------------------------#
if os.path.exists(save_file_root) is False:
os.makedirs(save_file_root)
# -------------------------------------------------------#
# 将xml文件解析成字典形式
# {'bndbox': {'xmin': '48', 'ymin': '240', 'xmax': '195', 'ymax': '371'}}
# -------------------------------------------------------#
def parse_xml_to_dict(xml):
if len(xml) == 0: # 遍历到底层,直接返回tag对应的信息
return {xml.tag: xml.text}
# -------------------------------------------------------#
# 递归遍历标签信息
# 因为object可能有多个,所以需要放入列表里
# -------------------------------------------------------#
result = {}
for child in xml:
child_result = parse_xml_to_dict(child)
if child.tag != 'object':
result[child.tag] = child_result[child.tag]
else:
if child.tag not in result:
result[child.tag] = []
result[child.tag].append(child_result[child.tag])
return {xml.tag: result}
# -------------------------------------------------------#
# 文件夹的格式
# |-dataset
# |-train
# |-images
# |-labels
# |-val
# |-images
# |-labels
# |-test
# |-images
# |-labels
# 先判断文件夹是否存在,不存在就创建
# -------------------------------------------------------#
def translate_info(file_names: list, save_root: str, class_dict: dict, train_val='train'):
'''
:param file_names: 文件名称列表不含文件后缀
:param save_root: 转换后的存储地址
:param class_dict: json格式{key:value,key:value,...,key:value}标签名索引
:param train_val: 保存文件夹名
:return:
'''
save_txt_path = os.path.join(save_root, train_val, "labels")
if os.path.exists(save_txt_path) is False:
os.makedirs(save_txt_path)
save_images_path = os.path.join(save_root, train_val, "images")
if os.path.exists(save_images_path) is False:
os.makedirs(save_images_path)
# -------------------------------------------------------#
# tqdm 进度条,对处理过程进行可视化,建议自己查API 学习一下
# -------------------------------------------------------#
for file in tqdm(file_names, desc="translate {} file...".format(train_val)):
# -------------------------------------------------------#
# 检查下图像文件是否存在,如果你的图片是.png 自行修改
# -------------------------------------------------------#
img_path = os.path.join(images_path, file + ".jpg")
assert os.path.exists(img_path), "file:{} not exist...".format(img_path)
# -------------------------------------------------------#
# 检查xml文件是否存在
# -------------------------------------------------------#
xml_full_path = os.path.join(xml_path, file + ".xml")
assert os.path.exists(xml_full_path), "file:{} not exist...".format(xml_full_path)
# -------------------------------------------------------#
# 读xml内容并对其进行处理
#
# 353
# 500
# 3
#
# -------------------------------------------------------#
with open(xml_full_path) as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = parse_xml_to_dict(xml)["annotation"]
img_height = int(data["size"]["height"])
img_width = int(data["size"]["width"])
# -------------------------------------------------------#
# 将.xml的内容转换为YOLO的格式并写入.txt文件中
# -------------------------------------------------------#
with open(os.path.join(save_txt_path, file + ".txt"), "w") as f:
assert "object" in data.keys(), "file: '{}' lack of object key.".format(xml_full_path)
for index, obj in enumerate(data["object"]):
# -------------------------------------------------------#
# 获取每个目标框架的(x1,y1),(x2,y2) 左上 和 右下标
# 例如:
#
# dog
# Left
# 1
# 0
#
# 48
# 240
# 195
# 371
#
#
# -------------------------------------------------------#
xmin = float(obj["bndbox"]["xmin"])
xmax = float(obj["bndbox"]["xmax"])
ymin = float(obj["bndbox"]["ymin"])
ymax = float(obj["bndbox"]["ymax"])
# -------------------------------------------------------#
# 通过key 去查 value 把目标的名 数字化
# -------------------------------------------------------#
class_index = class_dict[obj["name"]] - 1 # 目标id从0开始
# -------------------------------------------------------#
# 将box信息转换到yolo格式
# -------------------------------------------------------#
x_center = xmin + (xmax - xmin) / 2
y_center = ymin + (ymax - ymin) / 2
w = xmax - xmin
h = ymax - ymin
# -------------------------------------------------------#
# yolo格式(class_index,x,y,w,h) 进行归一化 方便网络训练
# 绝对坐标转相对坐标,保存6位小数
# info=['6', '0.13', '0.857357', '0.116', '0.141141']
# -------------------------------------------------------#
x_center = round(x_center / img_width, 6)
y_center = round(y_center / img_height, 6)
w = round(w / img_width, 6)
h = round(h / img_height, 6)
info = [str(i) for i in [class_index, x_center, y_center, w, h]]
# -------------------------------------------------------#
# 如果只有一个目标框就直接写入,并且用空格分隔
# 如果有多个目标框就要加换行符,并且用空格分隔
# -------------------------------------------------------#
if index == 0:
f.write(" ".join(info))
else:
f.write("\n" + " ".join(info))
# -------------------------------------------------------#
# 将图像复制到 save_images_path
# -------------------------------------------------------#
shutil.copyfile(img_path, os.path.join(save_images_path, img_path.split(os.sep)[-1]))
def create_class_names(class_dict: dict):
# -------------------------------------------------------#
# keys集合 目标名称 ['aeroplane', 'bicycle', 'bird',...,'tvmonitor']
# -------------------------------------------------------#
keys = class_dict.keys()
with open(dir_path+"/dataset_classes.names", "w") as w:
for index, k in enumerate(keys):
if index + 1 == len(keys):
w.write(k)
else:
w.write(k + "\n")
# -------------------------------------------------------#
# 创建记录图像的列表 .txt
# 例如:
# F:\mypytorch\pythonProject1\myYolo2\dataset\test\images\000007.jpg
# F:\mypytorch\pythonProject1\myYolo2\dataset\test\images\000035.jpg
# F:\mypytorch\pythonProject1\myYolo2\dataset\test\images\000036.jpg
# F:\mypytorch\pythonProject1\myYolo2\dataset\test\images\000049.jpg
# F:\mypytorch\pythonProject1\myYolo2\dataset\test\images\000055.jpg
# txt_path = train_path.txt val_path.txt test_path.txt保存地址
# dataset_dir = F:\mypytorch\pythonProject1\myYolo2\dataset\test\labels
# -------------------------------------------------------#
def calculate_data_txt(txt_path, dataset_dir):
with open(txt_path, "w") as w:
for file_name in os.listdir(dataset_dir):
print(file_name)
if file_name == "classes.txt":
continue
# -------------------------------------------------------#
# F:\mypytorch\pythonProject1\myYolo2\dataset\test\labels
# 将上述的绝对地址labels替换为images,同时把文件名从.分割加上JPG格式
# 在加上换行符 一条一条的写入 train_path.txt val_path.txt test_path.txt 文件
# -------------------------------------------------------#
img_path = os.path.join(dataset_dir.replace("labels", "images"),
file_name.split(".")[0]) + ".jpg"
line = img_path + "\n"
assert os.path.exists(img_path), "file:{} not exist!".format(img_path)
w.write(line)
w.close()
# -------------------------------------------------------#
# 创建记录数据的列表 dataset.data
# 例如:
# classes=20
# train=F:\mypytorch\pythonProject1\myYolo2\train_path.txt
# valid=F:\mypytorch\pythonProject1\myYolo2\val_path.txt
# test=F:\mypytorch\pythonProject1\myYolo2\test_path.txt
# names=F:\mypytorch\pythonProject1\myYolo2\dataset_classes.names
# -------------------------------------------------------#
def create_dataset_data(create_data_path, label_path, train_path, val_path,test_path, classes_info):
with open(create_data_path, "w") as w:
w.write("classes={}".format(len(classes_info)) + "\n") # 记录类别个数
w.write("train={}".format(train_path) + "\n") # 记录训练集对应txt文件路径
w.write("valid={}".format(val_path) + "\n") # 记录验证集对应txt文件路径
w.write("test={}".format(test_path) + "\n") # 记录测试集对应txt文件路径
w.write("names={}".format(classes_label) + "\n") # 记录label.names文件路径
w.close()
def main():
# -------------------------------------------------------#
# 读入json文件,并转为为json格式{key:value,key:value,...,key:value}形式
# -------------------------------------------------------#
json_file = open(label_json_path, 'r')
class_dict = json.load(json_file)
# -------------------------------------------------------#
# 读取train.txt中的所有行信息,删除空行
# -------------------------------------------------------#
with open(train_txt_path, "r") as r:
train_file_names = [i for i in r.read().splitlines() if len(i.strip()) > 0]
# -------------------------------------------------------#
# 读取训练集voc格式转换为YOLO格式
# -------------------------------------------------------#
translate_info(train_file_names, save_file_root, class_dict, "train")
with open(val_txt_path, "r") as r:
val_file_names = [i for i in r.read().splitlines() if len(i.strip()) > 0]
# -------------------------------------------------------#
# 读取验证集voc格式转换为YOLO格式
# -------------------------------------------------------#
translate_info(val_file_names, save_file_root, class_dict, "val")
with open(test_txt_path, "r") as r:
test_file_names = [i for i in r.read().splitlines() if len(i.strip()) > 0]
# -------------------------------------------------------#
# 读取测试集voc格式转换为YOLO格式
# -------------------------------------------------------#
translate_info(test_file_names, save_file_root, class_dict, "test")
# -------------------------------------------------------#
# 创建dataset_classes.names文件
# -------------------------------------------------------#
create_class_names(class_dict)
# -------------------------------------------------------#
# 统计训练集和验证集的数据并生成相应txt文件
# -------------------------------------------------------#
calculate_data_txt(train_path_txt, train_annotation_dir)
calculate_data_txt(val_path_txt, val_annotation_dir)
calculate_data_txt(test_path_txt, test_annotation_dir)
classes_info = [line.strip() for line in open(classes_label, "r").readlines() if len(line.strip()) > 0]
# -------------------------------------------------------#
# dataset.data文件,记录classes个数, train、val、test数据集文件(.txt)路径和label.names文件路径
# -------------------------------------------------------#
create_dataset_data(dataset_data, classes_label, train_path_txt, val_path_txt,test_path_txt, classes_info)
if __name__ == "__main__":
main()
主要涉及了对读入图片的进行规格化,对图片大小进行统一(填充),例如图片大小统一为416×416像素,同时可以在数据加载函数中使用,多尺度训练、数据增强等策略,这部分可以自行设计,使用DataLoader进行数据装载,后面训练部分会详细介绍。
import random
import os
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import torchvision.transforms as transforms
"""
作者:小小博
时间:2022.3.20
环境:夜深人静的晚上
本脚本有功能:
1.对应train_path.txt路径的数据进行YOLO需要的数据进行格式化
"""
# -------------------------------------------------------#
# 图像尺寸统一函数1
# -------------------------------------------------------#
def pad_to_square(img, pad_value):
# -------------------------------------------------------#
# dim_diff = np.abs(h - w) 求出高和宽的绝对值差
# 填补只有三种方式:
# 1、h=w 不需要填
# 2、h>w 需要对图像的宽左右进行填补
# 3、h
# 填补计算只会出现奇数和偶数,如果出现奇数就对上或者左多一像素进行填补
# pad1, pad2 是填充像素值
# pad = (左, 右, 上, 下)
# 如果 w >= h 宽度大于高度 就对图片上下进行填充
# 否则 对图像左右进行填充
# 填充常量 constant 具体这个函数的内容可以自己去查一查
# -------------------------------------------------------#
c, h, w = img.shape
dim_diff = np.abs(h - w)
pad1, pad2 = dim_diff // 2, dim_diff - dim_diff // 2
pad = (0, 0, pad1, pad2) if h <= w else (pad1, pad2, 0, 0)
img = F.pad(img, pad, "constant", value=pad_value)
return img, pad
# -------------------------------------------------------#
# 图像尺寸统一函数2
# 把图像重新调整为需要的尺寸 对其进行采样
# 可以使用最邻近上采样、线性插值法、双线性插值法等
# ’nearest’, ‘linear’, ‘bilinear’, ‘bicubic’ , ‘trilinear’和’area’. 默认使用’nearest’
# -------------------------------------------------------#
def img_resize(image, size):
image = F.interpolate(image.unsqueeze(0), size=size, mode="nearest").squeeze(0)
return image
# -------------------------------------------------------#
# ListDataset 数据预处理类
# -------------------------------------------------------#
class ListDataset(Dataset):
def __init__(self, list_path, img_size=416, augment=True, multiscale=True, normalized_labels=True):
# -------------------------------------------------------#
# 获取对应的 dataset\train\images\下面所有图片的绝对地址
# self.img_path
# D:\mypro\myYolo2\dataset\train\images\000001.jpg 图片的地址
# -------------------------------------------------------#
with open(list_path, "r") as file:
self.img_path = file.readlines()
# -------------------------------------------------------#
# 获取对应的 dataset\train\labels\下面所有图片的b标签绝对地址
# self.label_path
# D:\mypro\myYolo2\dataset\train\labels\000001.txt 标签的地址
# path 是对应图片的绝地地址,把地址中的images替换为labels 再把图片.png或者.jpg格式换为.txt
# 就得到了对应图片的标签绝绝对
# -------------------------------------------------------#
self.label_path = [path.replace("images", "labels").replace(".png", ".txt").replace(".jpg", ".txt") for path in self.img_path ]
self.img_size = img_size
self.max_objects = 100
self.augment = augment # 是否开启图像增强策略
self.multiscale = multiscale # 是否开启多尺度训练策略
self.normalized_labels = normalized_labels # 标签标准化
self.min_size = self.img_size - 3 * 32 # 输入416 min_size = 320 多尺度训练时用
self.max_size = self.img_size + 3 * 32 # 输入416 max_size = 512 多尺度训练时用
self.batch_count = 0 # 记录batch数 可以对多尺度训练进行设置
# -------------------------------------------------------#
# __getitem__方法
# 使用索引访问元素时
# 如果对象为datas,data[key]取值,当实例对象做pdata[0] 运算时,会调用类中的方法__getitem_
# -------------------------------------------------------#
def __getitem__(self, index):
# -------------------------------------------------------#
# 获取你索引的图像绝对地址,取余操作的好处就是更大的容错
# rstrip()去掉末尾的空格
# -------------------------------------------------------#
img_path = self.img_path[index % len(self.img_path)].rstrip()
# -------------------------------------------------------#
# Image 读出的图像数据格式是[h,w,c]要转为换为[c,h,w]tensor格式
# -------------------------------------------------------#
img = transforms.ToTensor()(Image.open(img_path).convert('RGB'))
# -------------------------------------------------------#
# img.shape = [3,416,416] len(img.shape) = 3
# 如果读入的图像是灰度图片只有一个通道,需要统一为三通道 img
# unsqueeze(0)升高维度 将[h,w] -> [1,h,w]
# expand(3,h,w) [1,h,w] -> 3,h,w]
# -------------------------------------------------------#
if len(img.shape) != 3:
img = img.unsqueeze(0)
img = img.expand(3, img.shape[1],img.shape[2])
# -------------------------------------------------------#
# _, h, w 分别为图像的 通道数 高度 宽度
# 如果 self.normalized_label = True 标签标准化
# h_factor, w_factor = (h, w)
# 否则 h_factor, w_factor = (1, 1)
# -------------------------------------------------------#
_, h, w = img.shape
h_factor, w_factor = (h, w) if self.normalized_labels else (1, 1)
# -------------------------------------------------------#
# 图像填充,输入网络的图像数据必须满足 高和宽相等,如果不相等就需要对其短边进行填充
# pad_to_square 图像填充函数 详细注解请看函数
# voc 数据集里数据大多数的 要么长为500像素 或者宽为500像素
# 最后都统一为500*500像素
# -------------------------------------------------------#
img, pad = pad_to_square(img, 0)
_, padded_h, padded_w = img.shape # padded_h = 500 padded_w = 500
# -------------------------------------------------------#
# 标签操作
# label_path 对应图像的标签绝对地址
# -------------------------------------------------------#
label_path = self.label_path[index % len(self.img_path)].rstrip()
targets = None
if os.path.exists(label_path):
# -------------------------------------------------------#
# 读入YOLO的标签格式 [class,x,y,w,h] 类别 中心坐标(x,y)和 宽高
# label_path 对应图像的标签绝对地址
# boxex.shape (class_num,5) class_num 为该张图像中标记的目标个数
# -------------------------------------------------------#
boxes = torch.from_numpy(np.loadtxt(label_path).reshape(-1, 5))
# -------------------------------------------------------#
# 将 (x,y,w,h) 格式转为 (x1,y1,x2,y2)
# (x1,y1) 目标左上角位置 (x2,y2)图片右下角位置 还原在原始图片上
# -------------------------------------------------------#
x1 = w_factor * (boxes[:, 1] - boxes[:, 3] / 2)
y1 = h_factor * (boxes[:, 2] - boxes[:, 4] / 2)
x2 = w_factor * (boxes[:, 1] + boxes[:, 3] / 2)
y2 = h_factor * (boxes[:, 2] + boxes[:, 4] / 2)
# -------------------------------------------------------#
# 原始标签是相对于原始图像的位置 原图尺寸为(500,336)
# 现在填补的图 为(500,500) 所以目标相对位置需要进行改变
# pad = (左, 右, 上, 下)
# 最简单的方法自己画图理解这部分 数形结合
# -------------------------------------------------------#
x1 += pad[0]
y1 += pad[2]
x2 += pad[0]
y2 += pad[2]
# x2 += pad[1]
# y2 += pad[3]
# -------------------------------------------------------#
# 将 (x1,y1,x2,y2)转为 (x,y,w,h)
# (x1+x2)/2 得到中心点x坐标 再除填充后的图像宽度归一化
# boxes[:, 3] *= w_factor / padded_w 先乘以原始宽度恢复目标的宽度再除以填充后的宽度
# -------------------------------------------------------#
boxes[:, 1] = ((x1 + x2) / 2) / padded_w
boxes[:, 2] = ((y1 + y2) / 2) / padded_h
boxes[:, 3] *= w_factor / padded_w
boxes[:, 4] *= h_factor / padded_h
# -------------------------------------------------------#
# boxes.shape (class_num,5)目标数量和加上 5 = [种类,x,y,w,h]
# targets.shape (class_num,6) 6 = [batch_index,种类,x,y,w,h]
# 后面我们对 DataLoader 时需要用第一个数来记录是第一个batch里面的图片
# -------------------------------------------------------#
targets = torch.zeros((len(boxes), 6))
targets[:, 1:] = boxes
# -------------------------------------------------------#
# 图像增强策略,根据自己需求添加
# -------------------------------------------------------#
# if self.augment:
# if np.random.random() < 0.5:
# #随机水平翻转
#
return img_path, img, targets
# -------------------------------------------------------#
# DataLoader的collate_fn 会把当前的批的内容传到collate_fn
# DataLoader(dataset,batch_size=32,shuffle=True,num_workers=0, pin_memory=True,collate_fn=dataset.collate_fn)
# 最终返回是这个函数的返回值
# -------------------------------------------------------#
def collate_fn(self, batch):
# -------------------------------------------------------#
# batch进行解包里面包含了 paths, imgs, targets
# -------------------------------------------------------#
paths, imgs, targets = list(zip(*batch))
# -------------------------------------------------------#
# 对标签进行非空判断,可能会存在一些图片没有标注的情况
# -------------------------------------------------------#
targets = [boxes for boxes in targets if boxes is not None]
# -------------------------------------------------------#
# 之前在第一个位置全部是填充0,现在有batch_size 就要为每个添加图片索引
# -------------------------------------------------------#
for i, boxes in enumerate(targets):
boxes[:, 0] = i
# -------------------------------------------------------#
# 每十个batch随机resize到不同尺寸 并且在区间[320,512]之间
# 且必须满足32的倍数的尺寸
# 主干网了Darknet没有全连接层,所以可以通过随机输入图像大小来增加模型泛化的能力
# -------------------------------------------------------#
if self.multiscale and self.batch_count % 10 == 0:
self.img_size = random.choice(range(self.min_size, self.max_size + 1, 32))
# torch.stack:沿着一个新维度对输入张量序列进行连接
imgs = torch.stack([img_resize(img, self.img_size) for img in imgs])
self.batch_count += 1
return paths, imgs, targets
def __len__(self):
# -------------------------------------------------------#
# __len__()的作用是返回容器中元素的个数,可以自己设置是返回对象还是属性
# 通过len()函数
# -------------------------------------------------------#
return len(self.img_path)