Browse Source

adapte to remove inline

merge me commit for remove inline

deal witch multiple cases of switch in ConstructKernelGraph

deal with switch and call cases in ConstructKernelGraph

fix bug and rebase master

ConstructKernelGraph adapte to remove inline

fix InsertMultipleAssignToGraph bug

add graph input to new graph which is created for switch input

replace CreateNewParameterFromCNode to NewParameter in order to set new
parameter's abstract and kernel_info

avoids create a new switch repeatedly when the cnode is a call switch without real input

null pointer check

update frontend code

Revert "update frontend code"

This reverts commit ce1f600d1e.

update frontend code PR_2948

fix bug of CheckLabalIndex

handle switch_layer in ConstructKernelGraph

add attr for assign node to avoid erasing by cse pass

cherry-pick ms commit[59b35f690d]:temporary avoid list getitem problem

rebase master

Revert "cherry-pick ms commit[59b35f690d]:temporary avoid list getitem problem"

This reverts commit 74c258f942.

Revert "handle switch_layer in ConstructKernelGraph"

This reverts commit cb5367f02d.

Revert "update frontend code PR_2948"

This reverts commit 234ac58340.

Revert "merge me commit for remove inline"

This reverts commit 55c0ebd42b.

fix diff after rebase master

doing remove inline in me

overwrite FindNodePrimitive

Revert "doing remove inline in me"

This reverts commit b42e893125.
tags/v0.7.0-beta
wenchunjiang 5 years ago
parent
commit
b24943d496
10 changed files with 162 additions and 141 deletions
  1. +13
    -14
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
  2. +1
    -1
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.h
  3. +44
    -66
      mindspore/ccsrc/backend/session/ascend_control_parser.cc
  4. +3
    -3
      mindspore/ccsrc/backend/session/ascend_control_parser.h
  5. +2
    -2
      mindspore/ccsrc/backend/session/ascend_session.cc
  6. +16
    -3
      mindspore/ccsrc/backend/session/kernel_graph.cc
  7. +1
    -0
      mindspore/ccsrc/backend/session/kernel_graph.h
  8. +73
    -46
      mindspore/ccsrc/backend/session/session_basic.cc
  9. +5
    -5
      mindspore/ccsrc/backend/session/session_basic.h
  10. +4
    -1
      mindspore/ccsrc/utils/utils.h

+ 13
- 14
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc View File

@@ -1031,31 +1031,29 @@ FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node)
return func_graph; return func_graph;
} }


std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CNodePtr &call_node) {
MS_EXCEPTION_IF_NULL(call_node);
if (!AnfAlgo::CheckPrimitiveType(call_node, std::make_shared<Primitive>("call"))) {
MS_LOG(EXCEPTION) << "Anf node: " << call_node->DebugString() << "is not a call node.";
std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallSwitchKernelGraph(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch))) {
MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a call or switch node.";
} }
auto input1 = call_node->input(1);
MS_EXCEPTION_IF_NULL(input1);
if (input1->isa<ValueNode>()) {
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) {
auto input1 = cnode->input(kCallKernelGraphIndex);
MS_EXCEPTION_IF_NULL(input1);
auto value_node = input1->cast<ValueNodePtr>(); auto value_node = input1->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node); MS_EXCEPTION_IF_NULL(value_node);
auto kernel_graph = value_node->value(); auto kernel_graph = value_node->value();
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
return {kernel_graph->cast<KernelGraphPtr>()}; return {kernel_graph->cast<KernelGraphPtr>()};
} else if (input1->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) {
auto switch_node = input1->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(switch_node);
auto get_switch_kernel_graph = [switch_node](size_t input_index) -> KernelGraphPtr {
auto partial = switch_node->input(input_index);
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
auto get_switch_kernel_graph = [cnode](size_t input_index) -> KernelGraphPtr {
auto partial = cnode->input(input_index);
MS_EXCEPTION_IF_NULL(partial); MS_EXCEPTION_IF_NULL(partial);
if (IsValueNode<KernelGraph>(partial)) { if (IsValueNode<KernelGraph>(partial)) {
return GetValueNode<KernelGraphPtr>(partial); return GetValueNode<KernelGraphPtr>(partial);
} }
auto partial_cnode = partial->cast<CNodePtr>(); auto partial_cnode = partial->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(partial_cnode); MS_EXCEPTION_IF_NULL(partial_cnode);
auto graph_node = partial_cnode->input(1);
auto graph_node = partial_cnode->input(kCallKernelGraphIndex);
MS_EXCEPTION_IF_NULL(graph_node); MS_EXCEPTION_IF_NULL(graph_node);
auto graph_value_node = graph_node->cast<ValueNodePtr>(); auto graph_value_node = graph_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(graph_value_node); MS_EXCEPTION_IF_NULL(graph_value_node);
@@ -1064,7 +1062,8 @@ std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CN
auto child_graph = graph_value->cast<KernelGraphPtr>(); auto child_graph = graph_value->cast<KernelGraphPtr>();
return child_graph; return child_graph;
}; };
return {get_switch_kernel_graph(2), get_switch_kernel_graph(3)};
return {get_switch_kernel_graph(kSwitchTrueKernelGraphIndex),
get_switch_kernel_graph(kSwitchFalseKernelGraphIndex)};
} }
return {}; return {};
} }


+ 1
- 1
mindspore/ccsrc/backend/session/anf_runtime_algorithm.h View File

@@ -201,7 +201,7 @@ class AnfRuntimeAlgorithm {
static bool IsCommunicationOp(const AnfNodePtr &node); static bool IsCommunicationOp(const AnfNodePtr &node);
static bool IsGetNext(const NotNull<AnfNodePtr> &node); static bool IsGetNext(const NotNull<AnfNodePtr> &node);
static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node);
static std::vector<KernelGraphPtr> GetCallNodeKernelGraph(const CNodePtr &call_node);
static std::vector<KernelGraphPtr> GetCallSwitchKernelGraph(const CNodePtr &cnode);
static bool IsSwitchCall(const CNodePtr &call_node); static bool IsSwitchCall(const CNodePtr &call_node);
static bool IsScalarInput(const CNodePtr &cnode, size_t index); static bool IsScalarInput(const CNodePtr &cnode, size_t index);
static bool IsScalarOutput(const CNodePtr &cnode, size_t index); static bool IsScalarOutput(const CNodePtr &cnode, size_t index);


+ 44
- 66
mindspore/ccsrc/backend/session/ascend_control_parser.cc View File

@@ -361,27 +361,22 @@ void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
} }
} }


std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> AscendControlParser::ParseCallNode(
NotNull<CNodePtr> call_node) {
std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> AscendControlParser::ParseCallSwitchNode(
NotNull<CNodePtr> cnode) {
std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> ret; std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> ret;
if (!IsPrimitiveCNode(call_node.get(), prim::kPrimCall)) {
MS_LOG(EXCEPTION) << "Node " << call_node->DebugString() << " is not a call node.";
}
if (call_node->size() <= kCNodeCallArg) {
MS_LOG(EXCEPTION) << "Node " << call_node->DebugString() << " has invalid inputs size " << call_node->size();
}
const std::vector<AnfNodePtr> &call_node_inputs = call_node->inputs();
auto call_arg = call_node_inputs[kCNodeCallArg];
MS_EXCEPTION_IF_NULL(call_arg);
if (IsValueNode<KernelGraph>(call_arg)) {

if (IsPrimitiveCNode(cnode.get(), prim::kPrimCall)) {
if (cnode->size() <= kCNodeCallArg) {
MS_LOG(EXCEPTION) << "Call node " << cnode->DebugString() << " has invalid inputs size " << cnode->size();
}
auto call_arg = cnode->input(kCNodeCallArg);
MS_EXCEPTION_IF_NULL(call_arg);
ret.emplace_back(GetValueNode<KernelGraphPtr>(call_arg), ret.emplace_back(GetValueNode<KernelGraphPtr>(call_arg),
std::vector<AnfNodePtr>(call_node_inputs.begin() + kCNodeCallArg + 1, call_node_inputs.end()));
} else if (IsPrimitiveCNode(call_arg, prim::kPrimSwitch)) {
auto switch_cnode = call_arg->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(switch_cnode);
const std::vector<AnfNodePtr> &switch_inputs = switch_cnode->inputs();
if (switch_inputs.size() <= kCNodeSwitchCond) {
MS_LOG(EXCEPTION) << "Node " << switch_cnode->DebugString() << " has invalid inputs size "
std::vector<AnfNodePtr>(cnode->inputs().begin() + kCNodeCallArg + 1, cnode->inputs().end()));
} else if (IsPrimitiveCNode(cnode.get(), prim::kPrimSwitch)) {
const std::vector<AnfNodePtr> &switch_inputs = cnode->inputs();
if (switch_inputs.size() < kCNodeSwitchLength) {
MS_LOG(EXCEPTION) << "Switch node " << cnode->DebugString() << " has invalid inputs size "
<< switch_inputs.size(); << switch_inputs.size();
} }
for (auto iter = switch_inputs.begin() + kCNodeSwitchCond + 1; iter != switch_inputs.end(); ++iter) { for (auto iter = switch_inputs.begin() + kCNodeSwitchCond + 1; iter != switch_inputs.end(); ++iter) {
@@ -389,7 +384,7 @@ std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> AscendControlPar
ret.emplace_back(target_graph, args); ret.emplace_back(target_graph, args);
} }
} else { } else {
MS_LOG(EXCEPTION) << "Unsupport call node: " << call_node->DebugString(5);
MS_LOG(EXCEPTION) << "Unsupport call node: " << cnode->DebugString(5);
} }
return ret; return ret;
} }
@@ -406,11 +401,11 @@ void AscendControlParser::ChildGraphDataAssign(
const std::vector<CNodePtr> &nodes = kg->execution_order(); const std::vector<CNodePtr> &nodes = kg->execution_order();


for (auto &node : nodes) { for (auto &node : nodes) {
if (!IsPrimitiveCNode(node, prim::kPrimCall)) {
if (!(IsPrimitiveCNode(node, prim::kPrimCall) || IsPrimitiveCNode(node, prim::kPrimSwitch))) {
continue; continue;
} }


auto child_graph_list = ParseCallNode(NOT_NULL(node));
auto child_graph_list = ParseCallSwitchNode(NOT_NULL(node));
for (auto &[child_graph, args] : child_graph_list) { for (auto &[child_graph, args] : child_graph_list) {
MS_EXCEPTION_IF_NULL(child_graph); MS_EXCEPTION_IF_NULL(child_graph);
const std::vector<AnfNodePtr> &params = child_graph->inputs(); const std::vector<AnfNodePtr> &params = child_graph->inputs();
@@ -425,7 +420,6 @@ void AscendControlParser::ChildGraphDataAssign(
link_list->emplace_back(args[i], params[i]); link_list->emplace_back(args[i], params[i]);
continue; continue;
} }

InsertMultipleAssignToGraph(kg, node, NOT_NULL(args[i]), NOT_NULL(params[i])); InsertMultipleAssignToGraph(kg, node, NOT_NULL(args[i]), NOT_NULL(params[i]));
} }
} }
@@ -475,30 +469,20 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr
for (size_t i = 0; i < nodes.size(); ++i) { for (size_t i = 0; i < nodes.size(); ++i) {
auto &cnode = nodes[i]; auto &cnode = nodes[i];
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() < kCNodePrim + 1) {
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
}
AnfNodePtr fn = cnode->input(kAnfPrimitiveIndex);
if (!IsPrimitive(fn, prim::kPrimCall) || cnode->size() < kCNodeCallArg + 1) {
MS_LOG(DEBUG) << "Continue node " << cnode->DebugString();
if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) ||
AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch) ||
AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer))) {
continue; continue;
} }
AnfNodePtr arg = cnode->input(kFirstDataInputIndex);
MS_EXCEPTION_IF_NULL(arg);
if (IsValueNode<KernelGraph>(arg)) {

if (IsPrimitiveCNode(cnode, prim::kPrimCall)) {
RecurseCall(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); RecurseCall(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo);
} else if (!arg->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Unknown type call node " << cnode->DebugString();
} else if (IsPrimitiveCNode(arg->cast<CNodePtr>(), prim::kPrimSwitch)) {
auto arg_cnode = arg->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(arg_cnode);
cnode->set_inputs(arg_cnode->inputs());
} else if (IsPrimitiveCNode(cnode, prim::kPrimSwitch)) {
RecurseSwitch(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); RecurseSwitch(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo);
} else if (IsPrimitiveCNode(arg->cast<CNodePtr>(), prim::kPrimSwitchLayer)) {
auto arg_cnode = arg->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(arg_cnode);
cnode->set_inputs(arg_cnode->inputs());
} else if (IsPrimitiveCNode(cnode, prim::kPrimSwitchLayer)) {
RecurseSwitchLayer(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); RecurseSwitchLayer(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo);
} else {
MS_LOG(EXCEPTION) << "Unexpected node: " << cnode->DebugString();
} }
} }
kg->SetExecOrderByDefault(); kg->SetExecOrderByDefault();
@@ -699,31 +683,22 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> fr
continue; continue;
} }
const auto &from_graph_exe_order = from_graph->execution_order(); const auto &from_graph_exe_order = from_graph->execution_order();
std::vector<CNodePtr> real_exe_order(from_graph_exe_order.size());
size_t real_exe_order_size = 0;
std::copy_if(from_graph_exe_order.begin(), from_graph_exe_order.end(), real_exe_order.begin(),
[&real_exe_order_size](const CNodePtr &node) -> bool {
return (IsPrimitiveCNode(node, prim::kPrimSwitch) || IsPrimitiveCNode(node, prim::kPrimPartial))
? false
: (++real_exe_order_size, true);
});
real_exe_order.resize(real_exe_order_size);
if (jump_node == nullptr) { if (jump_node == nullptr) {
if (!real_exe_order.empty()) {
InsertControlDependToGraph(from_graph, NOT_NULL(*(real_exe_order.rbegin())), NOT_NULL(assign_node));
if (!from_graph_exe_order.empty()) {
InsertControlDependToGraph(from_graph, NOT_NULL(*(from_graph_exe_order.rbegin())), NOT_NULL(assign_node));
} else { } else {
InsertDependToGraph(from_graph, NOT_NULL(assign_node)); InsertDependToGraph(from_graph, NOT_NULL(assign_node));
} }
continue; continue;
} }


auto jump_node_iter = std::find(real_exe_order.begin(), real_exe_order.end(), jump_node);
if (jump_node_iter == real_exe_order.end()) {
auto jump_node_iter = std::find(from_graph_exe_order.begin(), from_graph_exe_order.end(), jump_node);
if (jump_node_iter == from_graph_exe_order.end()) {
MS_LOG(EXCEPTION) << "Cannot find jump node " << jump_node->DebugString() << " in graph " MS_LOG(EXCEPTION) << "Cannot find jump node " << jump_node->DebugString() << " in graph "
<< from_graph->ToString(); << from_graph->ToString();
} }
// insert assign between jump_node -1 and jump_node // insert assign between jump_node -1 and jump_node
if (jump_node_iter != real_exe_order.begin()) {
if (jump_node_iter != from_graph_exe_order.begin()) {
InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node)); InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node));
} }
InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node)); InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node));
@@ -772,6 +747,7 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr>
std::vector<CNodePtr> execution_order; std::vector<CNodePtr> execution_order;
uint32_t child_order_index = 0; uint32_t child_order_index = 0;
for (auto &node : cnodes) { for (auto &node : cnodes) {
uint32_t child_graph_index = 0;
execution_order.push_back(node); execution_order.push_back(node);
if (node == graph->get_end_goto()) { if (node == graph->get_end_goto()) {
continue; continue;
@@ -779,7 +755,7 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr>
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) {
std::vector<uint32_t> label_switch_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(node, kAttrLabelSwitchList); std::vector<uint32_t> label_switch_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(node, kAttrLabelSwitchList);
for (auto iter = label_switch_list.rbegin(); iter != label_switch_list.rend(); ++iter) { for (auto iter = label_switch_list.rbegin(); iter != label_switch_list.rend(); ++iter) {
if (!CheckLabelIndex(child_order_index, *iter, node, graph)) {
if (!CheckLabelIndex(child_graph_index++, *iter, node)) {
MS_LOG(EXCEPTION) << "Check label index fail"; MS_LOG(EXCEPTION) << "Check label index fail";
} }
if (child_order_index >= graph->child_graph_order().size()) { if (child_order_index >= graph->child_graph_order().size()) {
@@ -791,9 +767,12 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr>
} }
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) {
uint32_t label_index = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex); uint32_t label_index = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
if (!CheckLabelIndex(child_order_index, label_index, node, graph)) {
if (!CheckLabelIndex(child_graph_index, label_index, node)) {
MS_LOG(EXCEPTION) << "Check label index fail"; MS_LOG(EXCEPTION) << "Check label index fail";
} }
if (child_order_index >= graph->child_graph_order().size()) {
MS_LOG(EXCEPTION) << "Index out of range:" << graph->child_graph_order().size();
}
auto child_graph = graph->child_graph_order()[child_order_index++]; auto child_graph = graph->child_graph_order()[child_order_index++];
auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo); auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo);
execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end());
@@ -804,15 +783,14 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr>
return execution_order; return execution_order;
} }


bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cur_label,
NotNull<KernelGraphPtr> graph) {
const std::vector<std::shared_ptr<KernelGraph>> &child_graph_order = graph->child_graph_order();
bool AscendControlParser::CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cur_label) {
auto child_graphs = AnfAlgo::GetNodeAttr<std::vector<KernelGraphPtr>>(cur_label, kAttrChildGraph);
// check index and child order size // check index and child order size
if (child_graph_order.size() <= IntToSize(order_index)) {
MS_LOG(EXCEPTION) << "Child graph order is wrong, graph " << graph->ToString() << " child graph size "
<< child_graph_order.size() << " goto index " << order_index;
if (child_graphs.size() <= IntToSize(index)) {
MS_LOG(EXCEPTION) << "Child graph index is wrong, current node " << cur_label->ToString() << " child graph size "
<< child_graphs.size() << " goto index " << index;
} }
auto child_graph = child_graph_order[order_index];
auto child_graph = child_graphs[index];
MS_EXCEPTION_IF_NULL(child_graph); MS_EXCEPTION_IF_NULL(child_graph);


// get start_label_set_index of child graph // get start_label_set_index of child graph
@@ -822,7 +800,7 @@ bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_i
MS_EXCEPTION_IF_NULL(cur_label); MS_EXCEPTION_IF_NULL(cur_label);
MS_EXCEPTION_IF_NULL(start_label_set); MS_EXCEPTION_IF_NULL(start_label_set);
MS_LOG(WARNING) << cur_label->DebugString() << " index " << label_index << " but " << start_label_set->DebugString() MS_LOG(WARNING) << cur_label->DebugString() << " index " << label_index << " but " << start_label_set->DebugString()
<< " index " << start_label_set_index << " current child graph order : " << order_index;
<< " index " << start_label_set_index;
return false; return false;
} else { } else {
return true; return true;


+ 3
- 3
mindspore/ccsrc/backend/session/ascend_control_parser.h View File

@@ -64,13 +64,13 @@ class AscendControlParser {
const CNodePtr &last_label); const CNodePtr &last_label);


static AnfNodePtr InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to); static AnfNodePtr InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
static std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> ParseCallNode(NotNull<CNodePtr> call_node);
static std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> ParseCallSwitchNode(
NotNull<CNodePtr> call_node);
static std::tuple<KernelGraphPtr, std::vector<AnfNodePtr>> ParsePartial(NotNull<AnfNodePtr> node); static std::tuple<KernelGraphPtr, std::vector<AnfNodePtr>> ParsePartial(NotNull<AnfNodePtr> node);
static void AttachChildGraphToReturnNode(NotNull<KernelGraphPtr> graph, static void AttachChildGraphToReturnNode(NotNull<KernelGraphPtr> graph,
const NotNull<std::set<KernelGraphPtr> *> memo); const NotNull<std::set<KernelGraphPtr> *> memo);
// root graph order // root graph order
static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode,
NotNull<KernelGraphPtr> graph);
static bool CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cnode);
static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph, static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph,
const NotNull<std::set<KernelGraphPtr> *> memo); const NotNull<std::set<KernelGraphPtr> *> memo);
}; };


+ 2
- 2
mindspore/ccsrc/backend/session/ascend_session.cc View File

@@ -885,7 +885,7 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu
std::map<AnfNodePtr, AnfNodePtr> need_replace_list; std::map<AnfNodePtr, AnfNodePtr> need_replace_list;
auto node_list = GetCNodes(TopoSort(graph->get_return())); auto node_list = GetCNodes(TopoSort(graph->get_return()));
for (auto &node : node_list) { for (auto &node : node_list) {
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) {
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) {
// create a parameter to store the output of multiple branch and set the parameter as the condition graph's output // create a parameter to store the output of multiple branch and set the parameter as the condition graph's output
auto output_param = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract())); auto output_param = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract()));
MS_EXCEPTION_IF_NULL(graph->MutableInputs()); MS_EXCEPTION_IF_NULL(graph->MutableInputs());
@@ -898,7 +898,7 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu
MS_LOG(INFO) << "Create parameter " << output_param->DebugString() << " for call node " << node->DebugString() MS_LOG(INFO) << "Create parameter " << output_param->DebugString() << " for call node " << node->DebugString()
<< ", depend node is " << depend->DebugString(); << ", depend node is " << depend->DebugString();
// insert assign in order to transfer child graph output to parameter // insert assign in order to transfer child graph output to parameter
auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node);
auto child_graphs = AnfAlgo::GetCallSwitchKernelGraph(node);
for (auto &child_graph : child_graphs) { for (auto &child_graph : child_graphs) {
MS_EXCEPTION_IF_NULL(child_graph); MS_EXCEPTION_IF_NULL(child_graph);
// If graph has no output, the graph is the true graph of while and will call condition graph, no need insert // If graph has no output, the graph is the true graph of while and will call condition graph, no need insert


+ 16
- 3
mindspore/ccsrc/backend/session/kernel_graph.cc View File

@@ -67,7 +67,7 @@ std::vector<AnfNodePtr> GetCallRealOutputs(const AnfNodePtr &call_node) {
return {node}; return {node};
} }
std::vector<AnfNodePtr> real_inputs; std::vector<AnfNodePtr> real_inputs;
auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node->cast<CNodePtr>());
auto child_graphs = AnfAlgo::GetCallSwitchKernelGraph(node->cast<CNodePtr>());
for (const auto &child_graph : child_graphs) { for (const auto &child_graph : child_graphs) {
if (child_graph->get_output_null()) { if (child_graph->get_output_null()) {
continue; continue;
@@ -931,6 +931,18 @@ std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primi
return result; return result;
} }


std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const std::vector<PrimitivePtr> &primitive_list) const {
std::vector<CNodePtr> result;
for (const auto &anf : execution_order_) {
for (const auto &primitive : primitive_list) {
if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) {
result.push_back(anf->cast<CNodePtr>());
}
}
}
return result;
}

void KernelGraph::PrintGraphExecuteOrder() const { void KernelGraph::PrintGraphExecuteOrder() const {
MS_LOG(INFO) << "Graph:" << graph_id_ << "execution order"; MS_LOG(INFO) << "Graph:" << graph_id_ << "execution order";
for (size_t i = 0; i < execution_order_.size(); i++) { for (size_t i = 0; i < execution_order_.size(); i++) {
@@ -1078,11 +1090,12 @@ bool KernelGraph::IsUniqueTargetInternalOutput(const AnfNodePtr &node, int outpu
void KernelGraph::UpdateChildGraphOrder() { void KernelGraph::UpdateChildGraphOrder() {
MS_LOG(INFO) << "Update " << ToString() << " child graph order."; MS_LOG(INFO) << "Update " << ToString() << " child graph order.";
SetExecOrderByDefault(); SetExecOrderByDefault();
auto call_nodes = FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name()));
auto call_nodes = FindNodeByPrimitive(
{std::make_shared<Primitive>(prim::kPrimCall->name()), std::make_shared<Primitive>(prim::kPrimSwitch->name())});
std::vector<KernelGraphPtr> child_graph_order; std::vector<KernelGraphPtr> child_graph_order;
for (auto &call_node : call_nodes) { for (auto &call_node : call_nodes) {
MS_EXCEPTION_IF_NULL(call_node); MS_EXCEPTION_IF_NULL(call_node);
auto call_child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast<CNodePtr>());
auto call_child_graphs = AnfAlgo::GetCallSwitchKernelGraph(call_node->cast<CNodePtr>());
for (const auto &child_graph : call_child_graphs) { for (const auto &child_graph : call_child_graphs) {
MS_EXCEPTION_IF_NULL(child_graph); MS_EXCEPTION_IF_NULL(child_graph);
if (child_graph != parent_graph_) { if (child_graph != parent_graph_) {


+ 1
- 0
mindspore/ccsrc/backend/session/kernel_graph.h View File

@@ -131,6 +131,7 @@ class KernelGraph : public FuncGraph {
void set_parent_graph(const std::shared_ptr<KernelGraph> &parent_graph) { parent_graph_ = parent_graph; } void set_parent_graph(const std::shared_ptr<KernelGraph> &parent_graph) { parent_graph_ = parent_graph; }
// find anf node in graph // find anf node in graph
std::vector<CNodePtr> FindNodeByPrimitive(const PrimitivePtr &primitive) const; std::vector<CNodePtr> FindNodeByPrimitive(const PrimitivePtr &primitive) const;
std::vector<CNodePtr> FindNodeByPrimitive(const std::vector<PrimitivePtr> &primitive_list) const;
// used to dump ir // used to dump ir
std::string ToString() const override; std::string ToString() const override;




+ 73
- 46
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -547,45 +547,26 @@ CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGra
MS_EXCEPTION_IF_NULL(node_input); MS_EXCEPTION_IF_NULL(node_input);
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
// switch input generalizes partial // switch input generalizes partial
if (AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial) ||
AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimCall)) {
return node_input->cast<CNodePtr>();
}
if (node_input->isa<CNode>()) {
MS_LOG(EXCEPTION) << "If switch input is " << node_input->DebugString() << ", it mast be partial or call.";
}
std::vector<AnfNodePtr> partial_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name()))}; std::vector<AnfNodePtr> partial_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name()))};
if (node_input->isa<ValueNode>() && IsValueNode<FuncGraph>(node_input)) {
partial_inputs.emplace_back(node_input);
auto partial_node = graph->NewCNode(partial_inputs);
return partial_node;
if (AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial)) {
auto partial_node = graph->GetBackendAnfByFrontAnf(node_input);
return partial_node->cast<CNodePtr>();
} else if (node_input->isa<ValueNode>() && IsValueNode<FuncGraph>(node_input)) {
partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input));
} else {
KernelGraphPtr kernel_graph = NewKernelGraph();
MS_EXCEPTION_IF_NULL(kernel_graph);
auto parameter = CreateNewParameterFromCNode(graph->GetBackendAnfByFrontAnf(node_input), true, kernel_graph.get());
auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name()));
auto return_node = kernel_graph->NewCNode({primitive, parameter});
kernel_graph->set_return(return_node);
partial_inputs.emplace_back(std::make_shared<ValueNode>(kernel_graph));
partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input));
} }
KernelGraphPtr kernel_graph = NewKernelGraph();
MS_EXCEPTION_IF_NULL(kernel_graph);
kernel_graph->set_output(graph->GetBackendAnfByFrontAnf(node_input));
partial_inputs.emplace_back(std::make_shared<ValueNode>(kernel_graph));
auto partial_node = graph->NewCNode(partial_inputs); auto partial_node = graph->NewCNode(partial_inputs);
return partial_node; return partial_node;
} }


CNodePtr SessionBasic::HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(graph);
auto node = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(node);
if (node->inputs().size() < kSwitchInputSize) {
MS_LOG(EXCEPTION) << "Switch input size less than " << kSwitchInputSize;
}
auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimSwitch->name()));
std::vector<AnfNodePtr> switch_inputs = {primitive, node->input(1)};
for (size_t index = 2; index < node->inputs().size(); index++) {
auto input = CreateSwitchInput(node->input(index), graph);
switch_inputs.emplace_back(input);
}
auto switch_node = graph->NewCNode(switch_inputs);
return switch_node;
}

std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) { std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
@@ -611,14 +592,33 @@ std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &
}); });
return cnode_inputs; return cnode_inputs;
} else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
auto switch_node = HandleSwitchInputs(cnode_input, graph);
auto switch_cnode = cnode_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(switch_cnode);
std::vector<AnfNodePtr> switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex),
switch_cnode->input(kFirstDataInputIndex)};
for (size_t index = kFirstBranchInSwitch; index < switch_cnode->inputs().size(); index++) {
auto node = switch_cnode->input(index);
// there is real input in call, should put it to true and false branch in switch
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
auto partial_node = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(partial_node);
std::vector<AnfNodePtr> partial_inputs = partial_node->inputs();
partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex)));
auto new_partial = graph->NewCNode(partial_inputs);
switch_inputs.emplace_back(new_partial);
}
}
if (switch_inputs.size() < kSwitchInputSize) {
MS_LOG(EXCEPTION) << "Switch inputs size: " << switch_inputs.size() << "less than " << kSwitchInputSize;
}
auto switch_node = graph->NewCNode(switch_inputs);
cnode_inputs.emplace_back(switch_node); cnode_inputs.emplace_back(switch_node);
return cnode_inputs; return cnode_inputs;
} }
MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch."; MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch.";
} }


CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) {
CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> cnode_inputs; std::vector<AnfNodePtr> cnode_inputs;
@@ -642,7 +642,22 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
} }
} }
} else if (attr_input->isa<CNode>()) { } else if (attr_input->isa<CNode>()) {
cnode_inputs = CreateSwitchOrPartialNode(cnode, graph);
auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
if (cnode->inputs().size() < 2 && AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
auto switch_cnode = cnode_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(switch_cnode);
cnode_inputs = switch_cnode->inputs();
} else {
cnode_inputs = CreateSwitchOrPartialNode(cnode, graph);
}
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
cnode_inputs = {graph->GetBackendAnfByFrontAnf(cnode->input(kAnfPrimitiveIndex)),
graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex))};
for (size_t index = kFirstBranchInSwitch; index < cnode->inputs().size(); index++) {
auto node_input = cnode->input(index);
auto switch_input = CreateSwitchInput(node_input, graph);
cnode_inputs.emplace_back(switch_input);
}
} else { } else {
// get primitive of old node // get primitive of old node
auto prim = AnfAlgo::GetCNodePrimitive(cnode); auto prim = AnfAlgo::GetCNodePrimitive(cnode);
@@ -651,21 +666,33 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))}; cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))};
} }


for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
auto anf = cnode->input(input_idx);
MS_EXCEPTION_IF_NULL(anf);
// anf has been created before
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
continue;
} else if (IsValueNode<None>(anf)) {
continue;
if (!AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
for (size_t input_idx = kFirstDataInputIndex; input_idx < cnode->inputs().size(); input_idx++) {
auto anf = cnode->input(input_idx);
MS_EXCEPTION_IF_NULL(anf);
// anf has been created before
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
continue;
} else if (IsValueNode<None>(anf)) {
continue;
}
MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
} }
MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
} }
TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info())); TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
auto new_cnode = graph->NewCNode(cnode_inputs); auto new_cnode = graph->NewCNode(cnode_inputs);
TraceManager::EndTrace(); TraceManager::EndTrace();

// if the cnode is call switch, remove call
if (new_cnode->inputs().size() > 1) {
auto first_input = new_cnode->input(kFirstDataInputIndex);
if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) &&
AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitch)) {
new_cnode = first_input->cast<CNodePtr>();
}
}

return new_cnode; return new_cnode;
} }




+ 5
- 5
mindspore/ccsrc/backend/session/session_basic.h View File

@@ -86,11 +86,7 @@ class SessionBasic {


CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph, CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph,
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode); std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph);

CNodePtr CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph);
CNodePtr HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph);
std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph);
CNodePtr CreateNewCNode(CNodePtr cnode, KernelGraph *graph);


// get graph id in child graphs by ME front anf node pointer // get graph id in child graphs by ME front anf node pointer
virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; } virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; }
@@ -112,6 +108,10 @@ class SessionBasic {
} }
#endif #endif


private:
CNodePtr CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph);
std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph);

protected: protected:
virtual void SetSummaryNodes(KernelGraph *graph); virtual void SetSummaryNodes(KernelGraph *graph);
// Get graph by graph id ,if not exist return null ptr // Get graph by graph id ,if not exist return null ptr


+ 4
- 1
mindspore/ccsrc/utils/utils.h View File

@@ -277,11 +277,14 @@ const int kValueNodeTensorMask = 2;
// define special index in special node // define special index in special node
constexpr auto kAnfPrimitiveIndex = 0; constexpr auto kAnfPrimitiveIndex = 0;
constexpr auto kFirstDataInputIndex = 1; constexpr auto kFirstDataInputIndex = 1;
constexpr auto kAnfPartialFuncGraphIndex = 1;
constexpr auto kRealInputNodeIndexInTupleGetItem = 1; constexpr auto kRealInputNodeIndexInTupleGetItem = 1;
constexpr auto kInputNodeOutputIndexInTupleGetItem = 2; constexpr auto kInputNodeOutputIndexInTupleGetItem = 2;
constexpr auto kTupleGetItemInputSize = 3; constexpr auto kTupleGetItemInputSize = 3;
constexpr auto kSwitchInputSize = 4; constexpr auto kSwitchInputSize = 4;
constexpr auto kFirstBranchInSwitch = 2;
constexpr auto kCallKernelGraphIndex = 1;
constexpr auto kSwitchTrueKernelGraphIndex = 2;
constexpr auto kSwitchFalseKernelGraphIndex = 3;
// index define of control depend // index define of control depend
constexpr auto kControlDependPriorIndex = 1; constexpr auto kControlDependPriorIndex = 1;
constexpr auto kControlDependBehindIndex = 2; constexpr auto kControlDependBehindIndex = 2;


Loading…
Cancel
Save