深度学习框架PyTorch笔记(四)数据转换 Data Transformation

📅 2026/6/29 23:23:13
深度学习框架PyTorch笔记(四)数据转换 Data Transformation
数据预处理将原始图片PIL Image、numpy array转换为Tensor并做归一化以匹配模式输入要求。例如输入的样本图像需要调整为固定大小张量格式并归一化到[0,1]。数据增强在训练时随机裁剪、翻转、改变颜色等增加数据多样性提升模型泛化能力。例如通过随机旋转、裁剪和裁剪增加数据样本的变种避免过拟合。统一处理把多种变换组合成一个pipeline与DataLoader无缝衔接。可以动态地对数据进行处理简化数据加载的复杂度。环境准备pip install torch torchvision导入常用模块import torch from torch.utils.data import Dataset, DataLoader import torchvision.transforms.v2 as transforms # 推荐使用 v2 from PIL import Image如果你习惯旧 API可以通过from torchvision import transforms导入但 v2 的功能更强大且兼容旧用法。1.基础变换操作变换函数名称描述实例transforms.ToTensor()将PIL图像或NumPy数组转换为PyTorch张量并自动将像素值归一化到 [0, 1]。transform transforms.ToTensor()transforms.Normalize(mean, std)对图像进行标准化使数据符合零均值和单位方差。transform transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225])transforms.Resize(size)调整图像尺寸确保输入到网络的图像大小一致。transform transforms.Resize((256, 256))transforms.CenterCrop(size)从图像中心裁剪指定大小的区域。transform transforms.CenterCrop(224)1.1 ToTensor将PIL图形或Numpy数组转换为PyTorch张量。同时将像素值从[0,255] 归一化为[0,1]。from torchvision import transforms transform transforms.ToTensor()1.2 Normalize​ 对数据进行标准化使其符合特定的均值和标准差。通常用于图像数据将其像素值归一化为零均值和单位方差。transform transforms.Normalize(mean[0.5], std[0.5]) # 归一化到 [-1, 1]1.3 Resize调整样本图形的大小确保输入到网络图像大小一致。transformtransforms.Resize((128, 128)) # 将图像调整为 128x1281.4 CenterCrop从图像中心裁剪指定大小的区域。transform transforms.CenterCrop(128) # 裁剪 128x128 的区域2. 数据增强操作Data Augmentation变换函数名称描述实例transforms.RandomHorizontalFlip(p)随机水平翻转图像。transform transforms.RandomHorizontalFlip(p0.5)transforms.RandomRotation(degrees)随机旋转图像。transform transforms.RandomRotation(degrees45)transforms.ColorJitter(brightness, contrast, saturation, hue)调整图像的亮度、对比度、饱和度和色调。transform transforms.ColorJitter(brightness0.2, contrast0.2, saturation0.2, hue0.1)transforms.RandomCrop(size)随机裁剪指定大小的区域。transform transforms.RandomCrop(224)transforms.RandomResizedCrop(size)随机裁剪图像并调整到指定大小。transform transforms.RandomResizedCrop(224)2.1 RandomCrop从图形中随机裁剪指定大小。transform transforms.RandomCrop(128)2.2 RandomHorizontalFlip以一定概率水平翻转图形。transform transforms.RandomHorizontalFlip(p0.5) # 50% 概率翻转2.3 RandomRotation随机将图像旋转一定角度。transform transforms.RandomRotation(degrees30) # 随机旋转 -30 到 30 度2.4 ColorJitter随机改变图像的亮度、对比度、饱和度或色调。transform transforms.ColorJitter(brightness0.5, contrast0.5)3.组合变换​ 这些转换可以通过Compose组合在一起以便对图像进行一系列的转换。Compose类允许你创建一个包含多个转换操作的列表这些操作将按照定义的顺序应用于输入数据中。变换函数名称描述实例transforms.Compose()将多个变换组合在一起按照顺序依次应用。transform transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225]), transforms.Resize((256, 256))])通过transforms.Compose将多个变换组合起来。transform transforms.Compose([ transforms.Resize((128, 128)), transforms.RandomHorizontalFlip(p0.5), transforms.ToTensor(), transforms.Normalize(mean[0.5], std[0.5]) ])4.与Dataset和DataLoader配合​ PyTorch的Data set类负责加载数据transform作为参数参入在__getitem__中被调用。class ImageDataset(Dataset): def __init__(self, image_paths, labels, transformNone): self.image_paths image_paths self.labels labels self.transform transform def __len__(self): return len(self.image_paths) def __getitem__(self, idx): img Image.open(self.image_paths[idx]).convert(RGB) label self.labels[idx] if self.transform: img self.transform(img) return img, label创建DataLoadertrain_dataset ImageDataset(train_paths, train_labels, transformtrain_transform) train_loader DataLoader(train_dataset, batch_size32, shuffleTrue)