Merge pull request !5516 from yao_yf/auto_parallel_context_collationtags/v1.0.0
| @@ -42,15 +42,12 @@ std::shared_ptr<ParallelContext> ParallelContext::GetInstance() { | |||||
| return inst_context_; | return inst_context_; | ||||
| } | } | ||||
| ParallelContext::ParallelContext() { | |||||
| communication_backend_ = HCCL_BACKEND; | |||||
| Reset(); | |||||
| } | |||||
| ParallelContext::ParallelContext() { Reset(); } | |||||
| void ParallelContext::Reset() { | void ParallelContext::Reset() { | ||||
| mirror_mean_ = false; | mirror_mean_ = false; | ||||
| full_batch_ = false; | full_batch_ = false; | ||||
| cast_before_mirror_ = true; | |||||
| gradient_fp32_sync_ = true; | |||||
| loss_repeated_mean_ = true; | loss_repeated_mean_ = true; | ||||
| device_num_ = 1; | device_num_ = 1; | ||||
| global_rank_ = 0; | global_rank_ = 0; | ||||
| @@ -81,14 +78,10 @@ void ParallelContext::set_mirror_mean(bool mirror_mean) { mirror_mean_ = mirror_ | |||||
| void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; } | void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; } | ||||
| void ParallelContext::set_cast_before_mirror(bool cast_before_mirror) { cast_before_mirror_ = cast_before_mirror; } | |||||
| void ParallelContext::set_gradient_fp32_sync(bool gradient_fp32_sync) { gradient_fp32_sync_ = gradient_fp32_sync; } | |||||
| void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; } | void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; } | ||||
| void ParallelContext::set_communication_backend(const std::string &communication_backend) { | |||||
| communication_backend_ = communication_backend; | |||||
| } | |||||
| bool ParallelContext::set_parallel_mode(const std::string ¶llel_mode) { | bool ParallelContext::set_parallel_mode(const std::string ¶llel_mode) { | ||||
| auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode); | auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode); | ||||
| if (iter == PARALLEL_MODE_LIST.end()) { | if (iter == PARALLEL_MODE_LIST.end()) { | ||||
| @@ -58,8 +58,8 @@ class ParallelContext { | |||||
| void set_full_batch(bool full_batch); | void set_full_batch(bool full_batch); | ||||
| bool full_batch() const { return full_batch_; } | bool full_batch() const { return full_batch_; } | ||||
| void set_cast_before_mirror(bool cast_before_mirror); | |||||
| bool cast_before_mirror() const { return cast_before_mirror_; } | |||||
| void set_gradient_fp32_sync(bool gradient_fp32_sync); | |||||
| bool gradient_fp32_sync() const { return gradient_fp32_sync_; } | |||||
| void set_loss_repeated_mean(bool loss_repeated_mean); | void set_loss_repeated_mean(bool loss_repeated_mean); | ||||
| bool loss_repeated_mean() const { return loss_repeated_mean_; } | bool loss_repeated_mean() const { return loss_repeated_mean_; } | ||||
| @@ -70,9 +70,6 @@ class ParallelContext { | |||||
| void set_global_rank(int32_t global_rank); | void set_global_rank(int32_t global_rank); | ||||
| int32_t global_rank() const { return global_rank_; } | int32_t global_rank() const { return global_rank_; } | ||||
| void set_communication_backend(const std::string &communication_backend); | |||||
| std::string communication_backend() const { return communication_backend_; } | |||||
| bool set_parallel_mode(const std::string ¶llel_mode); | bool set_parallel_mode(const std::string ¶llel_mode); | ||||
| std::string parallel_mode() const { return parallel_mode_; } | std::string parallel_mode() const { return parallel_mode_; } | ||||
| @@ -112,11 +109,10 @@ class ParallelContext { | |||||
| static std::shared_ptr<ParallelContext> inst_context_; | static std::shared_ptr<ParallelContext> inst_context_; | ||||
| bool mirror_mean_; | bool mirror_mean_; | ||||
| bool full_batch_; | bool full_batch_; | ||||
| bool cast_before_mirror_; | |||||
| bool gradient_fp32_sync_; | |||||
| bool loss_repeated_mean_; | bool loss_repeated_mean_; | ||||
| int32_t device_num_; | int32_t device_num_; | ||||
| int32_t global_rank_; | int32_t global_rank_; | ||||
| std::string communication_backend_; | |||||
| std::string parallel_mode_; | std::string parallel_mode_; | ||||
| std::string strategy_search_mode_; | std::string strategy_search_mode_; | ||||
| bool parameter_broadcast_; | bool parameter_broadcast_; | ||||
| @@ -43,6 +43,7 @@ | |||||
| #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" | #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" | ||||
| #include "utils/comm_manager.h" | #include "utils/comm_manager.h" | ||||
| #include "utils/symbolic.h" | #include "utils/symbolic.h" | ||||
| #include "utils/ms_context.h" | |||||
| using mindspore::tensor::Tensor; | using mindspore::tensor::Tensor; | ||||
| @@ -869,8 +870,8 @@ std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string & | |||||
| } | } | ||||
| bool IsCastBeforMirror(const CNodePtr &node, size_t index) { | bool IsCastBeforMirror(const CNodePtr &node, size_t index) { | ||||
| // only if cast_before_mirror is true, pre node is cast and type is not float32 return true | |||||
| if (!ParallelContext::GetInstance()->cast_before_mirror()) { | |||||
| // only if gradient_fp32_sync is true, pre node is cast and type is not float32 return true | |||||
| if (!ParallelContext::GetInstance()->gradient_fp32_sync()) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto pre_node = node->input(index); | auto pre_node = node->input(index); | ||||
| @@ -2421,13 +2422,17 @@ Status ParallelInit() { | |||||
| MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | ||||
| int32_t device_num = ParallelContext::GetInstance()->device_num(); | int32_t device_num = ParallelContext::GetInstance()->device_num(); | ||||
| int32_t global_rank = ParallelContext::GetInstance()->global_rank(); | int32_t global_rank = ParallelContext::GetInstance()->global_rank(); | ||||
| std::string backend = ParallelContext::GetInstance()->communication_backend(); | |||||
| auto ms_context = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(ms_context); | |||||
| std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||||
| std::string world_group; | std::string world_group; | ||||
| if (backend == HCCL_BACKEND) { | |||||
| std::string communication_backend; | |||||
| if (backend == kAscendDevice || backend == kDavinciDevice) { | |||||
| world_group = HCCL_WORLD_GROUP; | world_group = HCCL_WORLD_GROUP; | ||||
| } else if (backend == NCCL_BACKEND) { | |||||
| communication_backend = HCCL_BACKEND; | |||||
| } else if (backend == kGPUDevice) { | |||||
| world_group = NCCL_WORLD_GROUP; | world_group = NCCL_WORLD_GROUP; | ||||
| communication_backend = NCCL_BACKEND; | |||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend; | MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend; | ||||
| } | } | ||||
| @@ -2450,14 +2455,14 @@ Status ParallelInit() { | |||||
| MS_LOG(INFO) << "Get global rank from communication model, the global rank is " << global_rank; | MS_LOG(INFO) << "Get global rank from communication model, the global rank is " << global_rank; | ||||
| } | } | ||||
| if (!InitDevice(device_num, global_rank, backend)) { | |||||
| if (!InitDevice(device_num, global_rank, communication_backend)) { | |||||
| MS_LOG(ERROR) << "Init device failed"; | MS_LOG(ERROR) << "Init device failed"; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| MS_LOG(INFO) << "The parallel context: dev num: " << device_num << ", global rank: " << global_rank | MS_LOG(INFO) << "The parallel context: dev num: " << device_num << ", global rank: " << global_rank | ||||
| << ", backend: " << backend << ", mirror_mean: " << ParallelContext::GetInstance()->mirror_mean() | << ", backend: " << backend << ", mirror_mean: " << ParallelContext::GetInstance()->mirror_mean() | ||||
| << ", cast_before_mirror: " << ParallelContext::GetInstance()->cast_before_mirror(); | |||||
| << ", gradient_fp32_sync: " << ParallelContext::GetInstance()->gradient_fp32_sync(); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -117,12 +117,10 @@ PYBIND11_MODULE(_c_expression, m) { | |||||
| .def("get_global_rank_is_set", &ParallelContext::global_rank_is_set, "Get global rank is set.") | .def("get_global_rank_is_set", &ParallelContext::global_rank_is_set, "Get global rank is set.") | ||||
| .def("get_mirror_mean", &ParallelContext::mirror_mean, "Get mirror mean.") | .def("get_mirror_mean", &ParallelContext::mirror_mean, "Get mirror mean.") | ||||
| .def("set_mirror_mean", &ParallelContext::set_mirror_mean, "Set mirror mean.") | .def("set_mirror_mean", &ParallelContext::set_mirror_mean, "Set mirror mean.") | ||||
| .def("get_cast_before_mirror", &ParallelContext::cast_before_mirror, "Get cast before mirror.") | |||||
| .def("set_cast_before_mirror", &ParallelContext::set_cast_before_mirror, "Set cast before mirror.") | |||||
| .def("get_gradient_fp32_sync", &ParallelContext::gradient_fp32_sync, "Get cast before mirror.") | |||||
| .def("set_gradient_fp32_sync", &ParallelContext::set_gradient_fp32_sync, "Set cast before mirror.") | |||||
| .def("get_loss_repeated_mean", &ParallelContext::loss_repeated_mean, "Get loss repeated mean.") | .def("get_loss_repeated_mean", &ParallelContext::loss_repeated_mean, "Get loss repeated mean.") | ||||
| .def("set_loss_repeated_mean", &ParallelContext::set_loss_repeated_mean, "Set loss repeated mean.") | .def("set_loss_repeated_mean", &ParallelContext::set_loss_repeated_mean, "Set loss repeated mean.") | ||||
| .def("get_communication_backend", &ParallelContext::communication_backend, "Get communication backend.") | |||||
| .def("set_communication_backend", &ParallelContext::set_communication_backend, "Set communication backend.") | |||||
| .def("get_parallel_mode", &ParallelContext::parallel_mode, "Get parallel mode.") | .def("get_parallel_mode", &ParallelContext::parallel_mode, "Get parallel mode.") | ||||
| .def("set_parallel_mode", &ParallelContext::set_parallel_mode, "Set parallel mode.") | .def("set_parallel_mode", &ParallelContext::set_parallel_mode, "Set parallel mode.") | ||||
| .def("get_strategy_search_mode", &ParallelContext::strategy_search_mode, "Get strategy search mode.") | .def("get_strategy_search_mode", &ParallelContext::strategy_search_mode, "Get strategy search mode.") | ||||
| @@ -15,7 +15,6 @@ | |||||
| """Communication management API""" | """Communication management API""" | ||||
| import os | import os | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||||
| from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \ | from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \ | ||||
| _get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \ | _get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \ | ||||
| _create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \ | _create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \ | ||||
| @@ -86,9 +85,6 @@ def init(backend_name=None): | |||||
| else: | else: | ||||
| raise RuntimeError("Backend name {} is not supported.".format(backend_name)) | raise RuntimeError("Backend name {} is not supported.".format(backend_name)) | ||||
| auto_parallel_context().set_communication_backend(backend_name) | |||||
| def release(): | def release(): | ||||
| """ | """ | ||||
| Release distributed resource. e.g., hccl/nccl. | Release distributed resource. e.g., hccl/nccl. | ||||
| @@ -323,7 +323,7 @@ def _context(): | |||||
| return _k_context | return _k_context | ||||
| @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str, | |||||
| @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, gradient_fp32_sync=bool, parallel_mode=str, | |||||
| auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, | auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, | ||||
| strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool) | strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool) | ||||
| def set_auto_parallel_context(**kwargs): | def set_auto_parallel_context(**kwargs): | ||||
| @@ -343,9 +343,9 @@ def set_auto_parallel_context(**kwargs): | |||||
| global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0. | global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0. | ||||
| mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror. | mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror. | ||||
| "stand_alone" do not support mirror_mean. Default: False. | "stand_alone" do not support mirror_mean. Default: False. | ||||
| cast_before_mirror (bool): Insert Mirror Op after the cast if this flag is True. | |||||
| gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True.. | |||||
| "stand_alone", "data_parallel" and "hybrid_parallel" do not support | "stand_alone", "data_parallel" and "hybrid_parallel" do not support | ||||
| cast_before_mirror. Default: True. | |||||
| gradient_fp32_sync. Default: True. | |||||
| parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel", | parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel", | ||||
| "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone". | "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone". | ||||
| @@ -381,7 +381,7 @@ def set_auto_parallel_context(**kwargs): | |||||
| >>> context.set_auto_parallel_context(device_num=8) | >>> context.set_auto_parallel_context(device_num=8) | ||||
| >>> context.set_auto_parallel_context(global_rank=0) | >>> context.set_auto_parallel_context(global_rank=0) | ||||
| >>> context.set_auto_parallel_context(mirror_mean=True) | >>> context.set_auto_parallel_context(mirror_mean=True) | ||||
| >>> context.set_auto_parallel_context(cast_before_mirror=False) | |||||
| >>> context.set_auto_parallel_context(gradient_fp32_sync=False) | |||||
| >>> context.set_auto_parallel_context(parallel_mode="auto_parallel") | >>> context.set_auto_parallel_context(parallel_mode="auto_parallel") | ||||
| >>> context.set_auto_parallel_context(parameter_broadcast=False) | >>> context.set_auto_parallel_context(parameter_broadcast=False) | ||||
| >>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt") | >>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt") | ||||
| @@ -413,7 +413,7 @@ def reset_auto_parallel_context(): | |||||
| - device_num: 1. | - device_num: 1. | ||||
| - global_rank: 0. | - global_rank: 0. | ||||
| - mirror_mean: False. | - mirror_mean: False. | ||||
| - cast_before_mirror: True. | |||||
| - gradient_fp32_sync: True. | |||||
| - parallel_mode: "stand_alone". | - parallel_mode: "stand_alone". | ||||
| - parameter_broadcast: False. | - parameter_broadcast: False. | ||||
| - strategy_ckpt_load_file: "". | - strategy_ckpt_load_file: "". | ||||
| @@ -113,24 +113,24 @@ class _AutoParallelContext: | |||||
| self.check_context_handle() | self.check_context_handle() | ||||
| return self._context_handle.get_mirror_mean() | return self._context_handle.get_mirror_mean() | ||||
| def set_cast_before_mirror(self, cast_before_mirror): | |||||
| def set_gradient_fp32_sync(self, gradient_fp32_sync): | |||||
| """ | """ | ||||
| Set cast_before_mirror. | |||||
| Set gradient_fp32_sync. | |||||
| Note: | Note: | ||||
| If cast_before_mirror is true, | |||||
| If gradient_fp32_sync is true, | |||||
| it will convert tensor type from fp16 to fp32 before parameter gradients allreduce. | it will convert tensor type from fp16 to fp32 before parameter gradients allreduce. | ||||
| Args: | Args: | ||||
| cast_before_mirror (bool): The cast_before_mirror flag. | |||||
| gradient_fp32_sync (bool): The gradient_fp32_sync flag. | |||||
| """ | """ | ||||
| self.check_context_handle() | self.check_context_handle() | ||||
| self._context_handle.set_cast_before_mirror(cast_before_mirror) | |||||
| self._context_handle.set_gradient_fp32_sync(gradient_fp32_sync) | |||||
| def get_cast_before_mirror(self): | |||||
| """Get cast_before_mirror flag.""" | |||||
| def get_gradient_fp32_sync(self): | |||||
| """Get gradient_fp32_sync flag.""" | |||||
| self.check_context_handle() | self.check_context_handle() | ||||
| return self._context_handle.get_cast_before_mirror() | |||||
| return self._context_handle.get_gradient_fp32_sync() | |||||
| def set_loss_repeated_mean(self, loss_repeated_mean): | def set_loss_repeated_mean(self, loss_repeated_mean): | ||||
| """ | """ | ||||
| @@ -152,21 +152,6 @@ class _AutoParallelContext: | |||||
| self.check_context_handle() | self.check_context_handle() | ||||
| return self._context_handle.get_loss_repeated_mean() | return self._context_handle.get_loss_repeated_mean() | ||||
| def set_communication_backend(self, communication_backend): | |||||
| """ | |||||
| Set communication backend. | |||||
| Args: | |||||
| communication_backend (str): The communication backend. | |||||
| """ | |||||
| self.check_context_handle() | |||||
| self._context_handle.set_communication_backend(communication_backend) | |||||
| def get_communication_backend(self): | |||||
| """Get communication backend.""" | |||||
| self.check_context_handle() | |||||
| return self._context_handle.get_communication_backend() | |||||
| def set_parallel_mode(self, parallel_mode): | def set_parallel_mode(self, parallel_mode): | ||||
| """ | """ | ||||
| Set parallel mode for auto parallel. | Set parallel mode for auto parallel. | ||||
| @@ -469,7 +454,7 @@ _set_auto_parallel_context_func_map = { | |||||
| "device_num": auto_parallel_context().set_device_num, | "device_num": auto_parallel_context().set_device_num, | ||||
| "global_rank": auto_parallel_context().set_global_rank, | "global_rank": auto_parallel_context().set_global_rank, | ||||
| "mirror_mean": auto_parallel_context().set_mirror_mean, | "mirror_mean": auto_parallel_context().set_mirror_mean, | ||||
| "cast_before_mirror": auto_parallel_context().set_cast_before_mirror, | |||||
| "gradient_fp32_sync": auto_parallel_context().set_gradient_fp32_sync, | |||||
| "loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean, | "loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean, | ||||
| "parallel_mode": auto_parallel_context().set_parallel_mode, | "parallel_mode": auto_parallel_context().set_parallel_mode, | ||||
| "auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode, | "auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode, | ||||
| @@ -484,7 +469,7 @@ _get_auto_parallel_context_func_map = { | |||||
| "device_num": auto_parallel_context().get_device_num, | "device_num": auto_parallel_context().get_device_num, | ||||
| "global_rank": auto_parallel_context().get_global_rank, | "global_rank": auto_parallel_context().get_global_rank, | ||||
| "mirror_mean": auto_parallel_context().get_mirror_mean, | "mirror_mean": auto_parallel_context().get_mirror_mean, | ||||
| "cast_before_mirror": auto_parallel_context().get_cast_before_mirror, | |||||
| "gradient_fp32_sync": auto_parallel_context().get_gradient_fp32_sync, | |||||
| "loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean, | "loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean, | ||||
| "parallel_mode": auto_parallel_context().get_parallel_mode, | "parallel_mode": auto_parallel_context().get_parallel_mode, | ||||
| "auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode, | "auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode, | ||||
| @@ -495,7 +480,7 @@ _get_auto_parallel_context_func_map = { | |||||
| "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer} | "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer} | ||||
| @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, | |||||
| @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, gradient_fp32_sync=bool, | |||||
| loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str, | loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str, | ||||
| parameter_broadcast=bool, strategy_ckpt_load_file=str, | parameter_broadcast=bool, strategy_ckpt_load_file=str, | ||||
| strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool) | strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool) | ||||
| @@ -512,8 +497,9 @@ def _set_auto_parallel_context(**kwargs): | |||||
| global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0. | global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0. | ||||
| mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False. | mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False. | ||||
| loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated | loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated | ||||
| calculations. Default: True. | |||||
| cast_before_mirror (bool): Insert Mirror Op after the cast if this flag is True. Default: True. | |||||
| calculations. Default: True. | |||||
| gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True. | |||||
| Default: True. | |||||
| parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel", | parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel", | ||||
| "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone". | "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone". | ||||
| @@ -577,7 +563,7 @@ def _reset_auto_parallel_context(): | |||||
| - device_num: 1. | - device_num: 1. | ||||
| - global_rank: 0. | - global_rank: 0. | ||||
| - mirror_mean: False. | - mirror_mean: False. | ||||
| - cast_before_mirror: True. | |||||
| - gradient_fp32_sync: True. | |||||
| - parallel_mode: "stand_alone". | - parallel_mode: "stand_alone". | ||||
| - parameter_broadcast: False. | - parameter_broadcast: False. | ||||
| - strategy_ckpt_load_file: "" | - strategy_ckpt_load_file: "" | ||||
| @@ -61,7 +61,7 @@ def get_rank_id(group=None): | |||||
| def get_rank_size(group=None): | def get_rank_size(group=None): | ||||
| hccl = Hccl() | hccl = Hccl() | ||||
| if group is None: | |||||
| if group is None or "nccl_world_group" in group: | |||||
| return hccl.rank_size | return hccl.rank_size | ||||
| if isinstance(group, str): | if isinstance(group, str): | ||||
| return int(group.split("-")[0]) | return int(group.split("-")[0]) | ||||
| @@ -830,7 +830,7 @@ def test_matmul_cast(): | |||||
| compile_net(net, x, y, b) | compile_net(net, x, y, b) | ||||
| def test_cast_before_mirror(): | |||||
| def test_gradient_fp32_sync(): | |||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| def __init__(self, strategy1): | def __init__(self, strategy1): | ||||
| super().__init__() | super().__init__() | ||||
| @@ -843,7 +843,7 @@ def test_cast_before_mirror(): | |||||
| out = self.matmul(out, b) | out = self.matmul(out, b) | ||||
| return out | return out | ||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, cast_before_mirror=True) | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, gradient_fp32_sync=True) | |||||
| strategy1 = ((2, 2), (2, 2)) | strategy1 = ((2, 2), (2, 2)) | ||||
| net = GradWrap(NetWithLoss(Net(strategy1))) | net = GradWrap(NetWithLoss(Net(strategy1))) | ||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | ||||
| @@ -854,7 +854,7 @@ def test_cast_before_mirror(): | |||||
| compile_net(net, x, y, b) | compile_net(net, x, y, b) | ||||
| def test_cast_before_mirror1(): | |||||
| def test_gradient_fp32_sync1(): | |||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| def __init__(self, strategy1): | def __init__(self, strategy1): | ||||
| super().__init__() | super().__init__() | ||||
| @@ -867,7 +867,7 @@ def test_cast_before_mirror1(): | |||||
| out = self.matmul(out, b) | out = self.matmul(out, b) | ||||
| return out | return out | ||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, cast_before_mirror=True) | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, gradient_fp32_sync=True) | |||||
| strategy1 = ((2, 2), (2, 2)) | strategy1 = ((2, 2), (2, 2)) | ||||
| net = GradWrap(NetWithLoss(Net(strategy1))) | net = GradWrap(NetWithLoss(Net(strategy1))) | ||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | ||||
| @@ -878,7 +878,7 @@ def test_cast_before_mirror1(): | |||||
| compile_net(net, x, y, b) | compile_net(net, x, y, b) | ||||
| def test_cast_before_mirror2(): | |||||
| def test_gradient_fp32_sync2(): | |||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| def __init__(self, strategy1): | def __init__(self, strategy1): | ||||
| super().__init__() | super().__init__() | ||||
| @@ -891,7 +891,7 @@ def test_cast_before_mirror2(): | |||||
| out = self.matmul(out, b) | out = self.matmul(out, b) | ||||
| return out | return out | ||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, cast_before_mirror=False) | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, gradient_fp32_sync=False) | |||||
| strategy1 = ((2, 2), (2, 2)) | strategy1 = ((2, 2), (2, 2)) | ||||
| net = GradWrap(NetWithLoss(Net(strategy1))) | net = GradWrap(NetWithLoss(Net(strategy1))) | ||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | ||||
| @@ -902,7 +902,7 @@ def test_cast_before_mirror2(): | |||||
| compile_net(net, x, y, b) | compile_net(net, x, y, b) | ||||
| def test_cast_before_mirror3(): | |||||
| def test_gradient_fp32_sync3(): | |||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| def __init__(self, strategy1): | def __init__(self, strategy1): | ||||
| super().__init__() | super().__init__() | ||||
| @@ -20,25 +20,21 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||||
| def test_set_auto_parallel_context(): | def test_set_auto_parallel_context(): | ||||
| context.set_auto_parallel_context(device_num=4, global_rank=3, mirror_mean=True, cast_before_mirror=False, | |||||
| context.set_auto_parallel_context(device_num=4, global_rank=3, mirror_mean=True, gradient_fp32_sync=False, | |||||
| parallel_mode="auto_parallel", parameter_broadcast=False) | parallel_mode="auto_parallel", parameter_broadcast=False) | ||||
| device_num = context.get_auto_parallel_context("device_num") | device_num = context.get_auto_parallel_context("device_num") | ||||
| global_rank = context.get_auto_parallel_context("global_rank") | global_rank = context.get_auto_parallel_context("global_rank") | ||||
| mirror_mean = context.get_auto_parallel_context("mirror_mean") | mirror_mean = context.get_auto_parallel_context("mirror_mean") | ||||
| cast_before_mirror = context.get_auto_parallel_context("cast_before_mirror") | |||||
| gradient_fp32_sync = context.get_auto_parallel_context("gradient_fp32_sync") | |||||
| parallel_mode = context.get_auto_parallel_context("parallel_mode") | parallel_mode = context.get_auto_parallel_context("parallel_mode") | ||||
| parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast") | parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast") | ||||
| assert device_num == 4 | assert device_num == 4 | ||||
| assert global_rank == 3 | assert global_rank == 3 | ||||
| assert mirror_mean | assert mirror_mean | ||||
| assert not cast_before_mirror | |||||
| assert not gradient_fp32_sync | |||||
| assert parallel_mode == "auto_parallel" | assert parallel_mode == "auto_parallel" | ||||
| assert not parameter_broadcast | assert not parameter_broadcast | ||||
| auto_parallel_context().set_communication_backend("hccl") | |||||
| backend = auto_parallel_context().get_communication_backend() | |||||
| assert backend == "hccl" | |||||
| auto_parallel_context().set_device_num(4) | auto_parallel_context().set_device_num(4) | ||||
| device_num = auto_parallel_context().get_device_num() | device_num = auto_parallel_context().get_device_num() | ||||
| device_num_is_set = auto_parallel_context().get_device_num_is_set() | device_num_is_set = auto_parallel_context().get_device_num_is_set() | ||||
| @@ -53,9 +49,9 @@ def test_set_auto_parallel_context(): | |||||
| mirror_mean = auto_parallel_context().get_mirror_mean() | mirror_mean = auto_parallel_context().get_mirror_mean() | ||||
| assert mirror_mean | assert mirror_mean | ||||
| auto_parallel_context().set_cast_before_mirror(False) | |||||
| cast_before_mirror = auto_parallel_context().get_cast_before_mirror() | |||||
| assert not cast_before_mirror | |||||
| auto_parallel_context().set_gradient_fp32_sync(False) | |||||
| gradient_fp32_sync = auto_parallel_context().get_gradient_fp32_sync() | |||||
| assert not gradient_fp32_sync | |||||
| parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set() | parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set() | ||||
| assert parameter_broadcast_is_set | assert parameter_broadcast_is_set | ||||
| @@ -91,7 +87,7 @@ def test_reset_auto_parallel_context(): | |||||
| device_num = context.get_auto_parallel_context("device_num") | device_num = context.get_auto_parallel_context("device_num") | ||||
| global_rank = context.get_auto_parallel_context("global_rank") | global_rank = context.get_auto_parallel_context("global_rank") | ||||
| mirror_mean = context.get_auto_parallel_context("mirror_mean") | mirror_mean = context.get_auto_parallel_context("mirror_mean") | ||||
| cast_before_mirror = context.get_auto_parallel_context("cast_before_mirror") | |||||
| gradient_fp32_sync = context.get_auto_parallel_context("gradient_fp32_sync") | |||||
| parallel_mode = context.get_auto_parallel_context("parallel_mode") | parallel_mode = context.get_auto_parallel_context("parallel_mode") | ||||
| parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast") | parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast") | ||||
| device_num_is_set = auto_parallel_context().get_device_num_is_set() | device_num_is_set = auto_parallel_context().get_device_num_is_set() | ||||
| @@ -99,7 +95,7 @@ def test_reset_auto_parallel_context(): | |||||
| assert device_num == 1 | assert device_num == 1 | ||||
| assert global_rank == 0 | assert global_rank == 0 | ||||
| assert not mirror_mean | assert not mirror_mean | ||||
| assert cast_before_mirror | |||||
| assert gradient_fp32_sync | |||||
| assert parallel_mode == "stand_alone" | assert parallel_mode == "stand_alone" | ||||
| assert not parameter_broadcast | assert not parameter_broadcast | ||||
| assert not device_num_is_set | assert not device_num_is_set | ||||