发布时间:2023-08-15 16:30
安装环境部分:
1、Anaconda的安装(自行参考其他博客)
2、使用pip安装pytorch
代码部分:
import os
from PIL import Image
from torch.utils.data import Dataset
# dataset有两个作用:1、加载每一个数据,并获取其label;2、用len()查看数据集的长度
class MyData(Dataset):
def __init__(self, root_dir, label_dir): # 初始化,为这个函数用来设置在类中的全局变量
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir,self.label_dir) # 单纯的连接起来而已,背下来怎么用就好了,因为在win下和linux下的斜线方向不一样,所以用这个函数来连接路径
self.img_path = os.listdir(self.path) # img_path 的返回值,就已经是一个列表了
def __getitem__(self, idx): # 获取数据对应的 label
img_name = self.img_path[idx] # img_name 在上一个函数的最后,返回就是一个列表了
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 这行的返回,是一个图片的路径,加上图片的名称了,能够直接定位到某一张图片了
img = Image.open(img_item_path) # 这个步骤看来是不可缺少的,要想 show 或者 操作图片之前,必须要把图片打开(读取),也就是 Image.open()一下,这可能是 PIL 这个类型图片的特有操作
label = self.label_dir # 这个例子中,比较特殊,因为图片的 label 值,就是图片所在上一级的目录
return img, label # img 是每一张图片的名称,根据这个名称,就可以使用查看(直接img)、print、size等功能
# label 是这个图片的标签,在当前这个类中,标签,就是只文件夹名称,因为我们就是这样定义的
def __len__(self):
return len(self.img_path) # img_path,已经是一个列表了,len()就是在对这个列表进行一些操作
if __name__ == \'__main__\':
root_dir = \"dataset/train\"
ants_label_dir = \"ants\"
bees_label_dir = \"bees\"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset
# runfile(\'E:/pythonProject/learn_pytorch/read_data.py\')