Merge pull request !3405 from gziyan/fix_optimizer_paralleltags/v0.7.0-beta
| @@ -26,13 +26,13 @@ namespace mindspore { | |||||
| namespace kernel { | namespace kernel { | ||||
| namespace { | namespace { | ||||
| std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) { | std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) { | ||||
| const std::set<std::string> kReduceNoSupportedSet = {kOpFormat_FRAC_Z, kOpFormat_FRACTAL_Z_C04, kOpFormat_C1HWNCoC0}; | |||||
| auto op_name = AnfAlgo::GetCNodeName(kernel_node); | |||||
| auto parallel_context_instance = parallel::ParallelContext::GetInstance(); | auto parallel_context_instance = parallel::ParallelContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(parallel_context_instance); | MS_EXCEPTION_IF_NULL(parallel_context_instance); | ||||
| if (parallel_context_instance->enable_parallel_optimizer()) { | |||||
| if (parallel_context_instance->enable_parallel_optimizer() && op_name == kBroadcast) { | |||||
| return kOpFormat_DEFAULT; | return kOpFormat_DEFAULT; | ||||
| } | } | ||||
| const std::set<std::string> kReduceNoSupportedSet = {kOpFormat_FRAC_Z, kOpFormat_FRACTAL_Z_C04, kOpFormat_C1HWNCoC0}; | |||||
| auto op_name = AnfAlgo::GetCNodeName(kernel_node); | |||||
| auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index); | auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index); | ||||
| if (op_name != kReduceScatter && op_name != kAllGatherOpName) { | if (op_name != kReduceScatter && op_name != kAllGatherOpName) { | ||||
| return format; | return format; | ||||
| @@ -65,6 +65,8 @@ void ParallelContext::Reset() { | |||||
| strategy_ckpt_load_file_ = ""; | strategy_ckpt_load_file_ = ""; | ||||
| strategy_ckpt_save_file_ = ""; | strategy_ckpt_save_file_ = ""; | ||||
| enable_parallel_optimizer_ = false; | enable_parallel_optimizer_ = false; | ||||
| all_reduce_fusion_split_indices_.clear(); | |||||
| all_reduce_fusion_split_sizes_.clear(); | |||||
| } | } | ||||
| void ParallelContext::set_device_num(int32_t device_num) { | void ParallelContext::set_device_num(int32_t device_num) { | ||||
| @@ -371,5 +371,5 @@ class AdamWeightDecay(Optimizer): | |||||
| self.parameters, self.moments1, self.moments2, | self.parameters, self.moments1, self.moments2, | ||||
| gradients, self.decay_flags, self.optim_filter) | gradients, self.decay_flags, self.optim_filter) | ||||
| if self.use_parallel: | if self.use_parallel: | ||||
| optim_result = self.broadcast_params(optim_result) | |||||
| self.broadcast_params(optim_result) | |||||
| return optim_result | return optim_result | ||||
| @@ -312,7 +312,7 @@ class Lamb(Optimizer): | |||||
| self.decay_flags, self.optim_filter) | self.decay_flags, self.optim_filter) | ||||
| if self.use_parallel: | if self.use_parallel: | ||||
| optim_result = self.broadcast_params(optim_result) | |||||
| self.broadcast_params(optim_result) | |||||
| if not self.dynamic_lr: | if not self.dynamic_lr: | ||||
| F.control_depend(lr, self.assignadd(self.global_step, 1)) | F.control_depend(lr, self.assignadd(self.global_step, 1)) | ||||
| @@ -466,7 +466,7 @@ class Optimizer(Cell): | |||||
| param_group.append(F.make_tuple()) | param_group.append(F.make_tuple()) | ||||
| key_group.append(F.make_tuple()) | key_group.append(F.make_tuple()) | ||||
| for i in range(self.param_length): | for i in range(self.param_length): | ||||
| param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (optim_result[i],) | |||||
| param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (self.parameters[i],) | |||||
| key = P.MakeRefKey(self.param_names[i])() | key = P.MakeRefKey(self.param_names[i])() | ||||
| key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,) | key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,) | ||||
| new_param_group = [] | new_param_group = [] | ||||
| @@ -476,9 +476,9 @@ class Optimizer(Cell): | |||||
| new_param_group.append(next_params) | new_param_group.append(next_params) | ||||
| for i in range(F.tuple_len(next_params)): | for i in range(F.tuple_len(next_params)): | ||||
| F.assign(key_group[root][i], next_params[i]) | F.assign(key_group[root][i], next_params[i]) | ||||
| status = True | |||||
| status = F.control_depend(optim_result, new_param_group[0][0]) | |||||
| for i in range(self.dev_num - 1): | for i in range(self.dev_num - 1): | ||||
| status = F.control_depend(new_param_group[i][0], new_param_group[i+1]) | |||||
| status = F.depend(F.control_depend(new_param_group[i], new_param_group[i+1][0]), status) | |||||
| return status | return status | ||||
| @@ -25,7 +25,7 @@ import mindspore.common.dtype as mstype | |||||
| reduce_opt = C.MultitypeFuncGraph("reduce_opt") | reduce_opt = C.MultitypeFuncGraph("reduce_opt") | ||||
| def _init_allreduce_operators(length): | |||||
| def _init_allreduce_operators(length, split_indices): | |||||
| """ initialize allreduce communication operators""" | """ initialize allreduce communication operators""" | ||||
| group = 1 | group = 1 | ||||
| fusion = () | fusion = () | ||||
| @@ -318,7 +318,7 @@ class DistributedGradReducer(Cell): | |||||
| split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices() | split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices() | ||||
| if is_parallel_optimizer and split_indices: | if is_parallel_optimizer and split_indices: | ||||
| self.split_fusion = True | self.split_fusion = True | ||||
| self.op_list = _init_allreduce_operators(len(parameters)) | |||||
| self.op_list = _init_allreduce_operators(len(parameters), split_indices) | |||||
| else: | else: | ||||
| self.split_fusion = False | self.split_fusion = False | ||||
| self.allreduce = AllReduce().add_prim_attr('fusion', 1) | self.allreduce = AllReduce().add_prim_attr('fusion', 1) | ||||
| @@ -344,10 +344,10 @@ class DistributedGradReducer(Cell): | |||||
| if self.split_fusion: | if self.split_fusion: | ||||
| if self.enable_parameter_server: | if self.enable_parameter_server: | ||||
| new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), | new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), | ||||
| self.opt_list, self.allreduce_filter, grads, self.ps_parameters) | |||||
| self.op_list, self.allreduce_filter, grads, self.ps_parameters) | |||||
| else: | else: | ||||
| new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), | new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), | ||||
| self.opt_list, self.allreduce_filter, grads) | |||||
| self.op_list, self.allreduce_filter, grads) | |||||
| else: | else: | ||||
| if self.enable_parameter_server: | if self.enable_parameter_server: | ||||
| new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather, | new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather, | ||||
| @@ -16,8 +16,6 @@ | |||||
| import ctypes | import ctypes | ||||
| from mindspore import log as logger | |||||
| _MAX_GROUP_NAME_LEN = 127 | _MAX_GROUP_NAME_LEN = 127 | ||||
| _HCCL_LIB = 'libhccl.so' | _HCCL_LIB = 'libhccl.so' | ||||
| @@ -25,8 +23,8 @@ _HCCL_LIB = 'libhccl.so' | |||||
| def _load_lib(): | def _load_lib(): | ||||
| try: | try: | ||||
| hccl_lib = ctypes.CDLL(_HCCL_LIB) | hccl_lib = ctypes.CDLL(_HCCL_LIB) | ||||
| except RuntimeError: | |||||
| logger.error('Get hccl lib error') | |||||
| except Exception: | |||||
| raise RuntimeError('Get hccl lib error') | |||||
| return hccl_lib | return hccl_lib | ||||
| @@ -69,8 +67,9 @@ def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"): | |||||
| try: | try: | ||||
| lib_ctype = _load_lib() | lib_ctype = _load_lib() | ||||
| except RuntimeError: | except RuntimeError: | ||||
| logger.error('Load HCCL lib failed') | |||||
| import hccl_test.manage.api as hccl | |||||
| hccl.set_fusion_strategy_by_idx() | |||||
| return | |||||
| if isinstance(group, (str)): | if isinstance(group, (str)): | ||||
| group_len = len(group) | group_len = len(group) | ||||
| if (group_len > _MAX_GROUP_NAME_LEN or group_len == 0): | if (group_len > _MAX_GROUP_NAME_LEN or group_len == 0): | ||||
| @@ -126,7 +125,9 @@ def _set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"): | |||||
| try: | try: | ||||
| lib_ctype = _load_lib() | lib_ctype = _load_lib() | ||||
| except RuntimeError: | except RuntimeError: | ||||
| logger.error('Load HCCL lib failed') | |||||
| import hccl_test.manage.api as hccl | |||||
| hccl.set_fusion_strategy_by_size() | |||||
| return | |||||
| if isinstance(group, (str)): | if isinstance(group, (str)): | ||||
| group_len = len(group) | group_len = len(group) | ||||
| if group_len > _MAX_GROUP_NAME_LEN or group_len == 0: | if group_len > _MAX_GROUP_NAME_LEN or group_len == 0: | ||||
| @@ -86,3 +86,13 @@ def create_group(group, rank_size, rank_ids): | |||||
| # pylint: disable=unused-argument | # pylint: disable=unused-argument | ||||
| def destroy_group(group): | def destroy_group(group): | ||||
| pass | pass | ||||
| # pylint: disable=unused-argument | |||||
| def set_fusion_strategy_by_idx(): | |||||
| pass | |||||
| # pylint: disable=unused-argument | |||||
| def set_fusion_strategy_by_size(): | |||||
| pass | |||||
| @@ -23,7 +23,7 @@ from mindspore.nn import TrainOneStepCell, WithLossCell | |||||
| from mindspore.nn.optim import Adam, AdamWeightDecay, Lamb | from mindspore.nn.optim import Adam, AdamWeightDecay, Lamb | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| """Net definition""" | """Net definition""" | ||||
| @@ -64,6 +64,7 @@ def test_AdamWeightDecay(): | |||||
| net_with_loss = WithLossCell(net, loss) | net_with_loss = WithLossCell(net, loss) | ||||
| train_network = TrainOneStepCell(net_with_loss, optimizer) | train_network = TrainOneStepCell(net_with_loss, optimizer) | ||||
| _executor.compile(train_network, inputs, label) | _executor.compile(train_network, inputs, label) | ||||
| context.reset_auto_parallel_context() | |||||
| def test_lamb_compile(): | def test_lamb_compile(): | ||||
| @@ -79,7 +80,24 @@ def test_lamb_compile(): | |||||
| net_with_loss = WithLossCell(net, loss) | net_with_loss = WithLossCell(net, loss) | ||||
| train_network = TrainOneStepCell(net_with_loss, optimizer) | train_network = TrainOneStepCell(net_with_loss, optimizer) | ||||
| _executor.compile(train_network, inputs, label) | _executor.compile(train_network, inputs, label) | ||||
| context.reset_auto_parallel_context() | |||||
| def test_lamb_split_fusion(): | |||||
| """ test_Lamb_split_fusion """ | |||||
| context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True) | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([2, 4, 6, 8]) | |||||
| inputs = Tensor(np.ones([32, 128]).astype(np.float32)) | |||||
| label = Tensor(np.zeros([32, 768]).astype(np.float32)) | |||||
| net = Net() | |||||
| net.set_train() | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits() | |||||
| optimizer = Lamb(net.trainable_params(), learning_rate=0.1) | |||||
| net_with_loss = WithLossCell(net, loss) | |||||
| train_network = TrainOneStepCell(net_with_loss, optimizer) | |||||
| _executor.compile(train_network, inputs, label) | |||||
| context.reset_auto_parallel_context() | |||||
| def test_edge_case(): | def test_edge_case(): | ||||
| """ test_edge_case """ | """ test_edge_case """ | ||||
| @@ -93,3 +111,4 @@ def test_edge_case(): | |||||
| with pytest.raises(RuntimeError): | with pytest.raises(RuntimeError): | ||||
| context.set_auto_parallel_context(device_num=16) | context.set_auto_parallel_context(device_num=16) | ||||
| Lamb(net.trainable_params(), learning_rate=0.1) | Lamb(net.trainable_params(), learning_rate=0.1) | ||||
| context.reset_auto_parallel_context() | |||||