|
|
|
@@ -618,19 +618,11 @@ void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node |
|
|
|
for (auto &replace_input : replace_graph->first) { |
|
|
|
auto pre_node = node->input(IntToSize(replace_input.second)); |
|
|
|
manager->SetEdge(replace_input.first, 1, pre_node); |
|
|
|
auto replace_input_cnode = replace_input.first->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(replace_input_cnode); |
|
|
|
(void)replace_input_cnode->set_operator_info(node->operator_info()); |
|
|
|
replace_input_cnode->set_in_forward_flag(true); // mark this new cnode is forward node |
|
|
|
} |
|
|
|
// "(void)manager->Replace(replace_graph->first, pre_node);" can not be called |
|
|
|
auto replace_output = replace_graph->second; |
|
|
|
MS_EXCEPTION_IF_NULL(replace_output); |
|
|
|
(void)manager->Replace(node, replace_output); |
|
|
|
CNodePtr replace_output_cnode = replace_graph->second->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(replace_output_cnode); |
|
|
|
(void)replace_output_cnode->set_operator_info(node->operator_info()); |
|
|
|
replace_output_cnode->set_in_forward_flag(true); // mark this new cnode is forward node |
|
|
|
} |
|
|
|
|
|
|
|
int32_t GetTupleGetItemIndex(const CNodePtr &cnode) { |
|
|
|
@@ -1994,14 +1986,27 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt |
|
|
|
BackwardCommunication(distribute_operator, cnode, sens_loss_pairs); |
|
|
|
} |
|
|
|
|
|
|
|
// StepReplace |
|
|
|
StepReplace(distribute_operator, cnode); |
|
|
|
|
|
|
|
HandleSpecialNode(distribute_operator, cnode); |
|
|
|
} else if (IsValueNode<Tensor>(node)) { |
|
|
|
StepSplitTensor(node, manager); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
for (auto &node : all_nodes) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (node->isa<CNode>()) { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
if (!IsValueNode<Primitive>(cnode->input(0))) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); |
|
|
|
if (distribute_operator == nullptr) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
// StepReplace |
|
|
|
StepReplace(distribute_operator, cnode); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
namespace { |
|
|
|
|