Merge pull request !6483 from gziyan/rm——parameter_broadcasttags/v1.0.0
| @@ -15,12 +15,14 @@ | |||||
| """Utils of auto parallel""" | """Utils of auto parallel""" | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore import log as logger | |||||
| from mindspore._c_expression import reset_op_id | from mindspore._c_expression import reset_op_id | ||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.common.dtype import dtype_to_nptype | from mindspore.common.dtype import dtype_to_nptype | ||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from mindspore.communication.management import get_group_size, get_rank | from mindspore.communication.management import get_group_size, get_rank | ||||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | from mindspore.parallel._auto_parallel_context import auto_parallel_context | ||||
| from mindspore.common.seed import get_seed | |||||
| def _get_parallel_mode(): | def _get_parallel_mode(): | ||||
| @@ -136,16 +138,11 @@ def _get_global_rank(): | |||||
| def _get_parameter_broadcast(): | def _get_parameter_broadcast(): | ||||
| """Get the parameter broadcast.""" | """Get the parameter broadcast.""" | ||||
| parallel_mode = auto_parallel_context().get_parallel_mode() | parallel_mode = auto_parallel_context().get_parallel_mode() | ||||
| if parallel_mode == "stand_alone": | |||||
| parameter_broadcast = False | |||||
| return parameter_broadcast | |||||
| parameter_broadcast = auto_parallel_context().get_parameter_broadcast() | |||||
| if auto_parallel_context().get_parameter_broadcast_is_set() is True: | |||||
| parameter_broadcast = auto_parallel_context().get_parameter_broadcast() | |||||
| elif parallel_mode in ("data_parallel", "hybrid_parallel"): | |||||
| parameter_broadcast = True | |||||
| else: | |||||
| parameter_broadcast = False | |||||
| if parallel_mode in ("data_parallel", "hybrid_parallel") and parameter_broadcast is False and get_seed is None: | |||||
| logger.warning("You are suggested to use mindspore.common.set_seed() to share" | |||||
| " parameters among devices.") | |||||
| return parameter_broadcast | return parameter_broadcast | ||||
| @@ -268,7 +268,7 @@ def train(cloud_args=None): | |||||
| loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False) | loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False) | ||||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size, | context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size, | ||||
| parameter_broadcast=True, gradients_mean=True) | |||||
| gradients_mean=True) | |||||
| model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=loss_scale_manager, amp_level="O3") | model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=loss_scale_manager, amp_level="O3") | ||||
| # checkpoint save | # checkpoint save | ||||
| @@ -54,7 +54,7 @@ if __name__ == '__main__': | |||||
| rank = args_opt.rank_id | rank = args_opt.rank_id | ||||
| device_num = args_opt.device_num | device_num = args_opt.device_num | ||||
| context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | ||||
| gradients_mean=True, parameter_broadcast=True) | |||||
| gradients_mean=True) | |||||
| init() | init() | ||||
| else: | else: | ||||
| rank = 0 | rank = 0 | ||||
| @@ -58,7 +58,7 @@ if __name__ == '__main__': | |||||
| cfg.group_size = get_group_size() | cfg.group_size = get_group_size() | ||||
| parallel_mode = ParallelMode.DATA_PARALLEL | parallel_mode = ParallelMode.DATA_PARALLEL | ||||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size, | context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size, | ||||
| parameter_broadcast=True, gradients_mean=True) | |||||
| gradients_mean=True) | |||||
| else: | else: | ||||
| cfg.rank = 0 | cfg.rank = 0 | ||||
| cfg.group_size = 1 | cfg.group_size = 1 | ||||
| @@ -59,7 +59,7 @@ if __name__ == '__main__': | |||||
| rank = args_opt.rank_id | rank = args_opt.rank_id | ||||
| device_num = args_opt.device_num | device_num = args_opt.device_num | ||||
| context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | ||||
| gradients_mean=True, parameter_broadcast=True) | |||||
| gradients_mean=True) | |||||
| init() | init() | ||||
| else: | else: | ||||
| rank = 0 | rank = 0 | ||||
| @@ -49,7 +49,7 @@ def context_device_init(config): | |||||
| if config.run_distribute: | if config.run_distribute: | ||||
| context.set_auto_parallel_context(device_num=config.rank_size, | context.set_auto_parallel_context(device_num=config.rank_size, | ||||
| parallel_mode=ParallelMode.DATA_PARALLEL, | parallel_mode=ParallelMode.DATA_PARALLEL, | ||||
| parameter_broadcast=True, gradients_mean=True, | |||||
| gradients_mean=True, | |||||
| all_reduce_fusion_config=[140]) | all_reduce_fusion_config=[140]) | ||||
| init() | init() | ||||
| else: | else: | ||||
| @@ -76,7 +76,6 @@ def train_on_ascend(): | |||||
| if run_distribute: | if run_distribute: | ||||
| context.set_auto_parallel_context(device_num=rank_size, | context.set_auto_parallel_context(device_num=rank_size, | ||||
| parallel_mode=ParallelMode.DATA_PARALLEL, | parallel_mode=ParallelMode.DATA_PARALLEL, | ||||
| parameter_broadcast=True, | |||||
| gradients_mean=True) | gradients_mean=True) | ||||
| init() | init() | ||||
| @@ -74,7 +74,6 @@ if __name__ == '__main__': | |||||
| if run_distribute: | if run_distribute: | ||||
| context.set_auto_parallel_context(device_num=rank_size, | context.set_auto_parallel_context(device_num=rank_size, | ||||
| parallel_mode=ParallelMode.DATA_PARALLEL, | parallel_mode=ParallelMode.DATA_PARALLEL, | ||||
| parameter_broadcast=True, | |||||
| gradients_mean=True) | gradients_mean=True) | ||||
| init() | init() | ||||
| context.set_auto_parallel_context(device_num=args_opt.device_num, | context.set_auto_parallel_context(device_num=args_opt.device_num, | ||||
| @@ -178,7 +178,7 @@ def test(cloud_args=None): | |||||
| if args.is_distributed: | if args.is_distributed: | ||||
| parallel_mode = ParallelMode.DATA_PARALLEL | parallel_mode = ParallelMode.DATA_PARALLEL | ||||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size, | context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size, | ||||
| parameter_broadcast=True, gradients_mean=True) | |||||
| gradients_mean=True) | |||||
| args.logger.save_args(args) | args.logger.save_args(args) | ||||
| @@ -200,7 +200,7 @@ def train(cloud_args=None): | |||||
| if args.is_distributed: | if args.is_distributed: | ||||
| parallel_mode = ParallelMode.DATA_PARALLEL | parallel_mode = ParallelMode.DATA_PARALLEL | ||||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size, | context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size, | ||||
| parameter_broadcast=True, gradients_mean=True) | |||||
| gradients_mean=True) | |||||
| # dataloader | # dataloader | ||||
| de_dataset = classification_dataset(args.data_dir, args.image_size, | de_dataset = classification_dataset(args.data_dir, args.image_size, | ||||
| args.per_batch_size, 1, | args.per_batch_size, 1, | ||||
| @@ -51,7 +51,6 @@ def train_net(data_dir, | |||||
| parallel_mode = ParallelMode.DATA_PARALLEL | parallel_mode = ParallelMode.DATA_PARALLEL | ||||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, | context.set_auto_parallel_context(parallel_mode=parallel_mode, | ||||
| device_num=group_size, | device_num=group_size, | ||||
| parameter_broadcast=True, | |||||
| gradients_mean=False) | gradients_mean=False) | ||||
| net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) | net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) | ||||
| @@ -139,7 +139,7 @@ if __name__ == '__main__': | |||||
| device_num = args.group_size | device_num = args.group_size | ||||
| context.reset_auto_parallel_context() | context.reset_auto_parallel_context() | ||||
| context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | ||||
| parameter_broadcast=True, gradients_mean=True) | |||||
| gradients_mean=True) | |||||
| else: | else: | ||||
| context.set_context(device_id=args.device_id) | context.set_context(device_id=args.device_id) | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | ||||
| @@ -254,7 +254,6 @@ def _setup_parallel_env(platform): | |||||
| context.set_auto_parallel_context( | context.set_auto_parallel_context( | ||||
| parallel_mode=ParallelMode.DATA_PARALLEL, | parallel_mode=ParallelMode.DATA_PARALLEL, | ||||
| device_num=MultiAscend.get_group_size(), | device_num=MultiAscend.get_group_size(), | ||||
| parameter_broadcast=True, | |||||
| gradients_mean=True | gradients_mean=True | ||||
| ) | ) | ||||
| @@ -123,7 +123,7 @@ def run_transformer_train(): | |||||
| device_num = args.device_num | device_num = args.device_num | ||||
| context.reset_auto_parallel_context() | context.reset_auto_parallel_context() | ||||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, | context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, | ||||
| parameter_broadcast=True, device_num=device_num) | |||||
| device_num=device_num) | |||||
| D.init() | D.init() | ||||
| rank_id = args.device_id % device_num | rank_id = args.device_id % device_num | ||||
| save_ckpt_path = os.path.join(args.save_checkpoint_path, 'ckpt_' + str(get_rank()) + '/') | save_ckpt_path = os.path.join(args.save_checkpoint_path, 'ckpt_' + str(get_rank()) + '/') | ||||