from mpi4py import MPI
def main(args):comm = MPI.COMM_WORLDrank = comm.Get_rank()proc_num = comm.Get_size()gpu_rank = rank % 8model = build_model.to(gpu_rank)datasets = get_datasets()start = (len(datasets) // proc_num + 1) * rankend = (len(datasets) // proc_num + 1) * (rank + 1)cur_sub_datasets = datasets[start:end]for data in datasets:out = model(data)save(out)if __name__ == '__main__':'''Total dataset: 68750.'''args = parse_args()if not os.path.exists(args.output_root_path):os.mkdir(args.output_root_path)main(args)
mpirun -n 24 python main.py
Reference
- Semantic-Segment-Anything