|
|
|
@@ -81,11 +81,11 @@ if __name__ == '__main__': |
|
|
|
init() |
|
|
|
# GPU target |
|
|
|
else: |
|
|
|
init("nccl") |
|
|
|
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, |
|
|
|
mirror_mean=True) |
|
|
|
if args_opt.net == "resnet50": |
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices([85, 160]) |
|
|
|
init("nccl") |
|
|
|
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" |
|
|
|
|
|
|
|
# create dataset |
|
|
|
|