PyTorch学习笔记08——加载数据集

发布时间:2022-08-19 14:25

PyTorch学习笔记08——加载数据集

PyTorch学习笔记08——加载数据集_第1张图片
在上一次的糖尿病数据集中,我们是使用整个数据集input计算的。这次考虑mini_batch的输入方式。

三个概念:

epoch:所有训练样本全部轮一遍叫做一个epoch

Batch-Size:批量训练时,每批量包含的样本个数

iteration:每批量轮一遍叫做一个iteration

比如一个数据集有200个样本,把他分成40块,每块就有5个样本。
那么batch = 40, batch_size = 5。
训练的时候,按每块训练,把一块的5个样本轮一遍,叫做1itearion。
40块都轮一遍,就是200个样本都训练了一遍,叫做1个epoch。

DataLoader:一种数据集加载方式
他能帮我们做什么?我们要做小批量训练,为了提高训练的随机性,我们可以对数据集进行shuffle。
当把一个支持索引和长度可知的数据集送到dataloader里,就可以自动对dataset进行小批量生成。

dataset -> Shuffle -> PyTorch学习笔记08——加载数据集_第2张图片
Loader

PyTorch学习笔记08——加载数据集_第3张图片
如何定义你的数据集Dataset?
提供一个概念性代码:

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
class DiabetesDataset(Dataset):
    def __init__(self):
        pass
    def __getitem__(self, index):
        pass
    def __len__(self):
        pass
dataset = DiabetesDataset()
train_loader = DataLoader(dataset = dataset,
                          batch_size = 32,
                          shuffle = True,
                          num_workers = 2)

Pytorch提供了一种Dataset类,这是一种抽象类,我们知道抽象类不能被实例化,但可以被继承。

  • 上面的DiabetesDataset就是我们自己写的一个继承Dataset的类。表达式getitem、len都是魔法函数,分别返回值和数据集的长度。
  • 实例化DiabetesDataset后,通过Dataloader来自动创建小批量数据集。 这里用batch_size, shuffle,
    process number来初始化。

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号