pytorch学习笔记(二)——加载数据Dataset以及Dataloader的使用

发布时间:2024-10-14 11:01

1. pytorch加载数据涉及两个类

  • Dataset
    功能:
    1. 获取每一个数据及其label
    2. 告诉我们总共有多少数据
  • Dataloader:
    为后面的网络提供不同的数据形式。可以从dataset中取数据,把数据加载到神经网络中,理解成数据加载器

2. Dataset的使用

# 注意 本例用的数据集形式是 把label放在文件夹名字上。
# 但是还有其他类型的数据集 比如说地址等类型的数据 需要对于每一个数据存放一个txt存标签

from torch.utils.data import Dataset
from PIL import Image
import os


# PIL python image library 图像处理库
# os python提供的一个os模块,包含很多操作文件和目录的函数

# 继承Dataset类
class MyData(Dataset):

    # 类的初始化函数,创建实例时运行,为class类提供全局变量
    def __init__(self, root_dir, label_dir):
        # self.xx的变量理解为类的全局变量
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path_list = os.listdir(self.path)

    def __getitem__(self, item):
        img_name = self.img_path_list[item]
        # 获取每个图片的路径
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        # 读取图片 和标签
        img = Image.open(img_item_path)
        label = self.label_dir
        return img, label

    def __len__(self):
        return len(self.img_path_list)


# 创建类实例
root_dir = \"dataset/train\"
ants_label_dir = \"ants\"
bees_label_dir = \"bees\"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)

# 把两个数据集拼接
train_dataset = ants_dataset + bees_dataset

# 测试一下 看获取的图片
img, label = ants_dataset[0]
img.show()
print(len(train_dataset))
print(len(ants_dataset))
print(len(bees_dataset))

3. Dataloader的使用

import torchvision
from torch.utils.data import DataLoader

# 1.准备的测试据集
from torch.utils.tensorboard import SummaryWriter

test_data = torchvision.datasets.CIFAR10(\"./datasets\", train=False, transform=torchvision.transforms.ToTensor())

# 2.DtaLoader加载数据集
# 参数理解:dataset数据集,batch_size每次加载的数据量(把64个图片信息当成一组打包成一个作为dataloader的一个返回),
# shuffle每次加载数据之前是否重新洗牌,num_workers线程数, drop_last最后余数余下的数据集是否丢掉
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)

# 测试数据集
img, target = test_data[0]
print(img.shape)
# 输出:torch.Size([3, 32, 32]) rgb3通道,32×32大小的图
print(target)

writer = SummaryWriter(\"dataloader\")
step = 0
for data in test_loader:
    imgs, targets = data
    # print(img.shape)
    # print(target)
    # 注意这里用的是add_images() 而不是add_image()
    writer.add_images(\"test_data\", imgs, step)
    step = step + 1
writer.close()

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号