完整小测试:
#python的内置函数,getitemclass Animal():def __init__(self,name):self.name = namedef __str__(self):return f"This is {self.name}"class Zoo():def __init__(self,animal_list):self.animal_list = animal_listdef __getitem__(self,index):return self.animal_list[index]def __len__(self):return len(self.animal_list)#self类对象,key就是传入的索引,value就是key这个索引需要修改成什么值def __setitem__(self, key, value):self.animal_list[key]=value#zoo里面有个itemlist,我们如何遍历呢?animal_list = [Animal('lion'),Animal('monkey'),Animal('cat')]
zoo = Zoo(animal_list)
#不能直接遍历,需要实现itemlist才可以
print(zoo[0])
print("*" * 20) #打印20个星星
for animal in zoo:print(animal)#既然可以通过下标找到animal,是否可以用下标来做循环呢,需要实现len()方法
print("*" * 20) #打印20个星星
for index in range(len(zoo)):print(zoo[index])
说明:如果要遍历列一个类里面列表元素,这类需要实现getitem方法,通过下标返回去列表元素。如果要通过下标的长度遍历这个列表,需要实现len()这个内置方法
pytorch里面自定义的DataSet类:需要实现getitem和len()方法
from torch.utils.data import Dataset
from PIL import Image
import osclass MyDataset(Dataset):def __init__(self, root_dir, transform=None):self.root_dir = root_dirself.transform = transform# 假设我们数据集每个子文件夹里存储一个pair对:{image_in:image_gt}self.file_list = os.listdir(root_dir)def __len__(self):# 需要知道数据集的大小,方便后续划分等return len(self.file_list)def __getitem__(self, idx):# 返回给定索引处的数据样本, 调用该接口训练时就可以循环读取所有的数据对了img_path = os.path.join(self.root_dir, self.file_list[idx])image_in = Image.open(img_path+"image_in.jpg")image_gt = Image.open(img_path+"image_gt.jpg")if self.transform:# 数据增广image_in = self.transform(image_in)return image_in, image_gt
dataLoader里面:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import ostransform = transforms.Compose([transforms.Resize((256, 256)),transforms.ToTensor(),
])dataset = MyDataset(root_dir='path/to/images', transform=transform)# 创建 DataLoader 实例
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)# 迭代数据
for batch in dataloader:# 处理每一批数据images = batch# 进行训练或其他操作