diff --git a/model_zoo/official/cv/resnet/train.py b/model_zoo/official/cv/resnet/train.py index 69b2526f47..b820bace54 100755 --- a/model_zoo/official/cv/resnet/train.py +++ b/model_zoo/official/cv/resnet/train.py @@ -27,6 +27,7 @@ from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.communication.management import init, get_rank, get_group_size from mindspore.common import set_seed +from mindspore.parallel import set_algo_parameters import mindspore.nn as nn import mindspore.common.initializer as weight_init from src.lr_generator import get_lr, warmup_cosine_annealing_lr @@ -82,6 +83,7 @@ if __name__ == '__main__': context.set_context(device_id=device_id, enable_auto_mixed_precision=True) context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True) + set_algo_parameters(elementwise_op_strategy_follow=True) if args_opt.net == "resnet50" or args_opt.net == "se-resnet50": context.set_auto_parallel_context(all_reduce_fusion_config=[85, 160]) else: