pytroch深度学习——torchvision中的数据集使用

发布时间:2023-10-20 15:00

torchvision

在官方文档中,打开CIFAR10这个数据集,可以看到在使用这个数据集的时候,一般都需要几个参数,root是数据集下载之后存放的位置,train是一个bool类型,为true表示这个数据集作为训练集,为false表示这个数据集为测试集,transform和target_transform都是使用的transforms类型,download是bool型变量,为true表示下载数据集,false反之。
如果下载比较慢用迅雷下载(数据集的地址可以在pycharm中按住Ctrl点击数据集查看源码往上翻就能看到),然后自己在项目下弄一个文件夹复制进去就行,再次运行如果download是true会自动解压
\"pytroch深度学习——torchvision中的数据集使用_第1张图片\"

代码示例1

import torchvision

train_set = torchvision.datasets.CIFAR10(root=\"./dataset\", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root=\"./dataset\", train=False, download=True)

print(test_set[0])
print(test_set.classes)

img, target = test_set[0]
print(img)
print(target)
print(test_set.classes[target])
img.show()

运行结果:
\"pytroch深度学习——torchvision中的数据集使用_第2张图片\"
代码示例2

import torchvision
from torch.utils.tensorboard import SummaryWriter

dataset_trans = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])
train_set = torchvision.datasets.CIFAR10(root=\"./dataset\", train=True, transform=dataset_trans, download=True)
test_set = torchvision.datasets.CIFAR10(root=\"./dataset\", train=False, transform=dataset_trans, download=True)

#转化完得到tensor数据类型之后就可以用tensorboard显示了
writer = SummaryWriter(\"p10\")
for i in range(10):
    img, target = test_set[i]
    writer.add_image(\"test_set\", img, i)

writer.close()

\"pytroch深度学习——torchvision中的数据集使用_第3张图片\"

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号