发布时间:2024-02-02 13:30
PyTorch数据读取在Dataloader模块下,Dataloader又可以分为DataSet与Sampler。Sampler模块的功能是生成索引(样本序号);DataSet是依据索引读取Img、Lable。我们主要学习Dataloader与Dataset。
torch.utils.data.DataLoader()
DataLoader(dataset,
batch_size=1,shuffle=False,sampler=None,
batch_sampler=None,num_workers=0,
collate_fn=None,pin_memory=False,drop_last=False,timeout=0,
worker_init_fn=None,
multiprocessing_context=None)
功能:构建可迭代的数据装载器
Epoch:所有训练样本都已输入到模型中,称为一个Epoch
Iteration:一批样本输入到模型中,称之为一个lteration
Batchsize:批大小,决定一个Epoch有多少个lteration
样本总数:80,Batchsize : 8
1 Epoch = 10 lteration
样本总数:87, Batchsize: 8
1 Epoch = 10 lteration ? drop_last = True
1 Epoch = 11 lteration drop_last = False
torch.utils.data.Dataset()
class Dataset(object):
def __getitem__(self,index):
raise NotImplementedError
def __add__(self, other) :
return ConcatDataset([self, other])
功能: Dataset抽象类,所有自定义的Dataset需要继承它,并且复写__getitem__()
getitem:接收一个索引,返回一个样本
数据读取流程如下:
for i, data in enumerate(train_loader):
==>
# 判断是单进程还是多进程
def __iter__(self):
# 单进程
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
# 多进程
else:
return _MultiProcessingDataLoaderIter(self)
==>
# 以单进程为例
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self.timeout == 0
assert self.num_workers == 0
self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset, self.auto_collation, self.collate_fn, self.drop_last)
# 这个函数告诉我们每个iteration中读哪些数据
def __next__(self):
#
index = self._next_index() # may raise StopIteration
data = self.dataset_fetcher.fetch(index) # may raise StopIteration
if self.pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
next = __next__ # Python 2 compatibility
==>
def _next_index(self):
return next(self.sampler_iter) # may raise StopIteration
==>
# 利用sampler输出的index来进行采样
def __iter__(self):
batch = []
#
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
==>
class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
def fetch(self, possibly_batched_index):
if self.auto_collation:
# 这一步实现了正式的数据读取
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
==>
class RMBDataset(Dataset):
def __init__(self, data_dir, transform=None):
\"\"\"
rmb面额分类任务的Dataset
:param data_dir: str, 数据集所在路径
:param transform: torch.transform,数据预处理
\"\"\"
self.label_name = {\"1\": 0, \"100\": 1}
self.data_info = self.get_img_info(data_dir) # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
self.transform = transform
def __getitem__(self, index):
# 根据索引index获得数据与标签
path_img, label = self.data_info[index]
img = Image.open(path_img).convert(\'RGB\') # 0~255
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
def __len__(self):
return len(self.data_info)
@staticmethod
def get_img_info(data_dir):
data_info = list()
# 遍历一个目录内,各个子目录与子文件
for root, dirs, _ in os.walk(data_dir):
# 遍历类别
for sub_dir in dirs:
img_names = os.listdir(os.path.join(root, sub_dir))
img_names = list(filter(lambda x: x.endswith(\'.jpg\'), img_names))
# 遍历图片
for i in range(len(img_names)):
img_name = img_names[i]
path_img = os.path.join(root, sub_dir, img_name)
label = rmb_label[sub_dir]
data_info.append((path_img, int(label)))
return data_info
==>
class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
# 数据的整理器,将读取到的数据整理成batch的形式
return self.collate_fn(data)
==>
for i, data in enumerate(train_loader):
# forward
# data由两个Tensor组成
inputs, labels = data
如果对您有帮助,麻烦点赞关注,这真的对我很重要!!!如果需要互关,请评论或者私信!