Browse Source

Add CloneCNodeWithInfos for kernel graph

tags/v1.3.0
l00591931 4 years ago
parent
commit
c7bc867ae4
5 changed files with 17 additions and 2 deletions
  1. +2
    -0
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc
  2. +2
    -0
      mindspore/ccsrc/backend/optimizer/graph_kernel/shape_ops_splitter.cc
  3. +10
    -0
      mindspore/ccsrc/backend/session/kernel_graph.cc
  4. +1
    -0
      mindspore/ccsrc/backend/session/kernel_graph.h
  5. +2
    -2
      mindspore/ccsrc/backend/session/session_basic.cc

+ 2
- 0
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc View File

@@ -93,6 +93,8 @@ CNodePtr NewRecomputeNode(const AnfNodePtr &orig_node, std::map<AnfNodePtr, AnfN
ScopePtr scope = (orig_node->scope() != kDefaultScope) ? orig_node->scope() : kDefaultScope;
cp_node->set_scope(scope);
cp_node->set_kernel_info(cnode->kernel_info_ptr());
cp_node->set_primal_attrs(cnode->primal_attrs());
cp_node->set_primal_debug_infos(cnode->primal_debug_infos());
(*node_map)[orig_node] = cp_node;
return cp_node->cast<CNodePtr>();
}


+ 2
- 0
mindspore/ccsrc/backend/optimizer/graph_kernel/shape_ops_splitter.cc View File

@@ -47,6 +47,8 @@ AnfNodePtr CloneCNode(const AnfNodePtr &anf_node) {
ScopePtr scope = (anf_node->scope() != kDefaultScope) ? anf_node->scope() : kDefaultScope;
node->set_scope(scope);
node->set_kernel_info(cnode->kernel_info_ptr());
node->set_primal_attrs(cnode->primal_attrs());
node->set_primal_debug_infos(cnode->primal_debug_infos());
return node;
}



+ 10
- 0
mindspore/ccsrc/backend/session/kernel_graph.cc View File

@@ -425,6 +425,16 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
return cnode;
}

CNodePtr KernelGraph::NewCNodeWithInfos(const std::vector<AnfNodePtr> &inputs, const CNodePtr &ori_cnode) {
auto cnode = NewCNode(inputs);
if (ori_cnode != nullptr) {
cnode->set_attrs(ori_cnode->attrs());
cnode->set_primal_attrs(ori_cnode->primal_attrs());
cnode->set_primal_debug_infos(ori_cnode->primal_debug_infos());
}
return cnode;
}

void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) {
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
MS_EXCEPTION_IF_NULL(func_graph);


+ 1
- 0
mindspore/ccsrc/backend/session/kernel_graph.h View File

@@ -108,6 +108,7 @@ class KernelGraph : public FuncGraph {
void ReplaceGraphInput(const AnfNodePtr &old_parameter, const AnfNodePtr &new_parameter);
std::vector<AnfNodePtr> outputs() const;
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs) override;
CNodePtr NewCNodeWithInfos(const std::vector<AnfNodePtr> &inputs, const CNodePtr &ori_cnode = nullptr);
void CreateKernelInfoFromNewParameter(const CNodePtr &cnode);
CNodePtr NewCNode(const CNodePtr &cnode);
void ResetAssignInputFeaatureMapFlag(const CNodePtr &cnode) const;


+ 2
- 2
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -687,7 +687,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
GetCNodeInfo(cnode, &cnode_inputs);
GetNewCNodeInputs(cnode, graph, &cnode_inputs, other_graph_cnode);
TraceGuard trace_guard(std::make_shared<TraceCopy>(cnode->debug_info()));
auto new_cnode = graph->NewCNode(cnode_inputs);
auto new_cnode = graph->NewCNodeWithInfos(cnode_inputs, cnode);
return new_cnode;
}

@@ -997,7 +997,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
// handle inputs of cnode except primitive
CreateCNodeInputs(cnode, graph, &cnode_inputs);
TraceGuard trace_guard(std::make_shared<TraceCopy>(cnode->debug_info()));
auto new_cnode = graph->NewCNode(cnode_inputs);
auto new_cnode = graph->NewCNodeWithInfos(cnode_inputs, cnode);
// if the cnode is call switch, remove call
if (new_cnode->inputs().size() > 1) {
auto first_input = new_cnode->input(kFirstDataInputIndex);


Loading…
Cancel
Save