torch.randperm
是 PyTorch 中用于生成随机排列的整数序列的函数,广泛应用于数据随机化、样本抽样等场景。
功能与用途
torch.randperm(n)
会生成一个从 0
到 n-1
的随机排列的整数序列,返回一个一维张量。这个函数在以下场景中非常有用:
-
随机打乱数据:在机器学习中,常用于打乱数据集的顺序,以提高模型的泛化能力。
-
样本抽样:可以用于从数据集中随机抽取样本。
参数
-
n
(int):生成的随机排列的长度,范围是[0, n-1]
。 -
generator
(torch.Generator, 可选):用于生成随机数的伪随机数生成器。 -
device
(torch.device, 可选):指定生成的张量所在的设备(如 CPU 或 GPU)。默认为当前默认设备。 -
dtype
(torch.dtype, 可选):返回张量的数据类型,默认为torch.int64
。 -
requires_grad
(bool, 可选):是否需要记录梯度,默认为False
。
返回值
返回一个一维张量,包含从 0
到 n-1
的随机排列。
使用示例
生成随机排列
import torch# 生成一个长度为 10 的随机排列
random_perm = torch.randperm(10)
print(random_perm) # 输出类似:tensor([3, 8, 1, 5, 9, 0, 4, 7, 2, 6])
打乱数据集
data = torch.randn(10, 3, 224, 224) # 假设是一个图像数据集
labels = torch.randint(0, 2, (10,)) # 随机生成标签# 生成随机索引
indices = torch.randperm(data.size(0))# 打乱数据和标签
shuffled_data = data[indices]
shuffled_labels = labels[indices]print(shuffled_data.shape) # 输出:torch.Size([10, 3, 224, 224])
print(shuffled_labels)
在 GPU 上生成随机排列
random_perm = torch.randperm(10, device='cuda')
print(random_perm)
高级用法
可以通过指定 generator
参数来控制随机数生成器的状态,确保结果的可复现性:
generator = torch.Generator().manual_seed(42)
random_perm = torch.randperm(10, generator=generator)
print(random_perm) # 输出固定结果
torch.randperm
是一个简单而强大的工具,适用于需要随机排列或打乱数据的场景。