数据集类(Data Set)与数据加载器(Data Loader)

📅 2026/6/29 19:20:46
数据集类(Data Set)与数据加载器(Data Loader)
数据集类Data Set是指存储和表示数据的类或接口。它通常用于封装数据以便能够在机器学习任务中使用。数据集可以是任何形式的数据比如图像、文本、音频等。数据集的主要目的是提供对数据的标准访问方法以便可以轻松地将其用于模型训练、验证和测试。​数据加载器Data Loader是一个提供批量加载数据的工具。它通过将数据集分割成小批量并按照一定的顺序加载到内存中以提高训练效率。数据加载器常用于训练过程中的数据预处理、批量化操作和数据并行处理等。​ PyTorch中的torch.utils.data.Dataset和torch.utils.data.DataLoader是数据加载和处理的核心组件。它们将数据读取与模型训练解耦提供高效、灵活的数据迭代方式。下面从基础概念、自定义加载器参数、多进程机制等方面进行详细介绍。1.数据集Data Set1.1 自定义数据集定义实现​Data Set是一个抽象类表示一个数据集。任何自定义数据集都必须继承它自定义DataSet类必须实现它构造函数和两个方法__init__: 在 实例化DataSet 对象运行一次。我们初始化包含图像的目录、注释文件和transform与 target_transform.__len__返回数据集的总样本数。len(dataset)会调用它。__getitem__(self, idx)根据整数索引idx会返回一个样本通常为特征和标签。dataset[idx]会调用它。其作用就是实现通过索引访问对应的数据以及标签。from torch.utils.data import Dataset class MyDataset(Dataset): def __init__(self, data, labels): self.data data self.labels labels def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx], self.labels[idx]使用自定义数据集时可以用将其与torch.utils.data.DataLoader结合使用以便进行数据的批量加载和处理和训练。1.2 两种自定义数据集风格​ 在PyTorch中自定义数据集有两个核心设计模式映射式Map-Style和可迭代式Iterable-style。它们的差异不仅是实现接口不同更反映了“随机访问”与“流式读取”两种数据消费范式的根本区别。下面从设计理念、实现细节、多进程交互、适用场景等方面深入解析。Map-style datasets映射式就是上述需要实现__getitem__和__len__的数据集它通过索引映射到数据样本。适用于所有数据能一次性放入索引结构如列表、文件路径列表的场景。Iterable-style datasets可迭代式当数据集太大无法一次性加载或数据是流式读取时如实时日志、数据库流可以继承IterableDataset实现__iter__方法返回一个迭代器。这种数据集不能使用len()也无法使用随机采样shuffle的 loader需使用Sampler的特定变体。在后续笔记我们将详细介绍。1.3 内置数据集​ PyTorch提供了一些常用数据集类主要在torchvision.datasets、torchtext.datasets、torchaudio.datasets中。例如torchvision.datasets.MNIST、CIFAR10、ImageFolder从文件夹结构加载图片子文件夹为类别torchtext.datasets.IMDB等torchaudio.datasets.LIBRISPEECH等这些内置类都继承自Dataset使用时可自动下载数据并提供标准化访问方式。​ 现在我们来展示一个如何从TorchVision加载了Fahion-MINIST由60000个训练样本和10000个测试样本组成。每个样本包含一个28×28 灰度图像和一个来自10个类别之一的关联标签。下面使用以下参数加载FashionMINIST数据集root是存储路径、测试数据的路径。train指定训练集或测试数据集。downloadTrue如果root路径下没有数据则从网上下载数据。transform和target_transform是指定特征和标签转换。import torch from torch.utils.data import Dataset from torchvision import datasets from torchvision.transforms import ToTensor import matplotlib.pyplot as plt training_data datasets.FashionMNIST( root./data, trainTrue, downloadTrue, transformToTensor() ) test_data datasets.FashionMNIST( root./data, trainFalse, downloadTrue, transformToTensor() )我们可以用索引来访问数据集中的样本用matplotlib可视化图形样本。labels_map { 0: T-Shirt, 1: Trouser, 2: Pullover, 3: Dress, 4: Coat, 5: Sandal, 6: Shirt, 7: Sneaker, 8: Bag, 9: Ankle Boot, } figure plt.figure(figsize(8, 8)) cols, rows 3, 3 for i in range(1, cols * rows 1): sample_idx torch.randint(len(training_data), size(1,)).item() img, label training_data[sample_idx] figure.add_subplot(rows, cols, i) plt.title(labels_map[label]) plt.axis(off) plt.imshow(img.squeeze(), cmapgray) plt.show()其运行结果如下2. 数据加载器Data Loader数据加载器Data Loader将DataSet封装为可迭代对象负责批量加载、打乱数据、多进程并行加载等功能。其功能如下批量加载数据DataLoader可以从数据集中按照指定的批量大小加载数据。每个批次的数据可以作为一个张量或列表返回便于进行后续的处理和训练。数据随机洗牌通过设置shuffleTrueDataLoader可以在每个迭代周期中对数据进行随机洗牌以减少模型对数据顺序的依赖性提高训练效果。多线程数据加载DataLoader支持使用多个线程来并行加载数据加快数据加载的速度提高训练效率。数据批次采样除了按照批量大小加载数据外DataLoader还支持自定义的数据批次采样方式。可以通过设置batch_sampler参数来指定自定义的批次采样器例如按照指定的样本顺序或权重进行采样。数据加载器的API形式与核心参数DataLoader(dataset, batch_size1, shuffleFalse, samplerNone, batch_samplerNone, num_workers0, collate_fnNone, pin_memoryFalse, drop_lastFalse, timeout0, worker_init_fnNone, multiprocessing_contextNone, generatorNone, prefetch_factor2, persistent_workersFalse)dataset要加载的Dataset对象映射式或可迭代式。batch_size每个批次的样本数默认为 1。shuffle是否在每个 epoch 开始时打乱数据顺序仅对映射式有效。打乱基于RandomSampler。sampler自定义采样器继承自torch.utils.data.Sampler。定义数据索引的抽取策略。如果指定shuffle必须为False。batch_sampler类似sampler但每次返回一批索引与batch_size、shuffle、sampler互斥。num_workers用于数据加载的子进程数。0 表示在主进程中加载通常设置大于 0 可以加速数据预处理利用多核。collate_fn函数定义如何将多个样本列表合并为一个批次。默认collate_fn会将所有样本沿第0维堆叠成张量通常对于同型数据有效。如果样本结构不一致如不同长度序列需要自定义。pin_memory若为True数据加载器在返回张量前将其复制到 CUDA 固定内存加速数据传输到 GPU。仅适用于 CUDA。drop_last若为True丢弃最后一个不完整批次当总样本数不能被 batch_size 整除时。在训练时如果要求严格固定批次大小如 BatchNorm应设为Truetimeout从 worker 进程获取一个 batch 的超时时间秒。如果超时会抛异常。worker_init_fn每个 worker 进程的初始化函数参数为 worker id可用于设置随机种子等。generator用于生成随机采样的伪随机数生成器保证可复现性。prefetch_factor每个 worker 预先加载的 batch 数默认 2增加可以让 GPU 更少等待。persistent_workers若为True在数据集被消费一次后不会关闭 worker 进程可保持 worker 存活以加速后续 epoch。数据调用案例Demoimport torch from torch.utils.data import Dataset, DataLoader # 自定义数据集类 class MyDataset(Dataset): def __init__(self, data): self.data data def __len__(self): return len(self.data) def __getitem__(self, index): return self.data[index] # 自定义数据加载器类 class MyDataLoader(DataLoader): def __init__(self, dataset, batch_size1, shuffleFalse, num_workers0): super().__init__(dataset, batch_size, shuffle, num_workersnum_workers) def collate_fn(self, batch): # 自定义的数据预处理、合并等操作 # 这里只是简单地将样本转换为Tensor并进行堆叠 return torch.stack(batch) # 自定义数据集类 data [1, 2, 3, 4, 5] dataset MyDataset(data) # 创建数据加载器实例 dataloader MyDataLoader(dataset, batch_size2, shuffleTrue) # 遍历数据加载器 for batch in dataloader: # batch是一个包含多个样本的张量或列表 # 这里可以对批次数据进行处理 print(batch)3.实战案例import torch from sklearn.datasets import load_iris from torch.utils.data import Dataset, DataLoader # 此函数用于加载鸢尾花数据集 def load_data(shuffleTrue): x torch.tensor(load_iris().data) y torch.tensor(load_iris().target) # 数据归一化 x_min torch.min(x, dim0).values x_max torch.max(x, dim0).values x (x - x_min) / (x_max - x_min) if shuffle: idx torch.randperm(x.shape[0]) x x[idx] y y[idx] return x, y # 自定义鸢尾花数据类 class IrisDataset(Dataset): def __init__(self, modetrain, num_train120, num_dev15): super(IrisDataset, self).__init__() x, y load_data(shuffleTrue) if mode train: self.x, self.y x[:num_train], y[:num_train] elif mode dev: self.x, self.y x[num_train:num_train num_dev], y[num_train:num_train num_dev] else: self.x, self.y x[num_train num_dev:], y[num_train num_dev:] def __getitem__(self, idx): return self.x[idx], self.y[idx] def __len__(self): return len(self.x) batch_size 16 # 分别构建训练集、验证集和测试集 train_dataset IrisDataset(modetrain) dev_dataset IrisDataset(modedev) test_dataset IrisDataset(modetest) train_loader DataLoader(train_dataset, batch_sizebatch_size,shuffleTrue) dev_loader DataLoader(dev_dataset, batch_sizebatch_size) test_loader DataLoader(test_dataset, batch_size1, shuffleTrue)4.总 结ataset定义数据源及其访问方式映射式最常用流式数据用IterableDataset。DataLoader封装采样、批处理、多进程加载和内存固定等功能参数丰富。通过自定义sampler、collate_fn可以灵活处理各种数据形式和不平衡问题。多进程加载是加速训练的关键需注意内存复制和系统兼容性。