发布时间:2022-08-19 14:25
在上一次的糖尿病数据集中,我们是使用整个数据集input计算的。这次考虑mini_batch的输入方式。
三个概念:
epoch:所有训练样本全部轮一遍叫做一个epoch
Batch-Size:批量训练时,每批量包含的样本个数
iteration:每批量轮一遍叫做一个iteration
比如一个数据集有200个样本,把他分成40块,每块就有5个样本。
那么batch = 40, batch_size = 5。
训练的时候,按每块训练,把一块的5个样本轮一遍,叫做1个itearion。
这40块都轮一遍,就是200个样本都训练了一遍,叫做1个epoch。
DataLoader:一种数据集加载方式
他能帮我们做什么?我们要做小批量训练,为了提高训练的随机性,我们可以对数据集进行shuffle。
当把一个支持索引和长度可知的数据集送到dataloader里,就可以自动对dataset进行小批量生成。
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
class DiabetesDataset(Dataset):
def __init__(self):
pass
def __getitem__(self, index):
pass
def __len__(self):
pass
dataset = DiabetesDataset()
train_loader = DataLoader(dataset = dataset,
batch_size = 32,
shuffle = True,
num_workers = 2)
Pytorch提供了一种Dataset类,这是一种抽象类,我们知道抽象类不能被实例化,但可以被继承。