Browse Source

!4295 adapte to remove inline and generalization of ir

Merge pull request !4295 from wenchunjiang/remove_inline_1
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
d3dfeee195
10 changed files with 162 additions and 141 deletions
  1. +13
    -14
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
  2. +1
    -1
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.h
  3. +44
    -66
      mindspore/ccsrc/backend/session/ascend_control_parser.cc
  4. +3
    -3
      mindspore/ccsrc/backend/session/ascend_control_parser.h
  5. +2
    -2
      mindspore/ccsrc/backend/session/ascend_session.cc
  6. +16
    -3
      mindspore/ccsrc/backend/session/kernel_graph.cc
  7. +1
    -0
      mindspore/ccsrc/backend/session/kernel_graph.h
  8. +73
    -46
      mindspore/ccsrc/backend/session/session_basic.cc
  9. +5
    -5
      mindspore/ccsrc/backend/session/session_basic.h
  10. +4
    -1
      mindspore/ccsrc/utils/utils.h

+ 13
- 14
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc View File

@@ -1031,31 +1031,29 @@ FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node)
return func_graph;
}

std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CNodePtr &call_node) {
MS_EXCEPTION_IF_NULL(call_node);
if (!AnfAlgo::CheckPrimitiveType(call_node, std::make_shared<Primitive>("call"))) {
MS_LOG(EXCEPTION) << "Anf node: " << call_node->DebugString() << "is not a call node.";
std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallSwitchKernelGraph(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch))) {
MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a call or switch node.";
}
auto input1 = call_node->input(1);
MS_EXCEPTION_IF_NULL(input1);
if (input1->isa<ValueNode>()) {
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) {
auto input1 = cnode->input(kCallKernelGraphIndex);
MS_EXCEPTION_IF_NULL(input1);
auto value_node = input1->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto kernel_graph = value_node->value();
MS_EXCEPTION_IF_NULL(kernel_graph);
return {kernel_graph->cast<KernelGraphPtr>()};
} else if (input1->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) {
auto switch_node = input1->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(switch_node);
auto get_switch_kernel_graph = [switch_node](size_t input_index) -> KernelGraphPtr {
auto partial = switch_node->input(input_index);
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
auto get_switch_kernel_graph = [cnode](size_t input_index) -> KernelGraphPtr {
auto partial = cnode->input(input_index);
MS_EXCEPTION_IF_NULL(partial);
if (IsValueNode<KernelGraph>(partial)) {
return GetValueNode<KernelGraphPtr>(partial);
}
auto partial_cnode = partial->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(partial_cnode);
auto graph_node = partial_cnode->input(1);
auto graph_node = partial_cnode->input(kCallKernelGraphIndex);
MS_EXCEPTION_IF_NULL(graph_node);
auto graph_value_node = graph_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(graph_value_node);
@@ -1064,7 +1062,8 @@ std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CN
auto child_graph = graph_value->cast<KernelGraphPtr>();
return child_graph;
};
return {get_switch_kernel_graph(2), get_switch_kernel_graph(3)};
return {get_switch_kernel_graph(kSwitchTrueKernelGraphIndex),
get_switch_kernel_graph(kSwitchFalseKernelGraphIndex)};
}
return {};
}


+ 1
- 1
mindspore/ccsrc/backend/session/anf_runtime_algorithm.h View File

@@ -201,7 +201,7 @@ class AnfRuntimeAlgorithm {
static bool IsCommunicationOp(const AnfNodePtr &node);
static bool IsGetNext(const NotNull<AnfNodePtr> &node);
static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node);
static std::vector<KernelGraphPtr> GetCallNodeKernelGraph(const CNodePtr &call_node);
static std::vector<KernelGraphPtr> GetCallSwitchKernelGraph(const CNodePtr &cnode);
static bool IsSwitchCall(const CNodePtr &call_node);
static bool IsScalarInput(const CNodePtr &cnode, size_t index);
static bool IsScalarOutput(const CNodePtr &cnode, size_t index);


+ 44
- 66
mindspore/ccsrc/backend/session/ascend_control_parser.cc View File

@@ -361,27 +361,22 @@ void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
}
}

std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> AscendControlParser::ParseCallNode(
NotNull<CNodePtr> call_node) {
std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> AscendControlParser::ParseCallSwitchNode(
NotNull<CNodePtr> cnode) {
std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> ret;
if (!IsPrimitiveCNode(call_node.get(), prim::kPrimCall)) {
MS_LOG(EXCEPTION) << "Node " << call_node->DebugString() << " is not a call node.";
}
if (call_node->size() <= kCNodeCallArg) {
MS_LOG(EXCEPTION) << "Node " << call_node->DebugString() << " has invalid inputs size " << call_node->size();
}
const std::vector<AnfNodePtr> &call_node_inputs = call_node->inputs();
auto call_arg = call_node_inputs[kCNodeCallArg];
MS_EXCEPTION_IF_NULL(call_arg);
if (IsValueNode<KernelGraph>(call_arg)) {

if (IsPrimitiveCNode(cnode.get(), prim::kPrimCall)) {
if (cnode->size() <= kCNodeCallArg) {
MS_LOG(EXCEPTION) << "Call node " << cnode->DebugString() << " has invalid inputs size " << cnode->size();
}
auto call_arg = cnode->input(kCNodeCallArg);
MS_EXCEPTION_IF_NULL(call_arg);
ret.emplace_back(GetValueNode<KernelGraphPtr>(call_arg),
std::vector<AnfNodePtr>(call_node_inputs.begin() + kCNodeCallArg + 1, call_node_inputs.end()));
} else if (IsPrimitiveCNode(call_arg, prim::kPrimSwitch)) {
auto switch_cnode = call_arg->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(switch_cnode);
const std::vector<AnfNodePtr> &switch_inputs = switch_cnode->inputs();
if (switch_inputs.size() <= kCNodeSwitchCond) {
MS_LOG(EXCEPTION) << "Node " << switch_cnode->DebugString() << " has invalid inputs size "
std::vector<AnfNodePtr>(cnode->inputs().begin() + kCNodeCallArg + 1, cnode->inputs().end()));
} else if (IsPrimitiveCNode(cnode.get(), prim::kPrimSwitch)) {
const std::vector<AnfNodePtr> &switch_inputs = cnode->inputs();
if (switch_inputs.size() < kCNodeSwitchLength) {
MS_LOG(EXCEPTION) << "Switch node " << cnode->DebugString() << " has invalid inputs size "
<< switch_inputs.size();
}
for (auto iter = switch_inputs.begin() + kCNodeSwitchCond + 1; iter != switch_inputs.end(); ++iter) {
@@ -389,7 +384,7 @@ std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> AscendControlPar
ret.emplace_back(target_graph, args);
}
} else {
MS_LOG(EXCEPTION) << "Unsupport call node: " << call_node->DebugString(5);
MS_LOG(EXCEPTION) << "Unsupport call node: " << cnode->DebugString(5);
}
return ret;
}
@@ -406,11 +401,11 @@ void AscendControlParser::ChildGraphDataAssign(
const std::vector<CNodePtr> &nodes = kg->execution_order();

for (auto &node : nodes) {
if (!IsPrimitiveCNode(node, prim::kPrimCall)) {
if (!(IsPrimitiveCNode(node, prim::kPrimCall) || IsPrimitiveCNode(node, prim::kPrimSwitch))) {
continue;
}

auto child_graph_list = ParseCallNode(NOT_NULL(node));
auto child_graph_list = ParseCallSwitchNode(NOT_NULL(node));
for (auto &[child_graph, args] : child_graph_list) {
MS_EXCEPTION_IF_NULL(child_graph);
const std::vector<AnfNodePtr> &params = child_graph->inputs();
@@ -425,7 +420,6 @@ void AscendControlParser::ChildGraphDataAssign(
link_list->emplace_back(args[i], params[i]);
continue;
}

InsertMultipleAssignToGraph(kg, node, NOT_NULL(args[i]), NOT_NULL(params[i]));
}
}
@@ -475,30 +469,20 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr
for (size_t i = 0; i < nodes.size(); ++i) {
auto &cnode = nodes[i];
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() < kCNodePrim + 1) {
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
}
AnfNodePtr fn = cnode->input(kAnfPrimitiveIndex);
if (!IsPrimitive(fn, prim::kPrimCall) || cnode->size() < kCNodeCallArg + 1) {
MS_LOG(DEBUG) << "Continue node " << cnode->DebugString();
if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) ||
AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch) ||
AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer))) {
continue;
}
AnfNodePtr arg = cnode->input(kFirstDataInputIndex);
MS_EXCEPTION_IF_NULL(arg);
if (IsValueNode<KernelGraph>(arg)) {

if (IsPrimitiveCNode(cnode, prim::kPrimCall)) {
RecurseCall(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo);
} else if (!arg->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Unknown type call node " << cnode->DebugString();
} else if (IsPrimitiveCNode(arg->cast<CNodePtr>(), prim::kPrimSwitch)) {
auto arg_cnode = arg->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(arg_cnode);
cnode->set_inputs(arg_cnode->inputs());
} else if (IsPrimitiveCNode(cnode, prim::kPrimSwitch)) {
RecurseSwitch(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo);
} else if (IsPrimitiveCNode(arg->cast<CNodePtr>(), prim::kPrimSwitchLayer)) {
auto arg_cnode = arg->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(arg_cnode);
cnode->set_inputs(arg_cnode->inputs());
} else if (IsPrimitiveCNode(cnode, prim::kPrimSwitchLayer)) {
RecurseSwitchLayer(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo);
} else {
MS_LOG(EXCEPTION) << "Unexpected node: " << cnode->DebugString();
}
}
kg->SetExecOrderByDefault();
@@ -699,31 +683,22 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> fr
continue;
}
const auto &from_graph_exe_order = from_graph->execution_order();
std::vector<CNodePtr> real_exe_order(from_graph_exe_order.size());
size_t real_exe_order_size = 0;
std::copy_if(from_graph_exe_order.begin(), from_graph_exe_order.end(), real_exe_order.begin(),
[&real_exe_order_size](const CNodePtr &node) -> bool {
return (IsPrimitiveCNode(node, prim::kPrimSwitch) || IsPrimitiveCNode(node, prim::kPrimPartial))
? false
: (++real_exe_order_size, true);
});
real_exe_order.resize(real_exe_order_size);
if (jump_node == nullptr) {
if (!real_exe_order.empty()) {
InsertControlDependToGraph(from_graph, NOT_NULL(*(real_exe_order.rbegin())), NOT_NULL(assign_node));
if (!from_graph_exe_order.empty()) {
InsertControlDependToGraph(from_graph, NOT_NULL(*(from_graph_exe_order.rbegin())), NOT_NULL(assign_node));
} else {
InsertDependToGraph(from_graph, NOT_NULL(assign_node));
}
continue;
}

auto jump_node_iter = std::find(real_exe_order.begin(), real_exe_order.end(), jump_node);
if (jump_node_iter == real_exe_order.end()) {
auto jump_node_iter = std::find(from_graph_exe_order.begin(), from_graph_exe_order.end(), jump_node);
if (jump_node_iter == from_graph_exe_order.end()) {
MS_LOG(EXCEPTION) << "Cannot find jump node " << jump_node->DebugString() << " in graph "
<< from_graph->ToString();
}
// insert assign between jump_node -1 and jump_node
if (jump_node_iter != real_exe_order.begin()) {
if (jump_node_iter != from_graph_exe_order.begin()) {
InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node));
}
InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node));
@@ -772,6 +747,7 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr>
std::vector<CNodePtr> execution_order;
uint32_t child_order_index = 0;
for (auto &node : cnodes) {
uint32_t child_graph_index = 0;
execution_order.push_back(node);
if (node == graph->get_end_goto()) {
continue;
@@ -779,7 +755,7 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr>
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) {
std::vector<uint32_t> label_switch_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(node, kAttrLabelSwitchList);
for (auto iter = label_switch_list.rbegin(); iter != label_switch_list.rend(); ++iter) {
if (!CheckLabelIndex(child_order_index, *iter, node, graph)) {
if (!CheckLabelIndex(child_graph_index++, *iter, node)) {
MS_LOG(EXCEPTION) << "Check label index fail";
}
if (child_order_index >= graph->child_graph_order().size()) {
@@ -791,9 +767,12 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr>
}
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) {
uint32_t label_index = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
if (!CheckLabelIndex(child_order_index, label_index, node, graph)) {
if (!CheckLabelIndex(child_graph_index, label_index, node)) {
MS_LOG(EXCEPTION) << "Check label index fail";
}
if (child_order_index >= graph->child_graph_order().size()) {
MS_LOG(EXCEPTION) << "Index out of range:" << graph->child_graph_order().size();
}
auto child_graph = graph->child_graph_order()[child_order_index++];
auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo);
execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end());
@@ -804,15 +783,14 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr>
return execution_order;
}

bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cur_label,
NotNull<KernelGraphPtr> graph) {
const std::vector<std::shared_ptr<KernelGraph>> &child_graph_order = graph->child_graph_order();
bool AscendControlParser::CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cur_label) {
auto child_graphs = AnfAlgo::GetNodeAttr<std::vector<KernelGraphPtr>>(cur_label, kAttrChildGraph);
// check index and child order size
if (child_graph_order.size() <= IntToSize(order_index)) {
MS_LOG(EXCEPTION) << "Child graph order is wrong, graph " << graph->ToString() << " child graph size "
<< child_graph_order.size() << " goto index " << order_index;
if (child_graphs.size() <= IntToSize(index)) {
MS_LOG(EXCEPTION) << "Child graph index is wrong, current node " << cur_label->ToString() << " child graph size "
<< child_graphs.size() << " goto index " << index;
}
auto child_graph = child_graph_order[order_index];
auto child_graph = child_graphs[index];
MS_EXCEPTION_IF_NULL(child_graph);

// get start_label_set_index of child graph
@@ -822,7 +800,7 @@ bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_i
MS_EXCEPTION_IF_NULL(cur_label);
MS_EXCEPTION_IF_NULL(start_label_set);
MS_LOG(WARNING) << cur_label->DebugString() << " index " << label_index << " but " << start_label_set->DebugString()
<< " index " << start_label_set_index << " current child graph order : " << order_index;
<< " index " << start_label_set_index;
return false;
} else {
return true;


+ 3
- 3
mindspore/ccsrc/backend/session/ascend_control_parser.h View File

@@ -64,13 +64,13 @@ class AscendControlParser {
const CNodePtr &last_label);

static AnfNodePtr InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
static std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> ParseCallNode(NotNull<CNodePtr> call_node);
static std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> ParseCallSwitchNode(
NotNull<CNodePtr> call_node);
static std::tuple<KernelGraphPtr, std::vector<AnfNodePtr>> ParsePartial(NotNull<AnfNodePtr> node);
static void AttachChildGraphToReturnNode(NotNull<KernelGraphPtr> graph,
const NotNull<std::set<KernelGraphPtr> *> memo);
// root graph order
static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode,
NotNull<KernelGraphPtr> graph);
static bool CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cnode);
static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph,
const NotNull<std::set<KernelGraphPtr> *> memo);
};


+ 2
- 2
mindspore/ccsrc/backend/session/ascend_session.cc View File

@@ -885,7 +885,7 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu
std::map<AnfNodePtr, AnfNodePtr> need_replace_list;
auto node_list = GetCNodes(TopoSort(graph->get_return()));
for (auto &node : node_list) {
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) {
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) {
// create a parameter to store the output of multiple branch and set the parameter as the condition graph's output
auto output_param = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract()));
MS_EXCEPTION_IF_NULL(graph->MutableInputs());
@@ -898,7 +898,7 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu
MS_LOG(INFO) << "Create parameter " << output_param->DebugString() << " for call node " << node->DebugString()
<< ", depend node is " << depend->DebugString();
// insert assign in order to transfer child graph output to parameter
auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node);
auto child_graphs = AnfAlgo::GetCallSwitchKernelGraph(node);
for (auto &child_graph : child_graphs) {
MS_EXCEPTION_IF_NULL(child_graph);
// If graph has no output, the graph is the true graph of while and will call condition graph, no need insert


+ 16
- 3
mindspore/ccsrc/backend/session/kernel_graph.cc View File

@@ -67,7 +67,7 @@ std::vector<AnfNodePtr> GetCallRealOutputs(const AnfNodePtr &call_node) {
return {node};
}
std::vector<AnfNodePtr> real_inputs;
auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node->cast<CNodePtr>());
auto child_graphs = AnfAlgo::GetCallSwitchKernelGraph(node->cast<CNodePtr>());
for (const auto &child_graph : child_graphs) {
if (child_graph->get_output_null()) {
continue;
@@ -931,6 +931,18 @@ std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primi
return result;
}

std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const std::vector<PrimitivePtr> &primitive_list) const {
std::vector<CNodePtr> result;
for (const auto &anf : execution_order_) {
for (const auto &primitive : primitive_list) {
if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) {
result.push_back(anf->cast<CNodePtr>());
}
}
}
return result;
}

void KernelGraph::PrintGraphExecuteOrder() const {
MS_LOG(INFO) << "Graph:" << graph_id_ << "execution order";
for (size_t i = 0; i < execution_order_.size(); i++) {
@@ -1078,11 +1090,12 @@ bool KernelGraph::IsUniqueTargetInternalOutput(const AnfNodePtr &node, int outpu
void KernelGraph::UpdateChildGraphOrder() {
MS_LOG(INFO) << "Update " << ToString() << " child graph order.";
SetExecOrderByDefault();
auto call_nodes = FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name()));
auto call_nodes = FindNodeByPrimitive(
{std::make_shared<Primitive>(prim::kPrimCall->name()), std::make_shared<Primitive>(prim::kPrimSwitch->name())});
std::vector<KernelGraphPtr> child_graph_order;
for (auto &call_node : call_nodes) {
MS_EXCEPTION_IF_NULL(call_node);
auto call_child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast<CNodePtr>());
auto call_child_graphs = AnfAlgo::GetCallSwitchKernelGraph(call_node->cast<CNodePtr>());
for (const auto &child_graph : call_child_graphs) {
MS_EXCEPTION_IF_NULL(child_graph);
if (child_graph != parent_graph_) {


+ 1
- 0
mindspore/ccsrc/backend/session/kernel_graph.h View File

@@ -131,6 +131,7 @@ class KernelGraph : public FuncGraph {
void set_parent_graph(const std::shared_ptr<KernelGraph> &parent_graph) { parent_graph_ = parent_graph; }
// find anf node in graph
std::vector<CNodePtr> FindNodeByPrimitive(const PrimitivePtr &primitive) const;
std::vector<CNodePtr> FindNodeByPrimitive(const std::vector<PrimitivePtr> &primitive_list) const;
// used to dump ir
std::string ToString() const override;



+ 73
- 46
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -547,45 +547,26 @@ CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGra
MS_EXCEPTION_IF_NULL(node_input);
MS_EXCEPTION_IF_NULL(graph);
// switch input generalizes partial
if (AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial) ||
AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimCall)) {
return node_input->cast<CNodePtr>();
}
if (node_input->isa<CNode>()) {
MS_LOG(EXCEPTION) << "If switch input is " << node_input->DebugString() << ", it mast be partial or call.";
}
std::vector<AnfNodePtr> partial_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name()))};
if (node_input->isa<ValueNode>() && IsValueNode<FuncGraph>(node_input)) {
partial_inputs.emplace_back(node_input);
auto partial_node = graph->NewCNode(partial_inputs);
return partial_node;
if (AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial)) {
auto partial_node = graph->GetBackendAnfByFrontAnf(node_input);
return partial_node->cast<CNodePtr>();
} else if (node_input->isa<ValueNode>() && IsValueNode<FuncGraph>(node_input)) {
partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input));
} else {
KernelGraphPtr kernel_graph = NewKernelGraph();
MS_EXCEPTION_IF_NULL(kernel_graph);
auto parameter = CreateNewParameterFromCNode(graph->GetBackendAnfByFrontAnf(node_input), true, kernel_graph.get());
auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name()));
auto return_node = kernel_graph->NewCNode({primitive, parameter});
kernel_graph->set_return(return_node);
partial_inputs.emplace_back(std::make_shared<ValueNode>(kernel_graph));
partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input));
}
KernelGraphPtr kernel_graph = NewKernelGraph();
MS_EXCEPTION_IF_NULL(kernel_graph);
kernel_graph->set_output(graph->GetBackendAnfByFrontAnf(node_input));
partial_inputs.emplace_back(std::make_shared<ValueNode>(kernel_graph));
auto partial_node = graph->NewCNode(partial_inputs);
return partial_node;
}

CNodePtr SessionBasic::HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(graph);
auto node = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(node);
if (node->inputs().size() < kSwitchInputSize) {
MS_LOG(EXCEPTION) << "Switch input size less than " << kSwitchInputSize;
}
auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimSwitch->name()));
std::vector<AnfNodePtr> switch_inputs = {primitive, node->input(1)};
for (size_t index = 2; index < node->inputs().size(); index++) {
auto input = CreateSwitchInput(node->input(index), graph);
switch_inputs.emplace_back(input);
}
auto switch_node = graph->NewCNode(switch_inputs);
return switch_node;
}

std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(graph);
@@ -611,14 +592,33 @@ std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &
});
return cnode_inputs;
} else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
auto switch_node = HandleSwitchInputs(cnode_input, graph);
auto switch_cnode = cnode_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(switch_cnode);
std::vector<AnfNodePtr> switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex),
switch_cnode->input(kFirstDataInputIndex)};
for (size_t index = kFirstBranchInSwitch; index < switch_cnode->inputs().size(); index++) {
auto node = switch_cnode->input(index);
// there is real input in call, should put it to true and false branch in switch
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
auto partial_node = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(partial_node);
std::vector<AnfNodePtr> partial_inputs = partial_node->inputs();
partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex)));
auto new_partial = graph->NewCNode(partial_inputs);
switch_inputs.emplace_back(new_partial);
}
}
if (switch_inputs.size() < kSwitchInputSize) {
MS_LOG(EXCEPTION) << "Switch inputs size: " << switch_inputs.size() << "less than " << kSwitchInputSize;
}
auto switch_node = graph->NewCNode(switch_inputs);
cnode_inputs.emplace_back(switch_node);
return cnode_inputs;
}
MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch.";
}

CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) {
CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> cnode_inputs;
@@ -642,7 +642,22 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
}
}
} else if (attr_input->isa<CNode>()) {
cnode_inputs = CreateSwitchOrPartialNode(cnode, graph);
auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
if (cnode->inputs().size() < 2 && AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
auto switch_cnode = cnode_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(switch_cnode);
cnode_inputs = switch_cnode->inputs();
} else {
cnode_inputs = CreateSwitchOrPartialNode(cnode, graph);
}
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
cnode_inputs = {graph->GetBackendAnfByFrontAnf(cnode->input(kAnfPrimitiveIndex)),
graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex))};
for (size_t index = kFirstBranchInSwitch; index < cnode->inputs().size(); index++) {
auto node_input = cnode->input(index);
auto switch_input = CreateSwitchInput(node_input, graph);
cnode_inputs.emplace_back(switch_input);
}
} else {
// get primitive of old node
auto prim = AnfAlgo::GetCNodePrimitive(cnode);
@@ -651,21 +666,33 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))};
}

for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
auto anf = cnode->input(input_idx);
MS_EXCEPTION_IF_NULL(anf);
// anf has been created before
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
continue;
} else if (IsValueNode<None>(anf)) {
continue;
if (!AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
for (size_t input_idx = kFirstDataInputIndex; input_idx < cnode->inputs().size(); input_idx++) {
auto anf = cnode->input(input_idx);
MS_EXCEPTION_IF_NULL(anf);
// anf has been created before
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
continue;
} else if (IsValueNode<None>(anf)) {
continue;
}
MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
}
MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
}
TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
auto new_cnode = graph->NewCNode(cnode_inputs);
TraceManager::EndTrace();

// if the cnode is call switch, remove call
if (new_cnode->inputs().size() > 1) {
auto first_input = new_cnode->input(kFirstDataInputIndex);
if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) &&
AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitch)) {
new_cnode = first_input->cast<CNodePtr>();
}
}

return new_cnode;
}



+ 5
- 5
mindspore/ccsrc/backend/session/session_basic.h View File

@@ -86,11 +86,7 @@ class SessionBasic {

CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph,
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph);

CNodePtr CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph);
CNodePtr HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph);
std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph);
CNodePtr CreateNewCNode(CNodePtr cnode, KernelGraph *graph);

// get graph id in child graphs by ME front anf node pointer
virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; }
@@ -112,6 +108,10 @@ class SessionBasic {
}
#endif

private:
CNodePtr CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph);
std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph);

protected:
virtual void SetSummaryNodes(KernelGraph *graph);
// Get graph by graph id ,if not exist return null ptr


+ 4
- 1
mindspore/ccsrc/utils/utils.h View File

@@ -277,11 +277,14 @@ const int kValueNodeTensorMask = 2;
// define special index in special node
constexpr auto kAnfPrimitiveIndex = 0;
constexpr auto kFirstDataInputIndex = 1;
constexpr auto kAnfPartialFuncGraphIndex = 1;
constexpr auto kRealInputNodeIndexInTupleGetItem = 1;
constexpr auto kInputNodeOutputIndexInTupleGetItem = 2;
constexpr auto kTupleGetItemInputSize = 3;
constexpr auto kSwitchInputSize = 4;
constexpr auto kFirstBranchInSwitch = 2;
constexpr auto kCallKernelGraphIndex = 1;
constexpr auto kSwitchTrueKernelGraphIndex = 2;
constexpr auto kSwitchFalseKernelGraphIndex = 3;
// index define of control depend
constexpr auto kControlDependPriorIndex = 1;
constexpr auto kControlDependBehindIndex = 2;


Loading…
Cancel
Save