|
|
|
@@ -75,20 +75,23 @@ void SetNodeAttr(AnfNodeIndex aggregate_node, AnfNodePtr skip_node, std::vector< |
|
|
|
group++; |
|
|
|
} |
|
|
|
|
|
|
|
void InsertControlDependToGraph(const FuncGraphPtr &graph, const std::vector<AnfNodeIndex> &inplace_nodes) { |
|
|
|
void InsertControlDependToGraph(const FuncGraphPtr &graph, const std::vector<AnfNodeIndex> &inplace_nodes, |
|
|
|
const AnfNodePtr aggregate_node) { |
|
|
|
std::vector<AnfNodePtr> inputs1 = {NewValueNode(std::make_shared<Primitive>(prim::kPrimControlDepend->name())), |
|
|
|
inplace_nodes[0].node, inplace_nodes[1].node}; |
|
|
|
auto control_depend_node = graph->NewCNode(inputs1); |
|
|
|
|
|
|
|
auto return_node = graph->get_return(); |
|
|
|
MS_EXCEPTION_IF_NULL(return_node); |
|
|
|
// mount the `depend` before make_tuple, otherwise the output of graph will be `(tensor, )` rather than `tensor` |
|
|
|
auto return_input = return_node->input(kFirstDataInputIndex)->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(return_input); |
|
|
|
std::vector<AnfNodePtr> inputs2 = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())), |
|
|
|
return_input->input(kFirstDataInputIndex), control_depend_node}; |
|
|
|
aggregate_node, control_depend_node}; |
|
|
|
auto depend_node = graph->NewCNode(inputs2); |
|
|
|
return_node->set_input(kFirstDataInputIndex, depend_node); |
|
|
|
|
|
|
|
auto users = GetRealNodeUsedList(graph, aggregate_node); |
|
|
|
if (users->size() == 0) { |
|
|
|
MS_LOG(EXCEPTION) << "No users found: " << aggregate_node->DebugString(); |
|
|
|
} |
|
|
|
auto mount_node = users->at(0).first->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(mount_node); |
|
|
|
mount_node->set_input(kFirstDataInputIndex, depend_node); |
|
|
|
} |
|
|
|
|
|
|
|
bool PatternMatch(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodeIndex *aggregate, AnfNodePtr *skip_node, |
|
|
|
@@ -186,7 +189,7 @@ bool CudnnInplaceAggregate::Run(const FuncGraphPtr &graph) { |
|
|
|
// 2. Set Node attr |
|
|
|
SetNodeAttr(aggregate_node, skip_node, &inplace_node); |
|
|
|
// 3. Set dependence for inplace nodes |
|
|
|
InsertControlDependToGraph(graph, inplace_node); |
|
|
|
InsertControlDependToGraph(graph, inplace_node, aggregate_node.node); |
|
|
|
} |
|
|
|
|
|
|
|
return true; |
|
|
|
|