发布时间:2023-09-19 11:00
dataloader的使用,相关参数用法可以参考官方文档说明:
dataloader
import torchvision
#准备测试数据
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
test_data = torchvision.datasets.CIFAR10(\"./dataset\",train=True,transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data,batch_size=64,shuffle=True,num_workers=0,drop_last=True)
#batch_size (int, optional) – how many samples per batch to load (default: 1).
#shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).
#num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
#drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)
#测试数据集的第一张图片
img,target = test_data[0]
print(img.shape)
print(target)
writer = SummaryWriter(\"dataloader\")
for epoch in range(2): #两轮读取
step = 0
for data in test_data:
imgs,targets = data
writer.add_image(\"Epoch:{}\".format(epoch),imgs,step)
step = step+1
writer.close()