Browse Source

Fix network's AR and optimizer cannot be in parallel.

tags/v1.4.0
linqingke 4 years ago
parent
commit
7a463a885e
3 changed files with 19 additions and 8 deletions
  1. +6
    -1
      mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc
  2. +3
    -3
      mindspore/ccsrc/backend/session/kernel_graph.cc
  3. +10
    -4
      mindspore/nn/wrap/grad_reducer.py

+ 6
- 1
mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc View File

@@ -151,9 +151,14 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic
uint32_t last_index = 0;
for (size_t i = 0; i < split_indices.size(); ++i) {
uint32_t index = split_indices[i];
if ((index <= last_index && i != 0) || index >= communication_op_node_size) {
if (index <= last_index && i != 0) {
MS_LOG(EXCEPTION) << "invalid " << op_name_ << " split index " << i << " " << index;
}
if (index >= communication_op_node_size) {
MS_LOG(WARNING) << op_name_ << "'s split index " << index << " is large than total gradient's number "
<< communication_op_node_size;
continue;
}
segment_index->push_back(index);
last_index = index;
segments++;


+ 3
- 3
mindspore/ccsrc/backend/session/kernel_graph.cc View File

@@ -212,13 +212,13 @@ void KernelGraph::SetExecOrderByDefault() {
execution_order_.clear();
std::unordered_set<AnfNodePtr> visited_nodes;
std::queue<AnfNodePtr> zero_input_nodes;
std::stack<AnfNodePtr> delay_comm_stack;
std::queue<AnfNodePtr> delay_comm_stack;
std::queue<AnfNodePtr> communication_descendants;
std::string optimized_comm_group;
while (!seed_nodes.empty() || !delay_comm_stack.empty()) {
// seed nodes first, then delay comm nodes
if (seed_nodes.empty()) {
EnqueueActiveNodes(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false);
EnqueueActiveNodes(delay_comm_stack.front(), &communication_descendants, &visited_nodes, false);
delay_comm_stack.pop();
} else {
zero_input_nodes.push(seed_nodes.front());
@@ -253,7 +253,7 @@ void KernelGraph::SetExecOrderByDefault() {
}
if (optimize_comm) {
while (!delay_comm_stack.empty()) {
EnqueueActiveNodes(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false);
EnqueueActiveNodes(delay_comm_stack.front(), &communication_descendants, &visited_nodes, false);
delay_comm_stack.pop();
}
delay_comm_stack.push(node);


+ 10
- 4
mindspore/nn/wrap/grad_reducer.py View File

@@ -46,12 +46,13 @@ def _init_allreduce_operators(length, split_indices):
return op_list


def _init_allreduce_operators_by_parameters(parameters):
def _init_allreduce_operators_by_parameters(parameters, split_indices):
""" initialize allreduce communication operators by parameters"""
op_list = ()
param_fusion = False
last_comm_fusion = None
first_parameter_flag = True
index = 1
for parameter in parameters:
comm_fusion = parameter.comm_fusion
if first_parameter_flag:
@@ -63,10 +64,15 @@ def _init_allreduce_operators_by_parameters(parameters):
last_comm_fusion = comm_fusion
op = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP)
op.add_prim_attr('fusion', comm_fusion)
op.add_prim_attr('index', comm_fusion)
op.add_prim_attr('index', index)
index += 1
op_list = op_list + (op,)
if not param_fusion:
op_list = ()
if split_indices and split_indices[-1] == len(parameters) - 1:
op_list = _init_allreduce_operators(len(parameters), split_indices)
param_fusion = True
else:
op_list = ()
return op_list, param_fusion


@@ -385,7 +391,7 @@ class DistributedGradReducer(Cell):
self.op_list = _init_allreduce_operators(len(parameters), split_indices)
else:
self.split_fusion = True
self.op_list, param_fusion = _init_allreduce_operators_by_parameters(parameters)
self.op_list, param_fusion = _init_allreduce_operators_by_parameters(parameters, split_indices)
if not param_fusion:
self.split_fusion = False
self.allreduce = AllReduce().add_prim_attr('fusion', fusion_type)


Loading…
Cancel
Save