diff --git a/model_zoo/official/cv/resnet/train.py b/model_zoo/official/cv/resnet/train.py index a434c2c21f..365b084c5b 100755 --- a/model_zoo/official/cv/resnet/train.py +++ b/model_zoo/official/cv/resnet/train.py @@ -111,7 +111,7 @@ if __name__ == '__main__': if args_opt.net == "resnet50" or args_opt.net == "se-resnet50": context.set_auto_parallel_context(all_reduce_fusion_config=[85, 160]) elif args_opt.net == "resnet101": - context.set_auto_parallel_context(all_reduce_fusion_config=[180, 313]) + context.set_auto_parallel_context(all_reduce_fusion_config=[80, 210, 313]) init() # GPU target else: