AlphaFold3 data_modules 模块的 OpenFoldDataLoader
类是AlphaFold/OpenFold 中用于组织数据加载流程的核心类之一,基于 PyTorch Lightning 的 pl.LightningDataModule
实现,重写了__init__、setup、train_dataloader、val_dataloader、test_dataloader等方法。 它主要负责将不同的数据源(如训练、蒸馏、验证、预测)转换为 PyTorch 数据集,并生成对应的 DataLoader
。
源代码:
class OpenFoldDataModule(pl.LightningDataModule):def __init__(self,config: mlc.ConfigDict,template_mmcif_dir: str,max_template_date: str,train_data_dir: Optional[str] = None,train_alignment_dir: Optional[str] = None,train_chain_data_cache_path: Optional[str] = None,distillation_data_dir: Optional[str] = None,distillation_alignment_dir: Optional[str] = None,distillation_chain_data_cache_path: Optional[str] = None,val_data_dir: Optional[str] = None,val_alignment_dir: Optional[str] = None,predict_data_dir: Optional[str] = None,predict_alignment_dir: Optional[str] = None,kalign_binary_path: str = '/usr/bin/kalign',train_filter_path: Optional[str] = None,distillation_filter_path: Optional[str] = None,obsolete_pdbs_file_path: Optional[str] = None,template_release_dates_cache_path: Optional[str] = None,batch_seed: Optional[int] = None,train_epoch_len: int = 50000,_distillation_structure_index_path: Optional[str] = None,alignment_index_path: Optional[str] = None,distillation_alignment_index_path: Optional[str] = None,**kwargs):super(OpenFoldDataModule, self).__init__()self.config = configself.template_mmcif_dir = template_mmcif_dirself.max_template_date = max_template_dateself.train_data_dir = train_data_dirself.train_alignment_dir = train_alignment_dirself.train_chain_data_cache_path = train_chain_data_cache_pathself.distillation_data_dir = distillation_data_dirself.distillation_alignment_dir = distillation_alignment_dirself.distillation_chain_data_cache_path = (distillation_chain_data_cache_path)self.val_data_dir = val_data_dirself.val_alignment_dir = val_alignment_dirself.predict_data_dir = predict_data_dirself.predict_alignment_dir = predict_alignment_dirself.kalign_binary_pat