Browse Source

refactor CreateNewCNode

tags/v0.7.0-beta
Margaret_wangrui 5 years ago
parent
commit
614258dc26
3 changed files with 51 additions and 25 deletions
  1. +6
    -1
      mindspore/ccsrc/backend/session/kernel_graph.cc
  2. +43
    -24
      mindspore/ccsrc/backend/session/session_basic.cc
  3. +2
    -0
      mindspore/ccsrc/backend/session/session_basic.h

+ 6
- 1
mindspore/ccsrc/backend/session/kernel_graph.cc View File

@@ -563,7 +563,12 @@ void KernelGraph::FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNode
MS_LOG(EXCEPTION) << "Anf " << front_anf->DebugString() << " has been exist in the front_backend_anf_map_";
}
if (backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end()) {
MS_LOG(EXCEPTION) << "Kernel " << backend_anf->DebugString() << "has been exist in the backend_front_anf_map_";
auto front_node = front_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(front_node);
auto attr_input = front_node->input(kAnfPrimitiveIndex);
if (!attr_input->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Kernel " << backend_anf->DebugString() << "has been exist in the backend_front_anf_map_";
}
}
front_backend_anf_map_[front_anf] = backend_anf;
backend_front_anf_map_[backend_anf] = front_anf;


+ 43
- 24
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -606,6 +606,10 @@ std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &
} else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
auto switch_cnode = cnode_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(switch_cnode);
if (cnode->inputs().size() < 2) {
cnode_inputs = switch_cnode->inputs();
return cnode_inputs;
}
std::vector<AnfNodePtr> switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex),
switch_cnode->input(kFirstDataInputIndex)};
for (size_t index = kFirstBranchInSwitch; index < switch_cnode->inputs().size(); index++) {
@@ -630,7 +634,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &
MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch.";
}

CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) {
std::vector<AnfNodePtr> SessionBasic::CreateValueNode(const CNodePtr &cnode, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> cnode_inputs;
@@ -641,7 +645,7 @@ CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(fg);
auto new_fg = BasicClone(fg);
cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg));
} else if (IsValueNode<FuncGraph>(attr_input)) {
} else {
// create primitive of cnode:call
cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
// create a ValueNode<KernelGraph> as input of cnode:call
@@ -653,38 +657,27 @@ CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) {
cnode_inputs.emplace_back(new_value_node);
}
}
} else if (attr_input->isa<CNode>()) {
auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
if (cnode->inputs().size() < 2 && AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
auto switch_cnode = cnode_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(switch_cnode);
cnode_inputs = switch_cnode->inputs();
} else {
cnode_inputs = CreateSwitchOrPartialNode(cnode, graph);
}
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
cnode_inputs = {graph->GetBackendAnfByFrontAnf(cnode->input(kAnfPrimitiveIndex)),
graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex))};
}
return cnode_inputs;
}

void SessionBasic::CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(graph);
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex)));
for (size_t index = kFirstBranchInSwitch; index < cnode->inputs().size(); index++) {
auto node_input = cnode->input(index);
auto switch_input = CreateSwitchInput(node_input, graph);
cnode_inputs.emplace_back(switch_input);
cnode_inputs->emplace_back(switch_input);
}
} else {
// get primitive of old node
auto prim = AnfAlgo::GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(prim);
// push attr to inputs[0] of new cnode
cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))};
}

if (!AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
for (size_t input_idx = kFirstDataInputIndex; input_idx < cnode->inputs().size(); input_idx++) {
auto anf = cnode->input(input_idx);
MS_EXCEPTION_IF_NULL(anf);
// anf has been created before
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf));
continue;
} else if (IsValueNode<None>(anf)) {
continue;
@@ -692,6 +685,32 @@ CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) {
MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
}
}
}

CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> cnode_inputs;
auto attr_input = cnode->input(kAnfPrimitiveIndex);
MS_EXCEPTION_IF_NULL(attr_input);
if (IsValueNode<FuncGraph>(attr_input)) {
// cnode is a graph or a call
cnode_inputs = CreateValueNode(cnode, graph);
} else if (attr_input->isa<CNode>()) {
// cnode ia a call (partial/switch/switch_layer)
// 1. take the args of call to the partial node, as the real_args to call switch's or switch_layer's child graph
// 2. the call in frontend is map to the partial/switch/switch_layer in backend and haven't been created
cnode_inputs = CreateSwitchOrPartialNode(cnode, graph);
} else {
// get primitive of old node
auto prim = AnfAlgo::GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(prim);
// push attr to inputs[0] of new cnode
cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))};
}
// handle inputs of cnode except primitive
CreateCNodeInputs(cnode, graph, &cnode_inputs);

TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
auto new_cnode = graph->NewCNode(cnode_inputs);
TraceManager::EndTrace();


+ 2
- 0
mindspore/ccsrc/backend/session/session_basic.h View File

@@ -111,6 +111,8 @@ class SessionBasic {
private:
CNodePtr CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph);
std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph);
std::vector<AnfNodePtr> CreateValueNode(const CNodePtr &cnode, KernelGraph *graph);
void CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs);

protected:
virtual void SetSummaryNodes(KernelGraph *graph);


Loading…
Cancel
Save