Browse Source

!9459 for switch layer

From: @youui
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
4249940ba6
7 changed files with 165 additions and 47 deletions
  1. +29
    -21
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
  2. +20
    -10
      mindspore/ccsrc/backend/session/ascend_control_parser.cc
  3. +3
    -2
      mindspore/ccsrc/backend/session/ascend_session.cc
  4. +3
    -2
      mindspore/ccsrc/backend/session/kernel_graph.cc
  5. +106
    -11
      mindspore/ccsrc/backend/session/session_basic.cc
  6. +3
    -1
      mindspore/ccsrc/backend/session/session_basic.h
  7. +1
    -0
      mindspore/ccsrc/utils/utils.h

+ 29
- 21
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc View File

@@ -360,7 +360,7 @@ size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) {
MS_LOG(EXCEPTION) << "Cnode inputs size can't be zero" MS_LOG(EXCEPTION) << "Cnode inputs size can't be zero"
<< " trace: " << trace::DumpSourceLines(node); << " trace: " << trace::DumpSourceLines(node);
} }
// exclude intputs[0],which is value_node storing attr,inputs left are real input
// exclude inputs[0],which is value_node storing attr,inputs left are real input
return input_num - 1; return input_num - 1;
} }


@@ -1191,10 +1191,28 @@ FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node)


std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallSwitchKernelGraph(const CNodePtr &cnode) { std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallSwitchKernelGraph(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch))) {
MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a call or switch node."
if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch) ||
AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer))) {
MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a call or switch or switch_layer node."
<< " trace: " << trace::DumpSourceLines(cnode); << " trace: " << trace::DumpSourceLines(cnode);
} }
auto get_switch_kernel_graph = [cnode](size_t input_index) -> KernelGraphPtr {
auto partial = cnode->input(input_index);
MS_EXCEPTION_IF_NULL(partial);
if (IsValueNode<KernelGraph>(partial)) {
return GetValueNode<KernelGraphPtr>(partial);
}
auto partial_cnode = partial->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(partial_cnode);
auto graph_node = partial_cnode->input(kCallKernelGraphIndex);
MS_EXCEPTION_IF_NULL(graph_node);
auto graph_value_node = graph_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(graph_value_node);
auto graph_value = graph_value_node->value();
MS_EXCEPTION_IF_NULL(graph_value);
auto child_graph = graph_value->cast<KernelGraphPtr>();
return child_graph;
};
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) { if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) {
auto input1 = cnode->input(kCallKernelGraphIndex); auto input1 = cnode->input(kCallKernelGraphIndex);
MS_EXCEPTION_IF_NULL(input1); MS_EXCEPTION_IF_NULL(input1);
@@ -1204,25 +1222,15 @@ std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallSwitchKernelGraph(const
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
return {kernel_graph->cast<KernelGraphPtr>()}; return {kernel_graph->cast<KernelGraphPtr>()};
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
auto get_switch_kernel_graph = [cnode](size_t input_index) -> KernelGraphPtr {
auto partial = cnode->input(input_index);
MS_EXCEPTION_IF_NULL(partial);
if (IsValueNode<KernelGraph>(partial)) {
return GetValueNode<KernelGraphPtr>(partial);
}
auto partial_cnode = partial->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(partial_cnode);
auto graph_node = partial_cnode->input(kCallKernelGraphIndex);
MS_EXCEPTION_IF_NULL(graph_node);
auto graph_value_node = graph_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(graph_value_node);
auto graph_value = graph_value_node->value();
MS_EXCEPTION_IF_NULL(graph_value);
auto child_graph = graph_value->cast<KernelGraphPtr>();
return child_graph;
};
return {get_switch_kernel_graph(kSwitchTrueKernelGraphIndex), return {get_switch_kernel_graph(kSwitchTrueKernelGraphIndex),
get_switch_kernel_graph(kSwitchFalseKernelGraphIndex)}; get_switch_kernel_graph(kSwitchFalseKernelGraphIndex)};
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) {
std::vector<KernelGraphPtr> child_graphs;
for (size_t idx = kMakeTupleInSwitchLayerIndex; idx < cnode->inputs().size(); idx++) {
auto child_graph = get_switch_kernel_graph(idx);
child_graphs.emplace_back(child_graph);
}
return child_graphs;
} }
return {}; return {};
} }
@@ -1627,7 +1635,7 @@ void AnfRuntimeAlgorithm::GetAllFatherRealNode(const AnfNodePtr &anf_node, std::
MS_EXCEPTION_IF_NULL(result); MS_EXCEPTION_IF_NULL(result);
MS_EXCEPTION_IF_NULL(visited); MS_EXCEPTION_IF_NULL(visited);
if (visited->find(anf_node) != visited->end()) { if (visited->find(anf_node) != visited->end()) {
MS_LOG(INFO) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited";
MS_LOG(INFO) << "Node:" << anf_node->fullname_with_scope() << " has already been visited";
return; return;
} }
visited->insert(anf_node); visited->insert(anf_node);


+ 20
- 10
mindspore/ccsrc/backend/session/ascend_control_parser.cc View File

@@ -156,7 +156,7 @@ static std::vector<CNodePtr> GetTargetLabelSetNodes(NotNull<CNodePtr> jump_node,
for (auto label_id : target_label_list) { for (auto label_id : target_label_list) {
auto iter = label_id_to_label_set.find(label_id); auto iter = label_id_to_label_set.find(label_id);
if (iter == label_id_to_label_set.end()) { if (iter == label_id_to_label_set.end()) {
MS_LOG(EXCEPTION) << "Connot find LabelSet node has label id " << label_id;
MS_LOG(EXCEPTION) << "Cannot find LabelSet node has label id " << label_id;
} }
target_labelset_nodes.push_back(iter->second); target_labelset_nodes.push_back(iter->second);
} }
@@ -413,6 +413,16 @@ std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> AscendControlPar
const auto &[target_graph, args] = ParsePartial(NOT_NULL(*iter)); const auto &[target_graph, args] = ParsePartial(NOT_NULL(*iter));
ret.emplace_back(target_graph, args); ret.emplace_back(target_graph, args);
} }
} else if (IsPrimitiveCNode(cnode.get(), prim::kPrimSwitchLayer)) {
const std::vector<AnfNodePtr> &switch_layer_inputs = cnode->inputs();
if (switch_layer_inputs.size() <= kCNodeSwitchLayerBranch) {
MS_LOG(EXCEPTION) << "Switch layer node " << cnode->DebugString() << " has invalid inputs size "
<< switch_layer_inputs.size();
}
for (auto iter = switch_layer_inputs.begin() + kCNodeSwitchLayerBranch; iter != switch_layer_inputs.end(); ++iter) {
const auto &[target_graph, args] = ParsePartial(NOT_NULL(*iter));
ret.emplace_back(target_graph, args);
}
} else { } else {
MS_LOG(EXCEPTION) << "Unsupported call node: " << cnode->DebugString(5); MS_LOG(EXCEPTION) << "Unsupported call node: " << cnode->DebugString(5);
} }
@@ -431,7 +441,8 @@ void AscendControlParser::ChildGraphDataAssign(
const std::vector<CNodePtr> &nodes = kg->execution_order(); const std::vector<CNodePtr> &nodes = kg->execution_order();


for (auto &node : nodes) { for (auto &node : nodes) {
if (!(IsPrimitiveCNode(node, prim::kPrimCall) || IsPrimitiveCNode(node, prim::kPrimSwitch))) {
if (!(IsPrimitiveCNode(node, prim::kPrimCall) || IsPrimitiveCNode(node, prim::kPrimSwitch) ||
IsPrimitiveCNode(node, prim::kPrimSwitchLayer))) {
continue; continue;
} }


@@ -647,12 +658,10 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull
MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength; MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength;
} }


auto branch_tuple = cur_node->input(kCNodeSwitchLayerBranch);
MS_EXCEPTION_IF_NULL(branch_tuple);
if (!branch_tuple->isa<CNode>()) {
MS_LOG(EXCEPTION) << branch_tuple->DebugString() << " is not a CNode";
std::vector<AnfNodePtr> branch_partial;
for (size_t idx = kCNodeSwitchLayerBranch; idx < cur_node->inputs().size(); idx++) {
branch_partial.emplace_back(cur_node->input(idx));
} }
const std::vector<AnfNodePtr> &branch_partial = utils::cast<CNodePtr>(branch_tuple)->inputs();
// 1 return label // 1 return label
auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))}); auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))});
// 2 add depend relationship // 2 add depend relationship
@@ -673,16 +682,17 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull
// 3.1 branch kernel graph and args // 3.1 branch kernel graph and args
KernelGraphPtr branch_fg; KernelGraphPtr branch_fg;
std::vector<AnfNodePtr> origin_inputs; std::vector<AnfNodePtr> origin_inputs;
std::tie(branch_fg, origin_inputs) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
std::tie(branch_fg, origin_inputs) = ParsePartial(NOT_NULL(origin_switch_inputs[i + kCNodeSwitchLayerBranch]));
child_graphs.push_back(branch_fg); child_graphs.push_back(branch_fg);
// 3.2 recurse sub graph // 3.2 recurse sub graph
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo);
new_switch_inputs.push_back(branch_label); new_switch_inputs.push_back(branch_label);
AttachOriginalInputsToGraph(kg, origin_inputs); AttachOriginalInputsToGraph(kg, origin_inputs);
} }
new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end());
cur_node->set_inputs(new_switch_inputs); cur_node->set_inputs(new_switch_inputs);
cur_node->set_abstract(nullptr);
cur_node->set_abstract(std::make_shared<abstract::AbstractNone>());
// To adapt to the true and false branches of the switch, the sequence of the branches is reversed.
std::reverse(child_graphs.begin(), child_graphs.end());
AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue<std::vector<KernelGraphPtr>>(child_graphs), cur_node.get()); AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue<std::vector<KernelGraphPtr>>(child_graphs), cur_node.get());
MS_LOG(INFO) << "Succeed processing switch layer " << cur_node->DebugString(); MS_LOG(INFO) << "Succeed processing switch layer " << cur_node->DebugString();
} }


+ 3
- 2
mindspore/ccsrc/backend/session/ascend_session.cc View File

@@ -875,7 +875,7 @@ void AscendSession::BuildGraphImpl(GraphId graph_id) {
// generate and load task info to device if it is sink mode // generate and load task info to device if it is sink mode
Load(graph); Load(graph);
} }
// sync the inital const tensor to device
// sync the initial const tensor to device
SyncInitialTenosrToDevice(); SyncInitialTenosrToDevice();
DumpAllGraphs({graph}); DumpAllGraphs({graph});
MS_LOG(INFO) << "End"; MS_LOG(INFO) << "End";
@@ -1634,7 +1634,8 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu
std::map<AnfNodePtr, AnfNodePtr> need_replace_list; std::map<AnfNodePtr, AnfNodePtr> need_replace_list;
auto node_list = GetCNodes(TopoSort(graph->get_return())); auto node_list = GetCNodes(TopoSort(graph->get_return()));
for (auto &node : node_list) { for (auto &node : node_list) {
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) {
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitchLayer)) {
// create a parameter to store the output of multiple branch and set the parameter as the condition graph's output // create a parameter to store the output of multiple branch and set the parameter as the condition graph's output
auto output_param = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract())); auto output_param = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract()));
MS_EXCEPTION_IF_NULL(graph->MutableInputs()); MS_EXCEPTION_IF_NULL(graph->MutableInputs());


+ 3
- 2
mindspore/ccsrc/backend/session/kernel_graph.cc View File

@@ -1186,8 +1186,9 @@ bool KernelGraph::IsUniqueTargetInternalOutput(const AnfNodePtr &node, int outpu
void KernelGraph::UpdateChildGraphOrder() { void KernelGraph::UpdateChildGraphOrder() {
MS_LOG(INFO) << "Update " << ToString() << " child graph order."; MS_LOG(INFO) << "Update " << ToString() << " child graph order.";
SetExecOrderByDefault(); SetExecOrderByDefault();
auto call_nodes = FindNodeByPrimitive(
{std::make_shared<Primitive>(prim::kPrimCall->name()), std::make_shared<Primitive>(prim::kPrimSwitch->name())});
auto call_nodes = FindNodeByPrimitive({std::make_shared<Primitive>(prim::kPrimCall->name()),
std::make_shared<Primitive>(prim::kPrimSwitch->name()),
std::make_shared<Primitive>(prim::kPrimSwitchLayer->name())});
std::vector<std::weak_ptr<KernelGraph>> child_graph_order; std::vector<std::weak_ptr<KernelGraph>> child_graph_order;
for (auto &call_node : call_nodes) { for (auto &call_node : call_nodes) {
MS_EXCEPTION_IF_NULL(call_node); MS_EXCEPTION_IF_NULL(call_node);


+ 106
- 11
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -148,7 +148,7 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o
} }
} }
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index)); tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
// if in paynative mode,data only copyed to host when user want to print data
// if in pynative mode,data only copied to host when user want to print data
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
@@ -499,10 +499,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateParameterFromTuple(const AnfNodePtr
auto graph_inputs = graph->MutableInputs(); auto graph_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs); MS_EXCEPTION_IF_NULL(graph_inputs);
auto create_parameter = [&](const AbstractBasePtr &abstract) -> void { auto create_parameter = [&](const AbstractBasePtr &abstract) -> void {
auto parameter = graph->NewParameter();
MS_EXCEPTION_IF_NULL(parameter);
parameter->set_abstract(abstract);
auto new_parameter = graph->NewParameter(parameter);
auto new_parameter = graph->NewParameter(abstract);
parameters.push_back(new_parameter); parameters.push_back(new_parameter);
valid_inputs->push_back(true); valid_inputs->push_back(true);
graph_inputs->push_back(new_parameter); graph_inputs->push_back(new_parameter);
@@ -662,7 +659,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
return new_cnode; return new_cnode;
} }


CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph) {
CNodePtr SessionBasic::CreateSwitchInput(const CNodePtr &cnode, const AnfNodePtr &node_input, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(node_input); MS_EXCEPTION_IF_NULL(node_input);
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
// switch input generalizes partial // switch input generalizes partial
@@ -675,9 +672,11 @@ CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGra
} else { } else {
KernelGraphPtr kernel_graph = NewKernelGraph(); KernelGraphPtr kernel_graph = NewKernelGraph();
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
auto parameter = CreateNewParameterFromCNode(graph->GetBackendAnfByFrontAnf(node_input), kernel_graph.get());
auto parameter = CreateNewParameterFromCNode(cnode, kernel_graph.get());
parameter->set_abstract(cnode->abstract());
auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name())); auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name()));
auto return_node = kernel_graph->NewCNode({primitive, parameter}); auto return_node = kernel_graph->NewCNode({primitive, parameter});
return_node->set_abstract(cnode->abstract());
kernel_graph->set_return(return_node); kernel_graph->set_return(return_node);
partial_inputs.emplace_back(std::make_shared<ValueNode>(kernel_graph)); partial_inputs.emplace_back(std::make_shared<ValueNode>(kernel_graph));
partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input)); partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input));
@@ -722,10 +721,97 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchInputs(const CNodePtr &cno
return cnode_inputs; return cnode_inputs;
} }


void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, const AnfNodePtr &real_input) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(real_input);
if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimPartial))) {
MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a partial node.";
}
auto partial_input = cnode->input(kFirstDataInputIndex);
KernelGraphPtr partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_input);
MS_EXCEPTION_IF_NULL(partial_kernel_graph);
auto ret = partial_kernel_graph->get_return();
MS_EXCEPTION_IF_NULL(ret);
auto return_input = ret->input(kFirstDataInputIndex);
// if kernel graph return node is a function
if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial)) {
std::vector<AnfNodePtr> call_inputs = {
partial_kernel_graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
auto return_input_cnode = return_input->cast<CNodePtr>();

auto partial_inputs = return_input_cnode->inputs();
call_inputs.insert(call_inputs.end(), partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end());
auto parameter_for_input = CreateNewParameterFromCNode(real_input, partial_kernel_graph.get());
call_inputs.emplace_back(parameter_for_input);
auto call_node = partial_kernel_graph->NewCNode(call_inputs);
// update abstract
KernelGraphPtr sub_partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_inputs[kFirstDataInputIndex]);
auto ret_partial = sub_partial_kernel_graph->get_return();
call_node->set_abstract(ret_partial->abstract());
// update return input
ret->set_input(kFirstDataInputIndex, call_node);
}
}

std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> cnode_inputs = {
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
auto attr_input = cnode->input(kAnfPrimitiveIndex);
MS_EXCEPTION_IF_NULL(attr_input);
auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
auto switch_layer_cnode = cnode_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(switch_layer_cnode);
std::vector<AnfNodePtr> switch_layer_inputs = {switch_layer_cnode->input(kAnfPrimitiveIndex),
switch_layer_cnode->input(kFirstDataInputIndex)};
auto make_tuple_node = switch_layer_cnode->input(kMakeTupleInSwitchLayerIndex);
MS_EXCEPTION_IF_NULL(make_tuple_node);
auto node = make_tuple_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(node);
auto make_tuple_inputs = node->inputs();
// there is real input in call, should put it to make_tuple in switch_layer
auto real_input = cnode->input(kFirstDataInputIndex);
auto real_input_back = graph->GetBackendAnfByFrontAnf(real_input);
std::vector<AnfNodePtr> new_make_tuple_inputs = {
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())))};
for (size_t idx = kFirstDataInputIndex; idx < make_tuple_inputs.size(); idx++) {
auto partial_idx = make_tuple_inputs[idx];
MS_EXCEPTION_IF_NULL(cnode->abstract());
// switch_layer node input is partial cnode
if (AnfAlgo::CheckPrimitiveType(partial_idx, prim::kPrimPartial)) {
auto partial_node = partial_idx->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(partial_node);
// update kernel graph when switch_layer node return function
CreateCallNodeReturnFunction(partial_node, real_input_back);

std::vector<AnfNodePtr> new_partial_inputs = partial_node->inputs();
new_partial_inputs.emplace_back(real_input_back);
auto new_partial = graph->NewCNode(new_partial_inputs);
new_make_tuple_inputs.emplace_back(new_partial);
}
// switch_layer node input is kernel graph value node
if (IsValueNode<KernelGraph>(partial_idx)) {
// make_tuple inputs is KernelGraph
std::vector<AnfNodePtr> new_partial_inputs;
new_partial_inputs.emplace_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name())));
new_partial_inputs.emplace_back(partial_idx);
new_partial_inputs.emplace_back(real_input_back);
auto new_partial = graph->NewCNode(new_partial_inputs);
new_make_tuple_inputs.emplace_back(new_partial);
}
}
auto new_make_tuple = graph->NewCNode(new_make_tuple_inputs);
switch_layer_inputs.emplace_back(new_make_tuple);
auto new_switch_layer = graph->NewCNode(switch_layer_inputs);
cnode_inputs.emplace_back(new_switch_layer);
return cnode_inputs;
}

std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) { std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
// create primitive of cnode:call(partial or switch)
// create primitive of cnode:call(partial or switch or switch_layer)
std::vector<AnfNodePtr> cnode_inputs = { std::vector<AnfNodePtr> cnode_inputs = {
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))}; graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
auto attr_input = cnode->input(kAnfPrimitiveIndex); auto attr_input = cnode->input(kAnfPrimitiveIndex);
@@ -748,9 +834,11 @@ std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &
return cnode_inputs; return cnode_inputs;
} else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
return CreateCallSwitchInputs(cnode, graph); return CreateCallSwitchInputs(cnode, graph);
} else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitchLayer)) {
return CreateCallSwitchLayerInputs(cnode, graph);
} }
MS_LOG(ERROR) << "CNode:" << cnode->DebugString() << " input[0]" << cnode_input->DebugString() MS_LOG(ERROR) << "CNode:" << cnode->DebugString() << " input[0]" << cnode_input->DebugString()
<< "must be partial or switch.";
<< "must be partial or switch or switch_layer.";
return {}; return {};
} }


@@ -788,7 +876,7 @@ void SessionBasic::CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex))); cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex)));
for (size_t index = kFirstBranchInSwitch; index < cnode->inputs().size(); index++) { for (size_t index = kFirstBranchInSwitch; index < cnode->inputs().size(); index++) {
auto node_input = cnode->input(index); auto node_input = cnode->input(index);
auto switch_input = CreateSwitchInput(node_input, graph);
auto switch_input = CreateSwitchInput(cnode, node_input, graph);
cnode_inputs->emplace_back(switch_input); cnode_inputs->emplace_back(switch_input);
} }
} else { } else {
@@ -841,10 +929,17 @@ CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) {
// if the cnode is call switch, remove call // if the cnode is call switch, remove call
if (new_cnode->inputs().size() > 1) { if (new_cnode->inputs().size() > 1) {
auto first_input = new_cnode->input(kFirstDataInputIndex); auto first_input = new_cnode->input(kFirstDataInputIndex);
MS_EXCEPTION_IF_NULL(first_input);
if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) && if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) &&
AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitch)) { AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitch)) {
new_cnode = first_input->cast<CNodePtr>(); new_cnode = first_input->cast<CNodePtr>();
} }
if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) &&
AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitchLayer)) {
auto abstract = cnode->abstract();
new_cnode = first_input->cast<CNodePtr>();
new_cnode->set_abstract(abstract);
}
} }


return new_cnode; return new_cnode;
@@ -1842,7 +1937,7 @@ void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
// PS embeddingLookup cache check. // PS embeddingLookup cache check.
if (ps::PsDataPrefetch::GetInstance().cache_enable()) { if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
MS_LOG(EXCEPTION) << "The other parameter cann't set ps mode when the embeddingLookup cache is enabled in "
MS_LOG(EXCEPTION) << "The other parameter can't set ps mode when the embeddingLookup cache is enabled in "
"parameter server training mode."; "parameter server training mode.";
} }
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return()); std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return());


+ 3
- 1
mindspore/ccsrc/backend/session/session_basic.h View File

@@ -125,7 +125,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
#endif #endif


private: private:
CNodePtr CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph);
CNodePtr CreateSwitchInput(const CNodePtr &cnode, const AnfNodePtr &node_input, KernelGraph *graph);
std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph); std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph);
std::vector<AnfNodePtr> CreateValueNode(const CNodePtr &cnode, KernelGraph *graph); std::vector<AnfNodePtr> CreateValueNode(const CNodePtr &cnode, KernelGraph *graph);
void CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs); void CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs);
@@ -133,6 +133,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
void GetCNodeInfo(const CNodePtr &cnode, std::vector<AnfNodePtr> *cnode_inputs); void GetCNodeInfo(const CNodePtr &cnode, std::vector<AnfNodePtr> *cnode_inputs);
void GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs, void GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs,
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode); std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
std::vector<AnfNodePtr> CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph);
void CreateCallNodeReturnFunction(const CNodePtr &cnode, const AnfNodePtr &real_input);


protected: protected:
friend class Executor; friend class Executor;


+ 1
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -407,6 +407,7 @@ constexpr auto kFirstBranchInSwitch = 2;
constexpr auto kCallKernelGraphIndex = 1; constexpr auto kCallKernelGraphIndex = 1;
constexpr auto kSwitchTrueKernelGraphIndex = 2; constexpr auto kSwitchTrueKernelGraphIndex = 2;
constexpr auto kSwitchFalseKernelGraphIndex = 3; constexpr auto kSwitchFalseKernelGraphIndex = 3;
constexpr auto kMakeTupleInSwitchLayerIndex = 2;
// index define of control depend // index define of control depend
constexpr auto kControlDependPriorIndex = 1; constexpr auto kControlDependPriorIndex = 1;
constexpr auto kControlDependBehindIndex = 2; constexpr auto kControlDependBehindIndex = 2;


Loading…
Cancel
Save