diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 84cb5705ce..d475bf02fd 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -1895,7 +1895,11 @@ LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(pre_node); auto pre_cnode = pre_node->cast(); - if (pre_cnode == nullptr) { + if (pre_cnode == nullptr || !IsValueNode(pre_cnode->input(0))) { + return loss_node_info; + } + if (!IsValueNode(pre_cnode->input(0))) { + MS_LOG(DEBUG) << "pre_cnode:" << pre_cnode->ToString(); return loss_node_info; } auto prim = GetValueNode(pre_cnode->input(0)); diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index 67074eb576..2185cf50a2 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -294,6 +294,12 @@ class _AutoParallelContext: else: raise TypeError('indices must be a python list') + if len(set(indices)) != len(indices): + raise ValueError('indices has duplicate elements') + + if sorted(indices) != indices: + raise ValueError('elements in indices must be sorted in ascending order') + if isinstance(group, (str)): group_len = len(group) if group_len > _MAX_GROUP_NAME_LEN: @@ -308,7 +314,7 @@ class _AutoParallelContext: group = _DEFAULT_NCCL_FUSION_GROUP_NAME self._context_handle.set_all_reduce_fusion_split_indices(indices, group) - if context.get_context("device_target") == "Ascend": + if context.get_context("device_target") == "Ascend" and context.get_context("enable_ge"): _set_fusion_strategy_by_idx(indices) def get_all_reduce_fusion_split_indices(self, group=""):