Browse Source

modify controldepend mount node

tags/v1.1.0
wilfChen 5 years ago
parent
commit
e877f72bcf
2 changed files with 13 additions and 9 deletions
  1. +12
    -9
      mindspore/ccsrc/backend/optimizer/gpu/cudnn_inplace_fusion.cc
  2. +1
    -0
      mindspore/ccsrc/backend/session/gpu_session.cc

+ 12
- 9
mindspore/ccsrc/backend/optimizer/gpu/cudnn_inplace_fusion.cc View File

@@ -75,20 +75,23 @@ void SetNodeAttr(AnfNodeIndex aggregate_node, AnfNodePtr skip_node, std::vector<
group++;
}

void InsertControlDependToGraph(const FuncGraphPtr &graph, const std::vector<AnfNodeIndex> &inplace_nodes) {
void InsertControlDependToGraph(const FuncGraphPtr &graph, const std::vector<AnfNodeIndex> &inplace_nodes,
const AnfNodePtr aggregate_node) {
std::vector<AnfNodePtr> inputs1 = {NewValueNode(std::make_shared<Primitive>(prim::kPrimControlDepend->name())),
inplace_nodes[0].node, inplace_nodes[1].node};
auto control_depend_node = graph->NewCNode(inputs1);

auto return_node = graph->get_return();
MS_EXCEPTION_IF_NULL(return_node);
// mount the `depend` before make_tuple, otherwise the output of graph will be `(tensor, )` rather than `tensor`
auto return_input = return_node->input(kFirstDataInputIndex)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(return_input);
std::vector<AnfNodePtr> inputs2 = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
return_input->input(kFirstDataInputIndex), control_depend_node};
aggregate_node, control_depend_node};
auto depend_node = graph->NewCNode(inputs2);
return_node->set_input(kFirstDataInputIndex, depend_node);

auto users = GetRealNodeUsedList(graph, aggregate_node);
if (users->size() == 0) {
MS_LOG(EXCEPTION) << "No users found: " << aggregate_node->DebugString();
}
auto mount_node = users->at(0).first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(mount_node);
mount_node->set_input(kFirstDataInputIndex, depend_node);
}

bool PatternMatch(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodeIndex *aggregate, AnfNodePtr *skip_node,
@@ -186,7 +189,7 @@ bool CudnnInplaceAggregate::Run(const FuncGraphPtr &graph) {
// 2. Set Node attr
SetNodeAttr(aggregate_node, skip_node, &inplace_node);
// 3. Set dependence for inplace nodes
InsertControlDependToGraph(graph, inplace_node);
InsertControlDependToGraph(graph, inplace_node, aggregate_node.node);
}

return true;


+ 1
- 0
mindspore/ccsrc/backend/session/gpu_session.cc View File

@@ -108,6 +108,7 @@ void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_gra
pm->AddPass(std::make_shared<opt::InsertFormatTransformOp>());
pm->AddPass(std::make_shared<opt::RemoveFormatTransformPair>());
pm->AddPass(std::make_shared<opt::RemoveRedundantFormatTransform>());
pm->AddPass(std::make_shared<opt::CudnnInplaceAggregate>());
pm->AddPass(std::make_shared<opt::AllReduceFusion>());
pm->AddPass(std::make_shared<opt::GetitemTuple>());
pm->AddPass(std::make_shared<opt::ReducePrecisionFusion>("reduce_precision"));


Loading…
Cancel
Save