Browse Source

Mark null output subgraph

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
tags/v0.5.0-beta
zhoufeng 5 years ago
parent
commit
b7fae521c0
11 changed files with 142 additions and 88 deletions
  1. +0
    -2
      mindspore/ccsrc/device/ascend/ascend_label_assign.cc
  2. +0
    -1
      mindspore/ccsrc/device/ascend/ascend_label_assign.h
  3. +0
    -12
      mindspore/ccsrc/session/anf_runtime_algorithm.cc
  4. +0
    -1
      mindspore/ccsrc/session/anf_runtime_algorithm.h
  5. +6
    -5
      mindspore/ccsrc/session/ascend_control_parser.cc
  6. +7
    -5
      mindspore/ccsrc/session/ascend_control_parser.h
  7. +93
    -43
      mindspore/ccsrc/session/ascend_session.cc
  8. +4
    -7
      mindspore/ccsrc/session/ascend_session.h
  9. +17
    -5
      mindspore/ccsrc/session/kernel_graph.cc
  10. +4
    -1
      mindspore/ccsrc/session/kernel_graph.h
  11. +11
    -6
      mindspore/ccsrc/session/session_basic.cc

+ 0
- 2
mindspore/ccsrc/device/ascend/ascend_label_assign.cc View File

@@ -26,7 +26,6 @@ static constexpr uint32_t kLabelSwitchLabelId = 2;
namespace mindspore {
namespace device {
namespace ascend {

static void UpdateLabelGoto(NotNull<CNodePtr> node) {
if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, node)) {
return;
@@ -164,7 +163,6 @@ uint32_t AscendLabelAssign::GetLabelNum(NotNull<const session::KernelGraph *> gr
uint32_t AscendLabelAssign::GetLabelNum(NotNull<std::shared_ptr<session::KernelGraph>> graph) {
return GetLabelNum(NOT_NULL(graph.get().get()));
}

} // namespace ascend
} // namespace device
} // namespace mindspore

+ 0
- 1
mindspore/ccsrc/device/ascend/ascend_label_assign.h View File

@@ -25,7 +25,6 @@
namespace mindspore {
namespace device {
namespace ascend {

class AscendLabelAssign {
public:
static AscendLabelAssign &GetInstance() {


+ 0
- 12
mindspore/ccsrc/session/anf_runtime_algorithm.cc View File

@@ -974,17 +974,5 @@ bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) {
}
MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString();
}

bool AnfRuntimeAlgorithm::IsWhileTrueGraph(const KernelGraphPtr &child_graph) {
auto call_nodes = child_graph->FindNodeByPrimitive(prim::kPrimCall);
for (const auto &call_node : call_nodes) {
auto graphs = GetCallNodeKernelGraph(call_node);
if (graphs.size() == 1 && graphs[0] == child_graph->parent_graph()) {
return true;
}
}
return false;
}

} // namespace session
} // namespace mindspore

+ 0
- 1
mindspore/ccsrc/session/anf_runtime_algorithm.h View File

@@ -185,7 +185,6 @@ class AnfRuntimeAlgorithm {
static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node);
static std::vector<KernelGraphPtr> GetCallNodeKernelGraph(const CNodePtr &call_node);
static bool IsSwitchCall(const CNodePtr &call_node);
static bool IsWhileTrueGraph(const KernelGraphPtr &child_graph);
};
} // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm;


+ 6
- 5
mindspore/ccsrc/session/ascend_control_parser.cc View File

@@ -83,7 +83,7 @@ CNodePtr AscendControlParser::GetNextRealKernel(const std::vector<CNodePtr> &lis

NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
const CNodePtr &last_label,
NotNull<std::set<KernelGraphPtr> *> memo) {
const NotNull<std::set<KernelGraphPtr> *> memo) {
MS_LOG(INFO) << "Start process KernelGraph " << kg->ToString();

// 1. recursive condition
@@ -180,7 +180,7 @@ void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNod
}

void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node,
NotNull<std::set<KernelGraphPtr> *> memo) {
const NotNull<std::set<KernelGraphPtr> *> memo) {
MS_LOG(INFO) << "process call func " << cur_node->DebugString();

// 1 get kernel graph
@@ -212,7 +212,7 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP
}

void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node,
const CNodePtr &next_node, NotNull<std::set<KernelGraphPtr> *> memo) {
const CNodePtr &next_node, const NotNull<std::set<KernelGraphPtr> *> memo) {
MS_LOG(INFO) << "process switch node " << cur_node->DebugString();

if (cur_node->size() < kCNodeSwitchLength) {
@@ -249,7 +249,8 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod
}

void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node,
const CNodePtr &next_node, NotNull<std::set<KernelGraphPtr> *> memo) {
const CNodePtr &next_node,
const NotNull<std::set<KernelGraphPtr> *> memo) {
MS_LOG(INFO) << "process switch node " << cur_node->DebugString();

if (cur_node->size() < kCNodeSwitchLayerLength) {
@@ -353,7 +354,7 @@ void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
}

std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph,
NotNull<std::set<KernelGraphPtr> *> memo) {
const NotNull<std::set<KernelGraphPtr> *> memo) {
MS_LOG(INFO) << "graph:" << graph->graph_id() << " start";
auto print_vector = [&](std::vector<CNodePtr> vec) -> void {
MS_LOG(INFO) << "graph:" << graph->graph_id() << "execution order";


+ 7
- 5
mindspore/ccsrc/session/ascend_control_parser.h View File

@@ -40,13 +40,14 @@ class AscendControlParser {

private:
static NotNull<CNodePtr> ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
const CNodePtr &last_label, NotNull<std::set<KernelGraphPtr> *> memo);
const CNodePtr &last_label,
const NotNull<std::set<KernelGraphPtr> *> memo);
static void RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node,
NotNull<std::set<KernelGraphPtr> *> memo);
const NotNull<std::set<KernelGraphPtr> *> memo);
static void RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node,
NotNull<std::set<KernelGraphPtr> *> memo);
const NotNull<std::set<KernelGraphPtr> *> memo);
static void RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node,
NotNull<std::set<KernelGraphPtr> *> memo);
const NotNull<std::set<KernelGraphPtr> *> memo);

static void LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node,
const CNodePtr &last_label);
@@ -63,7 +64,8 @@ class AscendControlParser {
static std::vector<uint32_t> GetLabelSwitchList(const CNodePtr &node);
static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode,
NotNull<KernelGraphPtr> graph);
static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo);
static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph,
const NotNull<std::set<KernelGraphPtr> *> memo);
};
} // namespace session
} // namespace mindspore


+ 93
- 43
mindspore/ccsrc/session/ascend_session.cc View File

@@ -171,19 +171,35 @@ std::vector<CNodePtr> GetCNodes(const std::vector<AnfNodePtr> &anf_nodes) {
return cnodes;
}

std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, const std::vector<CNodePtr> &cnodes) {
size_t after_call_index = 0;
static std::vector<std::vector<CNodePtr>> GetChildList(const std::vector<CNodePtr> &cnodes,
const std::set<PrimitivePtr> &cut_prims) {
size_t after_cut_index = 0;
std::vector<std::vector<CNodePtr>> ret;
for (size_t i = 0; i < cnodes.size(); i++) {
if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimCall) && !AnfAlgo::IsSwitchCall(cnodes[i])) {
auto call_kernel_graph = AnfAlgo::GetCallNodeKernelGraph(cnodes[i]);
auto prev_call_list = std::vector<CNodePtr>(cnodes.begin() + after_call_index, cnodes.begin() + i);
auto call_list = std::vector<CNodePtr>(1, cnodes[i]);
after_call_index = i + 1;
ret.push_back(prev_call_list);
ret.push_back(call_list);
} else if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimReturn)) {
ret.push_back(std::vector<CNodePtr>(cnodes.begin() + after_call_index, cnodes.end()));
for (size_t i = 0; i < cnodes.size(); ++i) {
bool is_cut_node = false;
for (auto &prim : cut_prims) {
if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim)) {
is_cut_node = true;
break;
}
}
if (is_cut_node) {
// is call and not switch call,cut to 3 lists
if (!AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimCall)) {
// if is not a call,cut to 2 lists
ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.begin() + i);
after_cut_index = i;
} else if (!AnfAlgo::IsSwitchCall(cnodes[i])) {
ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.begin() + i);
ret.emplace_back(1, cnodes[i]);
after_cut_index = i + 1;
continue;
}
}
// get last child graph list
if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimReturn)) {
ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.end());
continue;
}
}
return ret;
@@ -191,7 +207,7 @@ std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, co

// if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of
// graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2]
static void UpdateRealInput(KernelGraph *graph) {
static void UpdateRealInput(NotNull<KernelGraphPtr> graph) {
auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall);
auto bind_call_arg_with_parameter = [&](const std::vector<AnfNodePtr> &parameters,
const std::vector<AnfNodePtr> &args, KernelGraph *child_graph) -> void {
@@ -253,16 +269,17 @@ static void UpdateRealInput(KernelGraph *graph) {
}
}

void RecurseToUpdateCallRealInput(KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
static void RecurseToUpdateCallRealInput(NotNull<KernelGraphPtr> graph,
const NotNull<std::set<KernelGraphPtr> *> memo) {
memo->insert(graph.get());
MS_LOG(INFO) << "start graph id:" << graph->graph_id();
for (auto &child_graph : graph->child_graph_order()) {
if (child_graph == graph->parent_graph()) {
if (memo->find(child_graph) != memo->end()) {
MS_LOG(INFO) << "Child graph:" << child_graph->graph_id()
<< ",parent graph:" << graph->parent_graph()->graph_id();
continue;
}
RecurseToUpdateCallRealInput(child_graph.get());
RecurseToUpdateCallRealInput(NOT_NULL(child_graph), memo);
}
// this action should from bottom to top
graph->UpdateCallRealInput();
@@ -282,7 +299,7 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
MS_LOG(INFO) << "start";
auto graph = ConstructKernelGraph(func_graph);
// split switch
SplitGraphs(graph);
SplitGraphs(NOT_NULL(graph));
// insert goto labels and label_sets
LinkChildGraphs(NOT_NULL(graph));
// resource initialize
@@ -290,7 +307,8 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
// assign label
AssignLabel(NOT_NULL(graph));
// recurse compile child graph
RecurseCompileGraph(graph);
std::set<KernelGraphPtr> memo;
RecurseCompileGraph(NOT_NULL(graph), NOT_NULL(&memo));
// root graph valiate,include genearte execute order and so on
RootGraphExecutorValidate(NOT_NULL(graph));
// adjust kernel
@@ -1423,24 +1441,43 @@ std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPt
}
MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id();
std::vector<AnfNodePtr> call_node_inputs;
auto graph_inputs = new_kernel_graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs);
std::vector<AnfNodePtr> new_graph_inputs;
// create new parameter from cnode
for (auto &anf_node : list) {
auto cnode = anf_node->cast<CNodePtr>();
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
auto input = cnode->inputs()[input_idx];
MS_EXCEPTION_IF_NULL(input);
if (input->isa<Parameter>()) {
graph_inputs->push_back(input);
AnfNodePtr new_parameter = nullptr;
// value node consider move to new graph
if (input->isa<ValueNode>()) {
cnode->set_input(input_idx, input);
continue;
} else if (input->isa<Parameter>()) {
// parameter reuse and should attention mulptiple use of one parameter
cnode->set_input(input_idx, input);
new_parameter = input;
} else if (AnfAlgo::GetGraphId(input.get()) != new_kernel_graph->graph_id()) {
auto new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get());
// if is cnode and not in current child graph
new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get());
cnode->set_input(input_idx, new_parameter);
} else {
// if is a cnode and in current graph
continue;
}
// if mulptiple use of one parameter or cnode, only set one parameter in graph inputs and one arg in call node
// args
if (std::find(call_node_inputs.begin(), call_node_inputs.end(), new_parameter) == call_node_inputs.end()) {
new_graph_inputs.push_back(new_parameter);
call_node_inputs.push_back(input);
}
call_node_inputs.push_back(input);
}
}
// set graph inputs of new graph
auto graph_inputs = new_kernel_graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs);
graph_inputs->clear();
std::copy(new_graph_inputs.begin(), new_graph_inputs.end(), std::back_inserter(*graph_inputs));
MS_LOG(INFO) << "Construct output of kernel graph:" << new_kernel_graph->graph_id();
auto make_tuple_primitve = NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()));
std::vector<AnfNodePtr> make_tuple_inputs = {make_tuple_primitve};
@@ -1461,20 +1498,30 @@ std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPt
return call_node_inputs;
}

void AscendSession::SplitGraphs(const KernelGraphPtr &root_graph) {
SplitGraph(root_graph);
void AscendSession::SplitGraphs(NotNull<KernelGraphPtr> root_graph) {
std::set<KernelGraphPtr> memo;
// if root graph output is a call node ,the root graph is condition graph of 'if' sentence
auto root_graph_output = AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0).first;
if (AnfAlgo::CheckPrimitiveType(root_graph_output, prim::kPrimCall)) {
SplitGraph(root_graph, {prim::kPrimReturn});
for (auto &child_graph : root_graph->child_graph_order()) {
RecurseSplitGraph(NOT_NULL(child_graph), NOT_NULL(&memo));
}
} else {
RecurseSplitGraph(root_graph, NOT_NULL(&memo));
}
memo.clear();
// replace the real input if the real input is a call
RecurseToUpdateCallRealInput(root_graph.get());
RecurseToUpdateCallRealInput(root_graph, NOT_NULL(&memo));
}

void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims) {
MS_LOG(INFO) << "start,graph_id:" << graph->graph_id();
MS_EXCEPTION_IF_NULL(graph);
auto apply_list = GetCNodes(TopoSort(graph->get_return()));
// update the root graph child graph order
AscendControlParser::UpdateChildGraphOrder(NOT_NULL(graph));
AscendControlParser::UpdateChildGraphOrder(graph);
// get child list from current graph
std::vector<std::vector<CNodePtr>> child_graph_lists = GetChildList(*graph, apply_list);
std::vector<std::vector<CNodePtr>> child_graph_lists = GetChildList(apply_list, cut_prims);
auto bind_new_call_to_new_graph = [&](std::vector<CNodePtr> child_graph_list) -> AnfNodePtr {
// if child graph list only has a call ,then return the exist call
if (child_graph_list.size() == 1 && AnfAlgo::CheckPrimitiveType(child_graph_list[0], prim::kPrimCall)) {
@@ -1521,20 +1568,22 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
pre_call_node = cur_call_node;
cur_call_node = *iter;
if (pre_call_node != nullptr && cur_call_node != nullptr) {
AscendControlParser::InsertControlDependToGraph(NOT_NULL(graph), NOT_NULL(cur_call_node),
NOT_NULL(pre_call_node));
AscendControlParser::InsertControlDependToGraph(graph, NOT_NULL(cur_call_node), NOT_NULL(pre_call_node));
}
}
}
AscendControlParser::UpdateChildGraphOrder(NOT_NULL(graph));
UpdateRealInput(graph.get());
auto graph_name = std::string("./kernel-graph-").append(std::to_string(graph->graph_id()));
DumpIR(graph_name, graph);
AscendControlParser::UpdateChildGraphOrder(graph);
UpdateRealInput(graph);
MS_LOG(INFO) << "split graph[" << graph->graph_id() << "] end";
// recurse to split child graph
}

void AscendSession::RecurseSplitGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo) {
memo->insert(graph.get());
SplitGraph(graph, {prim::kPrimCall});
for (auto &child_graph : graph->child_graph_order()) {
if (child_graph != graph->parent_graph()) {
SplitGraph(child_graph);
if (memo->find(child_graph) == memo->end()) {
RecurseSplitGraph(NOT_NULL(child_graph), memo);
}
}
}
@@ -1545,13 +1594,14 @@ void AscendSession::RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph) {
AscendControlParser::ExecutorValidate(graph);
}

void AscendSession::RecurseCompileGraph(const KernelGraphPtr &graph) {
void AscendSession::RecurseCompileGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo) {
memo->insert(graph.get());
CompileChildGraph(graph);
for (auto child_graph : graph->child_graph_order()) {
if (child_graph == graph->parent_graph()) {
if (memo->find(child_graph) != memo->end()) {
continue;
}
RecurseCompileGraph(child_graph);
RecurseCompileGraph(NOT_NULL(child_graph), memo);
}
}



+ 4
- 7
mindspore/ccsrc/session/ascend_session.h View File

@@ -98,18 +98,15 @@ class AscendSession : public SessionBasic {
void SetFinalGraphOutput(const ValuePtr &value);
void SetFinalGraphOutput(const VectorRef &vec_output);

void SplitGraph(const KernelGraphPtr &graph);
void SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims);
// split graphs with recurse from root graph
void SplitGraphs(const KernelGraphPtr &root_graph);
void SplitGraphs(NotNull<KernelGraphPtr> root_graph);
void LinkChildGraphs(NotNull<KernelGraphPtr> graph);
void IRFusion(const KernelGraphPtr &graph) {}
void SelectKernelGraphKernel(const KernelGraph &graph) {}
void ConvertPredictModel(const KernelGraphPtr graph) {}
void HardwareOptimizeGraphs(const KernelGraphPtr graph) {}
void RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph);
std::vector<AnfNodePtr> ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph,
const std::vector<CNodePtr> &list);
void RecurseCompileGraph(const KernelGraphPtr &graph);
void RecurseCompileGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo);
void RecurseSplitGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo);

// merge execution order list of child graphs
void MergeGraphExecOrder();


+ 17
- 5
mindspore/ccsrc/session/kernel_graph.cc View File

@@ -50,7 +50,7 @@ std::vector<AnfNodePtr> GetCallRealOutputs(const AnfNodePtr &call_node) {
std::vector<AnfNodePtr> real_inputs;
auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(item_with_index.first->cast<CNodePtr>());
for (const auto &child_graph : child_graphs) {
if (AnfAlgo::IsWhileTrueGraph(child_graph)) {
if (child_graph->get_output_null()) {
continue;
}
auto real_input = child_graph->output();
@@ -592,7 +592,11 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf
MS_EXCEPTION_IF_NULL(output_node.first);
auto output_cnode = output_node.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_cnode);
const auto &output_node_inputs = output_cnode->inputs();
auto &output_node_inputs = output_cnode->inputs();
// don't replace node if it is a control edge => output_node.second == 0
if (output_node.second == 0) {
continue;
}
for (size_t i = 1; i < output_node_inputs.size(); i++) {
if (output_node_inputs[i] == old_anf_node) {
output_cnode->set_input(i, new_anf_node);
@@ -686,10 +690,12 @@ std::set<AnfNodePtr> KernelGraph::GetRealInput(const AnfNodePtr &parameter) {

void KernelGraph::UpdateCallRealInput() {
MS_LOG(INFO) << "Update graph id: " << graph_id_;
std::map<AnfNodePtr, std::set<AnfNodePtr>> real_inputs_map;
std::vector<std::pair<AnfNodePtr, AnfNodePtr>> replace_list;
for (auto &it : real_inputs_) {
auto &parameter = it.first;
auto parameter = it.first;
MS_EXCEPTION_IF_NULL(parameter);
auto &real_inputs = it.second;
auto real_inputs = it.second;
std::vector<AnfNodePtr> new_real_inputs;
std::set<AnfNodePtr> erase_real_inputs;
for (auto &real_input : real_inputs) {
@@ -711,10 +717,16 @@ void KernelGraph::UpdateCallRealInput() {
<< " insert real input:" << new_real_input->DebugString();
(void)real_inputs.insert(new_real_input);
if (new_real_input->isa<Parameter>()) {
ReplaceNode(parameter, new_real_input);
replace_list.emplace_back(parameter, new_real_input);
parameter = new_real_input;
}
}
real_inputs_map[parameter] = real_inputs;
}
for (auto [parameter, arg] : replace_list) {
ReplaceNode(parameter, arg);
}
real_inputs_ = real_inputs_map;
}

std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); }


+ 4
- 1
mindspore/ccsrc/session/kernel_graph.h View File

@@ -36,7 +36,7 @@ namespace session {
using AnfWithOutIndex = std::pair<AnfNodePtr, size_t>;
class KernelGraph : public FuncGraph {
public:
KernelGraph() : graph_id_(0) {
KernelGraph() : graph_id_(0), start_label_(nullptr), end_goto_(nullptr), null_output_(false) {
inputs_ = std::make_shared<std::vector<AnfNodePtr>>();
execution_order_ = {};
executable_ = true;
@@ -134,6 +134,8 @@ class KernelGraph : public FuncGraph {
CNodePtr get_start_label() { return start_label_; }
void set_end_goto(const CNodePtr &end_goto) { end_goto_ = end_goto; }
CNodePtr get_end_goto() { return end_goto_; }
bool get_output_null() { return null_output_; }
void set_output_null(bool is_output_null) { null_output_ = is_output_null; }

private:
// remove value node form graph
@@ -188,6 +190,7 @@ class KernelGraph : public FuncGraph {

CNodePtr start_label_;
CNodePtr end_goto_;
bool null_output_;
};
} // namespace session
using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;


+ 11
- 6
mindspore/ccsrc/session/session_basic.cc View File

@@ -543,15 +543,12 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con

std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
if (front_backend_graph_map_.find(func_graph) != front_backend_graph_map_.end()) {
MS_LOG(INFO) << "FuncGraph: " << func_graph->ToString() << " has been transformed to KernelGraph.";
return front_backend_graph_map_[func_graph];
}
auto node_list = TopoSort(func_graph->get_return());
auto graph = NewKernelGraph();
front_backend_graph_map_[func_graph] = graph;
MS_LOG(INFO) << "Create graph: " << graph->graph_id();

bool is_trace_back = false;
for (const auto &node : node_list) {
MS_EXCEPTION_IF_NULL(node);
MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
@@ -564,8 +561,14 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
(void)CreateNewValueNode(node, graph.get());
} else {
// if input is a ValueNode<FuncGraph>
auto child_graph = ConstructKernelGraph(AnfAlgo::GetValueNodeFuncGraph(node));
auto new_value_node = CreateValueNodeKernelGraph(node, graph.get());
FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node);
if (front_backend_graph_map_.find(child_graph) != front_backend_graph_map_.end()) {
MS_LOG(INFO) << "FuncGraph: " << child_graph->ToString() << " has been transformed to KernelGraph.";
is_trace_back = true;
} else {
(void)ConstructKernelGraph(child_graph);
}
(void)CreateValueNodeKernelGraph(node, graph.get());
}
continue;
} else {
@@ -582,6 +585,8 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
}
}
}
// if a graph jump back unconditionally, return op of this graph will never be executed, so output is null.
graph->set_output_null(is_trace_back);
auto graph_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs);
graph_inputs->clear();


Loading…
Cancel
Save