数据通常不会直接是机器学习算法可以使用的“最终格式”。我们使用转换(transforms)来对数据进行处理,使其适合训练。
所有的 TorchVision 数据集都提供了两个参数:transform
用于修改特征,target_transform
用于修改标签,它们都接收包含转换逻辑的可调用对象(函数或类)。
torchvision.transforms
模块提供了几种常用的转换,开箱即用。
数据通常不会直接是机器学习算法可以使用的“最终格式”。我们使用**转换(transforms)**来对数据进行处理,使其适合训练。
FashionMNIST 的特征是 PIL 图像格式,标签是整数。为了训练,我们需要把特征转换成归一化的张量,标签转换成one-hot 编码的张量。为了实现这些转换,我们使用 ToTensor
和 Lambda
。
数据转换(Transforms) 是用来对原始数据进行处理的步骤,常见的处理包括:
-
转换图像格式
-
缩放、裁剪、归一化等
-
转换标签格式
transform
和 target_transform
:
-
transform
是对图像(或特征)做的预处理(如转张量、归一化)。 -
target_transform
是对标签做的预处理(如 one-hot 编码)。
常用转换:
torchvision.transforms
模块提供了很多常用的转换操作,比如:
-
ToTensor()
:把图像转换为张量。 -
Normalize()
:归一化操作。 -
Resize()
:调整图片尺寸。 -
RandomCrop()
:随机裁剪图片。 -
FashionMNIST:
-
图片格式是 PIL Image,我们需要把它转成 Tensor。
-
标签是整数形式,我们通常需要把它们转换成 one-hot 编码,这样模型可以更好地进行分类。
-
import torch
from torchvision import datasets # 导入数据集模块
from torchvision.transforms import ToTensor, Lambda # 导入 ToTensor 转换和自定义 Lambda 转换# 加载 FashionMNIST 数据集
ds = datasets.FashionMNIST(root="data", # 数据集保存的根目录,数据将下载到这个目录下train=True, # 加载训练集(如果为 False,则加载测试集)download=True, # 如果本地没有数据集,设置为 True 会自动从网络下载transform=ToTensor(), # 对图片进行转换:将PIL图像转换成Tensor格式,且像素值归一化到[0, 1]target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1)) # 将标签转换成one-hot编码格式
)
-
加载 FashionMNIST 数据集的训练集。
-
把每张图像转换为张量,并归一化到
[0, 1]
范围。 -
把每个标签转换为 one-hot 编码格式。
-
Lambda
:定义了一个自定义的转换逻辑,在这里将标签转换为 one-hot 编码。 -
one-hot 编码:假设标签是数字
3
,它会变成一个长度为 10 的张量,只有索引 3 的位置为 1,其他位置为 0,比如:[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]。
ToTensor()
ToTensor()
将 PIL 图像或 NumPy ndarray 转换成一个 FloatTensor,并将图像的像素值缩放到 [0., 1.] 范围。
-
PIL 图像:就是 Python 图像库(Pillow)里的图片格式(比如
.jpg
、.png
格式的图片)。 -
NumPy ndarray:NumPy 中的多维数组,一般用于存储数字数据。
-
FloatTensor:PyTorch 中的数据结构,表示一个浮点型的张量,适合进行数学计算。
-
缩放到 [0, 1]:图像像素值通常在 [0, 255] 范围内,
ToTensor()
会自动把它们缩放到 [0, 1],这有助于模型训练,因为大部分模型训练时需要输入值在 0 到 1 的范围
Lambda 转换
Lambda
转换应用任何用户自定义的 lambda 函数。在这里,我们定义了一个函数,将整数标签转换成 one-hot 编码的张量。首先它创建了一个大小为 10 的零张量(我们的数据集中有 10 个标签),然后调用 scatter_
方法,将标签 y
所在的索引位置赋值为 1。
target_transform = Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))