| @@ -26,13 +26,13 @@ namespace mindspore { | |||
| namespace kernel { | |||
| namespace { | |||
| std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) { | |||
| const std::set<std::string> kReduceNoSupportedSet = {kOpFormat_FRAC_Z, kOpFormat_FRACTAL_Z_C04, kOpFormat_C1HWNCoC0}; | |||
| auto op_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| auto parallel_context_instance = parallel::ParallelContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(parallel_context_instance); | |||
| if (parallel_context_instance->enable_parallel_optimizer()) { | |||
| if (parallel_context_instance->enable_parallel_optimizer() && op_name == kBroadcast) { | |||
| return kOpFormat_DEFAULT; | |||
| } | |||
| const std::set<std::string> kReduceNoSupportedSet = {kOpFormat_FRAC_Z, kOpFormat_FRACTAL_Z_C04, kOpFormat_C1HWNCoC0}; | |||
| auto op_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index); | |||
| if (op_name != kReduceScatter && op_name != kAllGatherOpName) { | |||
| return format; | |||
| @@ -65,6 +65,8 @@ void ParallelContext::Reset() { | |||
| strategy_ckpt_load_file_ = ""; | |||
| strategy_ckpt_save_file_ = ""; | |||
| enable_parallel_optimizer_ = false; | |||
| all_reduce_fusion_split_indices_.clear(); | |||
| all_reduce_fusion_split_sizes_.clear(); | |||
| } | |||
| void ParallelContext::set_device_num(int32_t device_num) { | |||
| @@ -371,5 +371,5 @@ class AdamWeightDecay(Optimizer): | |||
| self.parameters, self.moments1, self.moments2, | |||
| gradients, self.decay_flags, self.optim_filter) | |||
| if self.use_parallel: | |||
| optim_result = self.broadcast_params(optim_result) | |||
| self.broadcast_params(optim_result) | |||
| return optim_result | |||
| @@ -312,7 +312,7 @@ class Lamb(Optimizer): | |||
| self.decay_flags, self.optim_filter) | |||
| if self.use_parallel: | |||
| optim_result = self.broadcast_params(optim_result) | |||
| self.broadcast_params(optim_result) | |||
| if not self.dynamic_lr: | |||
| F.control_depend(lr, self.assignadd(self.global_step, 1)) | |||
| @@ -466,7 +466,7 @@ class Optimizer(Cell): | |||
| param_group.append(F.make_tuple()) | |||
| key_group.append(F.make_tuple()) | |||
| for i in range(self.param_length): | |||
| param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (optim_result[i],) | |||
| param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (self.parameters[i],) | |||
| key = P.MakeRefKey(self.param_names[i])() | |||
| key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,) | |||
| new_param_group = [] | |||
| @@ -476,9 +476,9 @@ class Optimizer(Cell): | |||
| new_param_group.append(next_params) | |||
| for i in range(F.tuple_len(next_params)): | |||
| F.assign(key_group[root][i], next_params[i]) | |||
| status = True | |||
| status = F.control_depend(optim_result, new_param_group[0][0]) | |||
| for i in range(self.dev_num - 1): | |||
| status = F.control_depend(new_param_group[i][0], new_param_group[i+1]) | |||
| status = F.depend(F.control_depend(new_param_group[i], new_param_group[i+1][0]), status) | |||
| return status | |||
| @@ -25,7 +25,7 @@ import mindspore.common.dtype as mstype | |||
| reduce_opt = C.MultitypeFuncGraph("reduce_opt") | |||
| def _init_allreduce_operators(length): | |||
| def _init_allreduce_operators(length, split_indices): | |||
| """ initialize allreduce communication operators""" | |||
| group = 1 | |||
| fusion = () | |||
| @@ -318,7 +318,7 @@ class DistributedGradReducer(Cell): | |||
| split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices() | |||
| if is_parallel_optimizer and split_indices: | |||
| self.split_fusion = True | |||
| self.op_list = _init_allreduce_operators(len(parameters)) | |||
| self.op_list = _init_allreduce_operators(len(parameters), split_indices) | |||
| else: | |||
| self.split_fusion = False | |||
| self.allreduce = AllReduce().add_prim_attr('fusion', 1) | |||
| @@ -344,10 +344,10 @@ class DistributedGradReducer(Cell): | |||
| if self.split_fusion: | |||
| if self.enable_parameter_server: | |||
| new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), | |||
| self.opt_list, self.allreduce_filter, grads, self.ps_parameters) | |||
| self.op_list, self.allreduce_filter, grads, self.ps_parameters) | |||
| else: | |||
| new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), | |||
| self.opt_list, self.allreduce_filter, grads) | |||
| self.op_list, self.allreduce_filter, grads) | |||
| else: | |||
| if self.enable_parameter_server: | |||
| new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather, | |||
| @@ -16,8 +16,6 @@ | |||
| import ctypes | |||
| from mindspore import log as logger | |||
| _MAX_GROUP_NAME_LEN = 127 | |||
| _HCCL_LIB = 'libhccl.so' | |||
| @@ -25,8 +23,8 @@ _HCCL_LIB = 'libhccl.so' | |||
| def _load_lib(): | |||
| try: | |||
| hccl_lib = ctypes.CDLL(_HCCL_LIB) | |||
| except RuntimeError: | |||
| logger.error('Get hccl lib error') | |||
| except Exception: | |||
| raise RuntimeError('Get hccl lib error') | |||
| return hccl_lib | |||
| @@ -69,8 +67,9 @@ def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"): | |||
| try: | |||
| lib_ctype = _load_lib() | |||
| except RuntimeError: | |||
| logger.error('Load HCCL lib failed') | |||
| import hccl_test.manage.api as hccl | |||
| hccl.set_fusion_strategy_by_idx() | |||
| return | |||
| if isinstance(group, (str)): | |||
| group_len = len(group) | |||
| if (group_len > _MAX_GROUP_NAME_LEN or group_len == 0): | |||
| @@ -126,7 +125,9 @@ def _set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"): | |||
| try: | |||
| lib_ctype = _load_lib() | |||
| except RuntimeError: | |||
| logger.error('Load HCCL lib failed') | |||
| import hccl_test.manage.api as hccl | |||
| hccl.set_fusion_strategy_by_size() | |||
| return | |||
| if isinstance(group, (str)): | |||
| group_len = len(group) | |||
| if group_len > _MAX_GROUP_NAME_LEN or group_len == 0: | |||
| @@ -86,3 +86,13 @@ def create_group(group, rank_size, rank_ids): | |||
| # pylint: disable=unused-argument | |||
| def destroy_group(group): | |||
| pass | |||
| # pylint: disable=unused-argument | |||
| def set_fusion_strategy_by_idx(): | |||
| pass | |||
| # pylint: disable=unused-argument | |||
| def set_fusion_strategy_by_size(): | |||
| pass | |||
| @@ -23,7 +23,7 @@ from mindspore.nn import TrainOneStepCell, WithLossCell | |||
| from mindspore.nn.optim import Adam, AdamWeightDecay, Lamb | |||
| from mindspore.ops import operations as P | |||
| from mindspore import context | |||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||
| class Net(nn.Cell): | |||
| """Net definition""" | |||
| @@ -64,6 +64,7 @@ def test_AdamWeightDecay(): | |||
| net_with_loss = WithLossCell(net, loss) | |||
| train_network = TrainOneStepCell(net_with_loss, optimizer) | |||
| _executor.compile(train_network, inputs, label) | |||
| context.reset_auto_parallel_context() | |||
| def test_lamb_compile(): | |||
| @@ -79,7 +80,24 @@ def test_lamb_compile(): | |||
| net_with_loss = WithLossCell(net, loss) | |||
| train_network = TrainOneStepCell(net_with_loss, optimizer) | |||
| _executor.compile(train_network, inputs, label) | |||
| context.reset_auto_parallel_context() | |||
| def test_lamb_split_fusion(): | |||
| """ test_Lamb_split_fusion """ | |||
| context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True) | |||
| auto_parallel_context().set_all_reduce_fusion_split_indices([2, 4, 6, 8]) | |||
| inputs = Tensor(np.ones([32, 128]).astype(np.float32)) | |||
| label = Tensor(np.zeros([32, 768]).astype(np.float32)) | |||
| net = Net() | |||
| net.set_train() | |||
| loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| optimizer = Lamb(net.trainable_params(), learning_rate=0.1) | |||
| net_with_loss = WithLossCell(net, loss) | |||
| train_network = TrainOneStepCell(net_with_loss, optimizer) | |||
| _executor.compile(train_network, inputs, label) | |||
| context.reset_auto_parallel_context() | |||
| def test_edge_case(): | |||
| """ test_edge_case """ | |||
| @@ -93,3 +111,4 @@ def test_edge_case(): | |||
| with pytest.raises(RuntimeError): | |||
| context.set_auto_parallel_context(device_num=16) | |||
| Lamb(net.trainable_params(), learning_rate=0.1) | |||
| context.reset_auto_parallel_context() | |||