AlphaFold3 data_modules 模块的 OpenFoldDataset
类的 reroll
方法动态确定当前 epoch 需要使用的数据点,确保每次训练时的样本 不是固定的,而是 基于概率重新采样。改变 self.datapoints
,即 改变 __getitem__
方法取数据的方式,从而实现 动态数据增强。在 训练过程中可以多次调用 reroll()
来 重新采样,确保每个 epoch 训练数据的多样性。
源代码:
def reroll(self):dataset_choices = torch.multinomial(torch.tensor(self.probabilities),num_samples=self.epoch_len,replacement=True,generator=self.generator,)self.datapoints = []for dataset_idx in dataset_choices:samples = self._samples[dataset_idx]datapoint_idx = next(samples)self.datapoints.append((dataset_idx, datapoint_idx))
<