Merge pull request !5516 from yao_yf/auto_parallel_context_collationtags/v1.0.0
| @@ -42,15 +42,12 @@ std::shared_ptr<ParallelContext> ParallelContext::GetInstance() { | |||
| return inst_context_; | |||
| } | |||
| ParallelContext::ParallelContext() { | |||
| communication_backend_ = HCCL_BACKEND; | |||
| Reset(); | |||
| } | |||
| ParallelContext::ParallelContext() { Reset(); } | |||
| void ParallelContext::Reset() { | |||
| mirror_mean_ = false; | |||
| full_batch_ = false; | |||
| cast_before_mirror_ = true; | |||
| gradient_fp32_sync_ = true; | |||
| loss_repeated_mean_ = true; | |||
| device_num_ = 1; | |||
| global_rank_ = 0; | |||
| @@ -81,14 +78,10 @@ void ParallelContext::set_mirror_mean(bool mirror_mean) { mirror_mean_ = mirror_ | |||
| void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; } | |||
| void ParallelContext::set_cast_before_mirror(bool cast_before_mirror) { cast_before_mirror_ = cast_before_mirror; } | |||
| void ParallelContext::set_gradient_fp32_sync(bool gradient_fp32_sync) { gradient_fp32_sync_ = gradient_fp32_sync; } | |||
| void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; } | |||
| void ParallelContext::set_communication_backend(const std::string &communication_backend) { | |||
| communication_backend_ = communication_backend; | |||
| } | |||
| bool ParallelContext::set_parallel_mode(const std::string ¶llel_mode) { | |||
| auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode); | |||
| if (iter == PARALLEL_MODE_LIST.end()) { | |||
| @@ -58,8 +58,8 @@ class ParallelContext { | |||
| void set_full_batch(bool full_batch); | |||
| bool full_batch() const { return full_batch_; } | |||
| void set_cast_before_mirror(bool cast_before_mirror); | |||
| bool cast_before_mirror() const { return cast_before_mirror_; } | |||
| void set_gradient_fp32_sync(bool gradient_fp32_sync); | |||
| bool gradient_fp32_sync() const { return gradient_fp32_sync_; } | |||
| void set_loss_repeated_mean(bool loss_repeated_mean); | |||
| bool loss_repeated_mean() const { return loss_repeated_mean_; } | |||
| @@ -70,9 +70,6 @@ class ParallelContext { | |||
| void set_global_rank(int32_t global_rank); | |||
| int32_t global_rank() const { return global_rank_; } | |||
| void set_communication_backend(const std::string &communication_backend); | |||
| std::string communication_backend() const { return communication_backend_; } | |||
| bool set_parallel_mode(const std::string ¶llel_mode); | |||
| std::string parallel_mode() const { return parallel_mode_; } | |||
| @@ -112,11 +109,10 @@ class ParallelContext { | |||
| static std::shared_ptr<ParallelContext> inst_context_; | |||
| bool mirror_mean_; | |||
| bool full_batch_; | |||
| bool cast_before_mirror_; | |||
| bool gradient_fp32_sync_; | |||
| bool loss_repeated_mean_; | |||
| int32_t device_num_; | |||
| int32_t global_rank_; | |||
| std::string communication_backend_; | |||
| std::string parallel_mode_; | |||
| std::string strategy_search_mode_; | |||
| bool parameter_broadcast_; | |||
| @@ -43,6 +43,7 @@ | |||
| #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" | |||
| #include "utils/comm_manager.h" | |||
| #include "utils/symbolic.h" | |||
| #include "utils/ms_context.h" | |||
| using mindspore::tensor::Tensor; | |||
| @@ -869,8 +870,8 @@ std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string & | |||
| } | |||
| bool IsCastBeforMirror(const CNodePtr &node, size_t index) { | |||
| // only if cast_before_mirror is true, pre node is cast and type is not float32 return true | |||
| if (!ParallelContext::GetInstance()->cast_before_mirror()) { | |||
| // only if gradient_fp32_sync is true, pre node is cast and type is not float32 return true | |||
| if (!ParallelContext::GetInstance()->gradient_fp32_sync()) { | |||
| return false; | |||
| } | |||
| auto pre_node = node->input(index); | |||
| @@ -2421,13 +2422,17 @@ Status ParallelInit() { | |||
| MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | |||
| int32_t device_num = ParallelContext::GetInstance()->device_num(); | |||
| int32_t global_rank = ParallelContext::GetInstance()->global_rank(); | |||
| std::string backend = ParallelContext::GetInstance()->communication_backend(); | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| std::string world_group; | |||
| if (backend == HCCL_BACKEND) { | |||
| std::string communication_backend; | |||
| if (backend == kAscendDevice || backend == kDavinciDevice) { | |||
| world_group = HCCL_WORLD_GROUP; | |||
| } else if (backend == NCCL_BACKEND) { | |||
| communication_backend = HCCL_BACKEND; | |||
| } else if (backend == kGPUDevice) { | |||
| world_group = NCCL_WORLD_GROUP; | |||
| communication_backend = NCCL_BACKEND; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend; | |||
| } | |||
| @@ -2450,14 +2455,14 @@ Status ParallelInit() { | |||
| MS_LOG(INFO) << "Get global rank from communication model, the global rank is " << global_rank; | |||
| } | |||
| if (!InitDevice(device_num, global_rank, backend)) { | |||
| if (!InitDevice(device_num, global_rank, communication_backend)) { | |||
| MS_LOG(ERROR) << "Init device failed"; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << "The parallel context: dev num: " << device_num << ", global rank: " << global_rank | |||
| << ", backend: " << backend << ", mirror_mean: " << ParallelContext::GetInstance()->mirror_mean() | |||
| << ", cast_before_mirror: " << ParallelContext::GetInstance()->cast_before_mirror(); | |||
| << ", gradient_fp32_sync: " << ParallelContext::GetInstance()->gradient_fp32_sync(); | |||
| return SUCCESS; | |||
| } | |||
| @@ -117,12 +117,10 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| .def("get_global_rank_is_set", &ParallelContext::global_rank_is_set, "Get global rank is set.") | |||
| .def("get_mirror_mean", &ParallelContext::mirror_mean, "Get mirror mean.") | |||
| .def("set_mirror_mean", &ParallelContext::set_mirror_mean, "Set mirror mean.") | |||
| .def("get_cast_before_mirror", &ParallelContext::cast_before_mirror, "Get cast before mirror.") | |||
| .def("set_cast_before_mirror", &ParallelContext::set_cast_before_mirror, "Set cast before mirror.") | |||
| .def("get_gradient_fp32_sync", &ParallelContext::gradient_fp32_sync, "Get cast before mirror.") | |||
| .def("set_gradient_fp32_sync", &ParallelContext::set_gradient_fp32_sync, "Set cast before mirror.") | |||
| .def("get_loss_repeated_mean", &ParallelContext::loss_repeated_mean, "Get loss repeated mean.") | |||
| .def("set_loss_repeated_mean", &ParallelContext::set_loss_repeated_mean, "Set loss repeated mean.") | |||
| .def("get_communication_backend", &ParallelContext::communication_backend, "Get communication backend.") | |||
| .def("set_communication_backend", &ParallelContext::set_communication_backend, "Set communication backend.") | |||
| .def("get_parallel_mode", &ParallelContext::parallel_mode, "Get parallel mode.") | |||
| .def("set_parallel_mode", &ParallelContext::set_parallel_mode, "Set parallel mode.") | |||
| .def("get_strategy_search_mode", &ParallelContext::strategy_search_mode, "Get strategy search mode.") | |||
| @@ -15,7 +15,6 @@ | |||
| """Communication management API""" | |||
| import os | |||
| from mindspore import context | |||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||
| from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \ | |||
| _get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \ | |||
| _create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \ | |||
| @@ -86,9 +85,6 @@ def init(backend_name=None): | |||
| else: | |||
| raise RuntimeError("Backend name {} is not supported.".format(backend_name)) | |||
| auto_parallel_context().set_communication_backend(backend_name) | |||
| def release(): | |||
| """ | |||
| Release distributed resource. e.g., hccl/nccl. | |||
| @@ -323,7 +323,7 @@ def _context(): | |||
| return _k_context | |||
| @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str, | |||
| @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, gradient_fp32_sync=bool, parallel_mode=str, | |||
| auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, | |||
| strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool) | |||
| def set_auto_parallel_context(**kwargs): | |||
| @@ -343,9 +343,9 @@ def set_auto_parallel_context(**kwargs): | |||
| global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0. | |||
| mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror. | |||
| "stand_alone" do not support mirror_mean. Default: False. | |||
| cast_before_mirror (bool): Insert Mirror Op after the cast if this flag is True. | |||
| gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True.. | |||
| "stand_alone", "data_parallel" and "hybrid_parallel" do not support | |||
| cast_before_mirror. Default: True. | |||
| gradient_fp32_sync. Default: True. | |||
| parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel", | |||
| "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone". | |||
| @@ -381,7 +381,7 @@ def set_auto_parallel_context(**kwargs): | |||
| >>> context.set_auto_parallel_context(device_num=8) | |||
| >>> context.set_auto_parallel_context(global_rank=0) | |||
| >>> context.set_auto_parallel_context(mirror_mean=True) | |||
| >>> context.set_auto_parallel_context(cast_before_mirror=False) | |||
| >>> context.set_auto_parallel_context(gradient_fp32_sync=False) | |||
| >>> context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| >>> context.set_auto_parallel_context(parameter_broadcast=False) | |||
| >>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt") | |||
| @@ -413,7 +413,7 @@ def reset_auto_parallel_context(): | |||
| - device_num: 1. | |||
| - global_rank: 0. | |||
| - mirror_mean: False. | |||
| - cast_before_mirror: True. | |||
| - gradient_fp32_sync: True. | |||
| - parallel_mode: "stand_alone". | |||
| - parameter_broadcast: False. | |||
| - strategy_ckpt_load_file: "". | |||
| @@ -113,24 +113,24 @@ class _AutoParallelContext: | |||
| self.check_context_handle() | |||
| return self._context_handle.get_mirror_mean() | |||
| def set_cast_before_mirror(self, cast_before_mirror): | |||
| def set_gradient_fp32_sync(self, gradient_fp32_sync): | |||
| """ | |||
| Set cast_before_mirror. | |||
| Set gradient_fp32_sync. | |||
| Note: | |||
| If cast_before_mirror is true, | |||
| If gradient_fp32_sync is true, | |||
| it will convert tensor type from fp16 to fp32 before parameter gradients allreduce. | |||
| Args: | |||
| cast_before_mirror (bool): The cast_before_mirror flag. | |||
| gradient_fp32_sync (bool): The gradient_fp32_sync flag. | |||
| """ | |||
| self.check_context_handle() | |||
| self._context_handle.set_cast_before_mirror(cast_before_mirror) | |||
| self._context_handle.set_gradient_fp32_sync(gradient_fp32_sync) | |||
| def get_cast_before_mirror(self): | |||
| """Get cast_before_mirror flag.""" | |||
| def get_gradient_fp32_sync(self): | |||
| """Get gradient_fp32_sync flag.""" | |||
| self.check_context_handle() | |||
| return self._context_handle.get_cast_before_mirror() | |||
| return self._context_handle.get_gradient_fp32_sync() | |||
| def set_loss_repeated_mean(self, loss_repeated_mean): | |||
| """ | |||
| @@ -152,21 +152,6 @@ class _AutoParallelContext: | |||
| self.check_context_handle() | |||
| return self._context_handle.get_loss_repeated_mean() | |||
| def set_communication_backend(self, communication_backend): | |||
| """ | |||
| Set communication backend. | |||
| Args: | |||
| communication_backend (str): The communication backend. | |||
| """ | |||
| self.check_context_handle() | |||
| self._context_handle.set_communication_backend(communication_backend) | |||
| def get_communication_backend(self): | |||
| """Get communication backend.""" | |||
| self.check_context_handle() | |||
| return self._context_handle.get_communication_backend() | |||
| def set_parallel_mode(self, parallel_mode): | |||
| """ | |||
| Set parallel mode for auto parallel. | |||
| @@ -469,7 +454,7 @@ _set_auto_parallel_context_func_map = { | |||
| "device_num": auto_parallel_context().set_device_num, | |||
| "global_rank": auto_parallel_context().set_global_rank, | |||
| "mirror_mean": auto_parallel_context().set_mirror_mean, | |||
| "cast_before_mirror": auto_parallel_context().set_cast_before_mirror, | |||
| "gradient_fp32_sync": auto_parallel_context().set_gradient_fp32_sync, | |||
| "loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean, | |||
| "parallel_mode": auto_parallel_context().set_parallel_mode, | |||
| "auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode, | |||
| @@ -484,7 +469,7 @@ _get_auto_parallel_context_func_map = { | |||
| "device_num": auto_parallel_context().get_device_num, | |||
| "global_rank": auto_parallel_context().get_global_rank, | |||
| "mirror_mean": auto_parallel_context().get_mirror_mean, | |||
| "cast_before_mirror": auto_parallel_context().get_cast_before_mirror, | |||
| "gradient_fp32_sync": auto_parallel_context().get_gradient_fp32_sync, | |||
| "loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean, | |||
| "parallel_mode": auto_parallel_context().get_parallel_mode, | |||
| "auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode, | |||
| @@ -495,7 +480,7 @@ _get_auto_parallel_context_func_map = { | |||
| "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer} | |||
| @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, | |||
| @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, gradient_fp32_sync=bool, | |||
| loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str, | |||
| parameter_broadcast=bool, strategy_ckpt_load_file=str, | |||
| strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool) | |||
| @@ -512,8 +497,9 @@ def _set_auto_parallel_context(**kwargs): | |||
| global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0. | |||
| mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False. | |||
| loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated | |||
| calculations. Default: True. | |||
| cast_before_mirror (bool): Insert Mirror Op after the cast if this flag is True. Default: True. | |||
| calculations. Default: True. | |||
| gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True. | |||
| Default: True. | |||
| parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel", | |||
| "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone". | |||
| @@ -577,7 +563,7 @@ def _reset_auto_parallel_context(): | |||
| - device_num: 1. | |||
| - global_rank: 0. | |||
| - mirror_mean: False. | |||
| - cast_before_mirror: True. | |||
| - gradient_fp32_sync: True. | |||
| - parallel_mode: "stand_alone". | |||
| - parameter_broadcast: False. | |||
| - strategy_ckpt_load_file: "" | |||
| @@ -61,7 +61,7 @@ def get_rank_id(group=None): | |||
| def get_rank_size(group=None): | |||
| hccl = Hccl() | |||
| if group is None: | |||
| if group is None or "nccl_world_group" in group: | |||
| return hccl.rank_size | |||
| if isinstance(group, str): | |||
| return int(group.split("-")[0]) | |||
| @@ -830,7 +830,7 @@ def test_matmul_cast(): | |||
| compile_net(net, x, y, b) | |||
| def test_cast_before_mirror(): | |||
| def test_gradient_fp32_sync(): | |||
| class Net(nn.Cell): | |||
| def __init__(self, strategy1): | |||
| super().__init__() | |||
| @@ -843,7 +843,7 @@ def test_cast_before_mirror(): | |||
| out = self.matmul(out, b) | |||
| return out | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, cast_before_mirror=True) | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, gradient_fp32_sync=True) | |||
| strategy1 = ((2, 2), (2, 2)) | |||
| net = GradWrap(NetWithLoss(Net(strategy1))) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| @@ -854,7 +854,7 @@ def test_cast_before_mirror(): | |||
| compile_net(net, x, y, b) | |||
| def test_cast_before_mirror1(): | |||
| def test_gradient_fp32_sync1(): | |||
| class Net(nn.Cell): | |||
| def __init__(self, strategy1): | |||
| super().__init__() | |||
| @@ -867,7 +867,7 @@ def test_cast_before_mirror1(): | |||
| out = self.matmul(out, b) | |||
| return out | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, cast_before_mirror=True) | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, gradient_fp32_sync=True) | |||
| strategy1 = ((2, 2), (2, 2)) | |||
| net = GradWrap(NetWithLoss(Net(strategy1))) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| @@ -878,7 +878,7 @@ def test_cast_before_mirror1(): | |||
| compile_net(net, x, y, b) | |||
| def test_cast_before_mirror2(): | |||
| def test_gradient_fp32_sync2(): | |||
| class Net(nn.Cell): | |||
| def __init__(self, strategy1): | |||
| super().__init__() | |||
| @@ -891,7 +891,7 @@ def test_cast_before_mirror2(): | |||
| out = self.matmul(out, b) | |||
| return out | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, cast_before_mirror=False) | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, gradient_fp32_sync=False) | |||
| strategy1 = ((2, 2), (2, 2)) | |||
| net = GradWrap(NetWithLoss(Net(strategy1))) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| @@ -902,7 +902,7 @@ def test_cast_before_mirror2(): | |||
| compile_net(net, x, y, b) | |||
| def test_cast_before_mirror3(): | |||
| def test_gradient_fp32_sync3(): | |||
| class Net(nn.Cell): | |||
| def __init__(self, strategy1): | |||
| super().__init__() | |||
| @@ -20,25 +20,21 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||
| def test_set_auto_parallel_context(): | |||
| context.set_auto_parallel_context(device_num=4, global_rank=3, mirror_mean=True, cast_before_mirror=False, | |||
| context.set_auto_parallel_context(device_num=4, global_rank=3, mirror_mean=True, gradient_fp32_sync=False, | |||
| parallel_mode="auto_parallel", parameter_broadcast=False) | |||
| device_num = context.get_auto_parallel_context("device_num") | |||
| global_rank = context.get_auto_parallel_context("global_rank") | |||
| mirror_mean = context.get_auto_parallel_context("mirror_mean") | |||
| cast_before_mirror = context.get_auto_parallel_context("cast_before_mirror") | |||
| gradient_fp32_sync = context.get_auto_parallel_context("gradient_fp32_sync") | |||
| parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast") | |||
| assert device_num == 4 | |||
| assert global_rank == 3 | |||
| assert mirror_mean | |||
| assert not cast_before_mirror | |||
| assert not gradient_fp32_sync | |||
| assert parallel_mode == "auto_parallel" | |||
| assert not parameter_broadcast | |||
| auto_parallel_context().set_communication_backend("hccl") | |||
| backend = auto_parallel_context().get_communication_backend() | |||
| assert backend == "hccl" | |||
| auto_parallel_context().set_device_num(4) | |||
| device_num = auto_parallel_context().get_device_num() | |||
| device_num_is_set = auto_parallel_context().get_device_num_is_set() | |||
| @@ -53,9 +49,9 @@ def test_set_auto_parallel_context(): | |||
| mirror_mean = auto_parallel_context().get_mirror_mean() | |||
| assert mirror_mean | |||
| auto_parallel_context().set_cast_before_mirror(False) | |||
| cast_before_mirror = auto_parallel_context().get_cast_before_mirror() | |||
| assert not cast_before_mirror | |||
| auto_parallel_context().set_gradient_fp32_sync(False) | |||
| gradient_fp32_sync = auto_parallel_context().get_gradient_fp32_sync() | |||
| assert not gradient_fp32_sync | |||
| parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set() | |||
| assert parameter_broadcast_is_set | |||
| @@ -91,7 +87,7 @@ def test_reset_auto_parallel_context(): | |||
| device_num = context.get_auto_parallel_context("device_num") | |||
| global_rank = context.get_auto_parallel_context("global_rank") | |||
| mirror_mean = context.get_auto_parallel_context("mirror_mean") | |||
| cast_before_mirror = context.get_auto_parallel_context("cast_before_mirror") | |||
| gradient_fp32_sync = context.get_auto_parallel_context("gradient_fp32_sync") | |||
| parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast") | |||
| device_num_is_set = auto_parallel_context().get_device_num_is_set() | |||
| @@ -99,7 +95,7 @@ def test_reset_auto_parallel_context(): | |||
| assert device_num == 1 | |||
| assert global_rank == 0 | |||
| assert not mirror_mean | |||
| assert cast_before_mirror | |||
| assert gradient_fp32_sync | |||
| assert parallel_mode == "stand_alone" | |||
| assert not parameter_broadcast | |||
| assert not device_num_is_set | |||