当前位置: 首页> 教育> 锐评 > PyTorch重写DataSet类

PyTorch重写DataSet类

时间:2025/7/13 3:57:48来源:https://blog.csdn.net/m0_52420323/article/details/140956227 浏览次数:0次

PyTorch重写DataSet类


文章目录

  • PyTorch重写DataSet类
  • 前言
  • 一、如何重写?
  • 二、具体代码
    • 1.数据集格式
    • 2.获取标签
    • 3.重写dataset
    • 4.调用
  • 总结


前言

在之前沐神的Cifar-10分类 课程学习中,沐神是用的将每一类创建一个文件夹去完成图片的导入。此外我们还可以通过重写DataSet类来完成!

一、如何重写?

在这里插入图片描述
通过查看官方文档我们可知。
需要去重写__getitem__这个方法,去以一种特定的方法拿到一个数据。并且选择性的重写__len__这个方法,去返回整个数据集的大小。

二、具体代码

1.数据集格式

在这里插入图片描述
这个数据集是沐神课程上讲过的cifar-10数据集。
train和test文件夹分别为要进行训练和测试的图片。而训练数据的标签以csv文件存在trainLabels.csv文件中。

2.获取标签

def read_csv_labels(fname):with open(fname,'r') as f:lines = f.readlines()[1:]tokens = [l.rstrip().split(',') for l in lines]return dict(((name,label) for name,label in tokens))

这里通过一个read_csv_labels的方法 将图片名字和标签以一个字典的方式返回

3.重写dataset

class MyDateset(Dataset):def __init__(self,root_dir,state,label_dict=None):self.root_dir = root_dirself.state = stateif label_dict is not None:self.label_dict = label_dictself.img_path = os.listdir(os.path.join(root_dir,state))# os.listdir 将当前文件夹下的图片名称按列表返回def __getitem__(self, idx):img = Image.open(os.path.join(self.root_dir,self.state,self.img_path[idx]))if self.state == 'train':img_num =self.img_path[idx].split('.')[0]# 这个取出来是数字.jpg 所以需要将.jpg舍去label = self.label_dict[img_num]return img,labelelse:return imgdef __len__(self):return len(self.img_path)

state参数表示此时是训练数据集还是测试数据集。

4.调用

root_dir = "D:\\PytorchLearn\\cifar-10"
label_dict = read_csv_labels(os.path.join(root_dir,"trainLabels.csv"))train_dataset = MyDateset(root_dir,'train',label_dict)test_dataset = MyDateset(root_dir,'test')train_iter = torch.utils.data.DataLoader(train_dataset,batch_size=8,shuffle=True)

总结

以上就是重写DataSet的方法,有不足之处还望各位指出。

关键字:PyTorch重写DataSet类

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

责任编辑: