Browse Source

!13013 fix switch layer

From: @youui
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
7e8f58f79d
2 changed files with 58 additions and 46 deletions
  1. +56
    -45
      mindspore/ccsrc/backend/session/session_basic.cc
  2. +2
    -1
      mindspore/ccsrc/backend/session/session_basic.h

+ 56
- 45
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -934,46 +934,60 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchInputs(const CNodePtr &cno
return cnode_inputs;
}

void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, const std::vector<AnfNodePtr> &real_inputs) {
void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, KernelGraph *graph,
const std::vector<AnfNodePtr> &real_inputs) {
MS_EXCEPTION_IF_NULL(cnode);
if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimPartial))) {
MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a partial node.";
// func1 =switch(branch1, branch2)
// func2 = func1(param1)
// out = func2(param2)
// process the last cnode(func2), not func1 which abstract is AbstractFunction
if (cnode->abstract()->isa<abstract::AbstractFunction>()) {
return;
}
auto partial_input = cnode->input(kFirstDataInputIndex);
KernelGraphPtr partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_input);
MS_EXCEPTION_IF_NULL(partial_kernel_graph);
auto ret = partial_kernel_graph->get_return();
MS_EXCEPTION_IF_NULL(graph);
auto ret = graph->get_return();
MS_EXCEPTION_IF_NULL(ret);
auto return_input = ret->input(kFirstDataInputIndex);
// return node is a function
std::vector<AnfNodePtr> call_inputs = {
partial_kernel_graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
AnfNodePtr real_kernel_graph;
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial)) {
auto return_input_cnode = return_input->cast<CNodePtr>();
auto partial_inputs = return_input_cnode->inputs();
call_inputs.insert(call_inputs.end(), partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end());
real_kernel_graph = partial_inputs[kFirstDataInputIndex];
} else { // return node is kernel graph
} else if (IsValueNode<KernelGraph>(return_input)) { // return node is kernel graph
call_inputs.emplace_back(return_input);
real_kernel_graph = return_input;
} else { // return node is value node
KernelGraphPtr kernel_graph = NewKernelGraph();
auto valid_inputs = kernel_graph->MutableValidInputs();
MS_EXCEPTION_IF_NULL(valid_inputs);
auto graph_inputs = kernel_graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs);
std::vector<AnfNodePtr> cnode_inputs = {return_input};
for (auto real_input : real_inputs) {
auto new_parameter = kernel_graph->NewParameter(real_input->abstract());
valid_inputs->push_back(true);
graph_inputs->push_back(new_parameter);
cnode_inputs.push_back(new_parameter);
}
auto new_cnode = kernel_graph->NewCNode(cnode_inputs);
new_cnode->set_abstract(cnode->abstract());
std::vector<AnfNodePtr> return_inputs = {
kernel_graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name()))), new_cnode};
auto return_node = kernel_graph->NewCNode(return_inputs);
return_node->set_abstract(cnode->abstract());
kernel_graph->set_return(return_node);
call_inputs.push_back(std::make_shared<ValueNode>(kernel_graph));
}

// new call node inputs
for (auto real_input : real_inputs) {
auto parameter_for_input = CreateNewParameterFromCNode(real_input, partial_kernel_graph.get());
auto parameter_for_input = CreateNewParameterFromCNode(real_input, graph);
call_inputs.emplace_back(parameter_for_input);
}

auto call_node = partial_kernel_graph->NewCNode(call_inputs);
// update abstract
MS_EXCEPTION_IF_NULL(real_kernel_graph);
if (real_kernel_graph->isa<ValueNode>() && IsValueNode<FuncGraph>(real_kernel_graph)) {
KernelGraphPtr sub_partial_kernel_graph = GetValueNode<KernelGraphPtr>(real_kernel_graph);
MS_EXCEPTION_IF_NULL(sub_partial_kernel_graph);
auto ret_partial = sub_partial_kernel_graph->get_return();
call_node->set_abstract(ret_partial->abstract());
}
auto call_node = graph->NewCNode(call_inputs);
call_node->set_abstract(cnode->abstract());
// update return input
ret->set_input(kFirstDataInputIndex, call_node);
}
@@ -1005,36 +1019,33 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr
for (size_t idx = kFirstDataInputIndex; idx < make_tuple_inputs.size(); idx++) {
auto partial_idx = make_tuple_inputs[idx];
MS_EXCEPTION_IF_NULL(cnode->abstract());
std::vector<AnfNodePtr> new_partial_inputs;
KernelGraphPtr partial_kernel_graph;
// switch_layer node input is partial cnode
if (AnfAlgo::CheckPrimitiveType(partial_idx, prim::kPrimPartial)) {
auto partial_node = partial_idx->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(partial_node);
// update kernel graph when switch_layer node return function
auto partial_input = partial_node->input(kFirstDataInputIndex);
KernelGraphPtr partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_input);
MS_EXCEPTION_IF_NULL(partial_kernel_graph);
auto ret = partial_kernel_graph->get_return();
MS_EXCEPTION_IF_NULL(ret);
auto return_input = ret->input(kFirstDataInputIndex);
if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial) || IsValueNode<KernelGraph>(return_input)) {
CreateCallNodeReturnFunction(partial_node, real_inputs);
}

std::vector<AnfNodePtr> new_partial_inputs = partial_node->inputs();
new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end());
auto new_partial = graph->NewCNode(new_partial_inputs);
new_make_tuple_inputs.emplace_back(new_partial);
}
// switch_layer node input is kernel graph value node
if (IsValueNode<KernelGraph>(partial_idx)) {
// make_tuple inputs is KernelGraph
std::vector<AnfNodePtr> new_partial_inputs;
partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_input);
new_partial_inputs = partial_node->inputs();
} else if (IsValueNode<KernelGraph>(partial_idx)) { // switch_layer node input is kernel graph value node
new_partial_inputs.emplace_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name())));
new_partial_inputs.emplace_back(partial_idx);
new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end());
auto new_partial = graph->NewCNode(new_partial_inputs);
new_make_tuple_inputs.emplace_back(new_partial);
}
partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_idx);
}
// when branch in swich_layer return function
MS_EXCEPTION_IF_NULL(partial_kernel_graph);
auto ret = partial_kernel_graph->get_return();
MS_EXCEPTION_IF_NULL(ret);
auto return_input = ret->input(kFirstDataInputIndex);
if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial) || return_input->isa<ValueNode>()) {
CreateCallNodeReturnFunction(cnode, partial_kernel_graph.get(), real_inputs);
}
// partial node add input args
new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end());
// create new partial node
auto new_partial = graph->NewCNode(new_partial_inputs);
new_make_tuple_inputs.emplace_back(new_partial);
}
auto new_make_tuple = graph->NewCNode(new_make_tuple_inputs);
switch_layer_inputs.emplace_back(new_make_tuple);


+ 2
- 1
mindspore/ccsrc/backend/session/session_basic.h View File

@@ -146,7 +146,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
void GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs,
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
std::vector<AnfNodePtr> CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph);
void CreateCallNodeReturnFunction(const CNodePtr &cnode, const std::vector<AnfNodePtr> &real_inputs);
void CreateCallNodeReturnFunction(const CNodePtr &cnode, KernelGraph *graph,
const std::vector<AnfNodePtr> &real_inputs);

protected:
friend class Executor;


Loading…
Cancel
Save