发布时间:2023-10-20 15:00
在官方文档中,打开CIFAR10这个数据集,可以看到在使用这个数据集的时候,一般都需要几个参数,root是数据集下载之后存放的位置,train是一个bool类型,为true表示这个数据集作为训练集,为false表示这个数据集为测试集,transform和target_transform都是使用的transforms类型,download是bool型变量,为true表示下载数据集,false反之。
如果下载比较慢用迅雷下载(数据集的地址可以在pycharm中按住Ctrl点击数据集查看源码往上翻就能看到),然后自己在项目下弄一个文件夹复制进去就行,再次运行如果download是true会自动解压
代码示例1
import torchvision
train_set = torchvision.datasets.CIFAR10(root=\"./dataset\", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root=\"./dataset\", train=False, download=True)
print(test_set[0])
print(test_set.classes)
img, target = test_set[0]
print(img)
print(target)
print(test_set.classes[target])
img.show()
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_trans = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
train_set = torchvision.datasets.CIFAR10(root=\"./dataset\", train=True, transform=dataset_trans, download=True)
test_set = torchvision.datasets.CIFAR10(root=\"./dataset\", train=False, transform=dataset_trans, download=True)
#转化完得到tensor数据类型之后就可以用tensorboard显示了
writer = SummaryWriter(\"p10\")
for i in range(10):
img, target = test_set[i]
writer.add_image(\"test_set\", img, i)
writer.close()