Browse Source

add check for allreduce fusion

tags/v1.1.0
lichenever 5 years ago
parent
commit
cfffff2875
2 changed files with 12 additions and 2 deletions
  1. +5
    -1
      mindspore/ccsrc/frontend/parallel/step_parallel.cc
  2. +7
    -1
      mindspore/parallel/_auto_parallel_context.py

+ 5
- 1
mindspore/ccsrc/frontend/parallel/step_parallel.cc View File

@@ -1895,7 +1895,11 @@ LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(pre_node);

auto pre_cnode = pre_node->cast<CNodePtr>();
if (pre_cnode == nullptr) {
if (pre_cnode == nullptr || !IsValueNode<Primitive>(pre_cnode->input(0))) {
return loss_node_info;
}
if (!IsValueNode<Primitive>(pre_cnode->input(0))) {
MS_LOG(DEBUG) << "pre_cnode:" << pre_cnode->ToString();
return loss_node_info;
}
auto prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));


+ 7
- 1
mindspore/parallel/_auto_parallel_context.py View File

@@ -294,6 +294,12 @@ class _AutoParallelContext:
else:
raise TypeError('indices must be a python list')

if len(set(indices)) != len(indices):
raise ValueError('indices has duplicate elements')

if sorted(indices) != indices:
raise ValueError('elements in indices must be sorted in ascending order')

if isinstance(group, (str)):
group_len = len(group)
if group_len > _MAX_GROUP_NAME_LEN:
@@ -308,7 +314,7 @@ class _AutoParallelContext:
group = _DEFAULT_NCCL_FUSION_GROUP_NAME

self._context_handle.set_all_reduce_fusion_split_indices(indices, group)
if context.get_context("device_target") == "Ascend":
if context.get_context("device_target") == "Ascend" and context.get_context("enable_ge"):
_set_fusion_strategy_by_idx(indices)

def get_all_reduce_fusion_split_indices(self, group=""):


Loading…
Cancel
Save