Dataset实战 利用pytorch进行图片数据的读取 和数据集的分类
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self,root_dir,label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir,self.label_dir)
self.img_path = os.listdir(self.path) #将这个路径下的文件变成一个列表的形式
def __getitem__(self, idx): #想要获取每一个图片
img_name = self.img_path[idx] #在这个列表下 用idx看是第几个图
#'0013035.jpg' str类型的 img的name
img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
#将这个名字和路径进行拼接 就能得到这个图片的相对路径
img = Image.open(img_item_path)
#这样就能得到这个img
label = self.label_dir
#标签 label
return img,label
def __len__(self):
return len(self.img_path)
#dir_path = "dataset/train/ants"
#img_path_list = os.listdir(dir_path)
#将这个路径下的文件变成一个列表的形式
#对于self的理解,相当于将一个函数中的变量变成了全局变量
root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir,ants_label_dir)
bees_dataset = MyData(root_dir,bees_label_dir)
train_dataset = ants_dataset + bees_dataset
1.对于self的理解,相当于将一个函数中的变量变成了全局变量
root_dir:根目录
label_dir:子目录,在这个案例中子目录就是ants和bees,就是图片的分两类
self.root_dir = root_dir #root_dir = "dataset/train"
self.label_dir = label_dir #ants_label_dir = "ants"
先将两个路径进行合并可以得到目标图片的真正路径
然后用listdir将这个路径下的文件转成一个列表的形式
self.path = os.path.join(self.root_dir,self.label_dir)
self.img_path = os.listdir(self.path) #将这个路径下的文件变成一个列表的形式
列表的形式:
getitem是得到每个图片的方法,idx是第几个图片的索引
def __getitem__(self, idx): #想要获取每一个图片
img_name = self.img_path[idx] #在这个列表下 用idx看是第几个图
将这个名字和路径进行拼接 就能得到这个图片的相对路径
img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
可以得到图片和标签
img = Image.open(img_item_path)
#这样就能得到这个img
label = self.label_dir
#标签 label
return img,label
再定义一个可以得到长度的方法:
def __len__(self):
return len(self.img_path)
img , label = bees_dataset[1]
img.show()
得到下面的图片