Browse Source

stop TaskEmitAction when the graph return const or parameter

tags/v1.2.0-rc1
buxue 4 years ago
parent
commit
cd9770c0fe
2 changed files with 36 additions and 18 deletions
  1. +18
    -0
      mindspore/ccsrc/pipeline/jit/action.cc
  2. +18
    -18
      mindspore/ccsrc/pipeline/jit/pipeline.cc

+ 18
- 0
mindspore/ccsrc/pipeline/jit/action.cc View File

@@ -380,7 +380,21 @@ static bool IsCtrlSink() {
return true;
}

bool CheckGraphOutputConstOrParameter(const FuncGraphPtr &func_graph) {
if (func_graph != nullptr) {
AnfNodePtr output = func_graph->output();
if (output != nullptr && (output->isa<ValueNode>() || output->isa<Parameter>())) {
return true;
}
}
return false;
}

bool TaskEmitAction(const ResourcePtr &res) {
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode &&
CheckGraphOutputConstOrParameter(res->func_graph())) {
return true;
}
if (res->func_graph() == nullptr) {
MS_LOG(EXCEPTION) << "TaskEmit args error";
}
@@ -415,6 +429,10 @@ bool TaskEmitAction(const ResourcePtr &res) {
}

bool ExecuteAction(const ResourcePtr &res) {
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode &&
CheckGraphOutputConstOrParameter(res->func_graph())) {
return true;
}
if (res->results().count(kOutput) == 0) {
MS_LOG(EXCEPTION) << "Execute args error";
}


+ 18
- 18
mindspore/ccsrc/pipeline/jit/pipeline.cc View File

@@ -723,12 +723,14 @@ void Pipeline::Run() {
if (!result) {
MS_LOG(EXCEPTION) << "Pipeline running to end, failed in step:" << action.first;
}

FuncGraphPtr graph = resource_->func_graph();
#ifdef ENABLE_DUMP_IR
if (mindspore::RecorderManager::Instance().RdrEnable()) {
MS_LOG(INFO) << "Recording FuncGraph in pipeline using RDR.";
std::string tag = GetBaseNameForIR(i, action.first);
if (resource_->func_graph() != nullptr) {
auto graph_clone = BasicClone(resource_->func_graph());
if (graph != nullptr) {
auto graph_clone = BasicClone(graph);
if (graph_clone != nullptr) {
mindspore::RDR::RecordAnfGraph(SUBMODULE_ID, tag, graph_clone, false, ".ir");
} else {
@@ -740,23 +742,21 @@ void Pipeline::Run() {
MS_LOG(INFO) << "Recording FuncGraph in pipeline end.";
}
#endif
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && resource_->func_graph() != nullptr) {
auto graph = resource_->func_graph();
if (graph != nullptr) {
user_graph = graph;
std::string base_name = GetBaseNameForIR(i, action.first);

// generate IR file in dot format, which can be converted to svg file using graphviz dot command
draw::Draw(base_name + ".dot", graph);
// generate IR file in human readable format
if (i == actions_.size() - 1) {
DumpIR(base_name + ".ir", graph, false, kWholeStack);
} else {
DumpIR(base_name + ".ir", graph, false, kTopStack);
}
// generate IR file in a heavily commented format, which can also be reloaded
ExportIR(base_name + ".dat", std::to_string(i), graph);

if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && graph != nullptr) {
user_graph = graph;
std::string base_name = GetBaseNameForIR(i, action.first);

// generate IR file in dot format, which can be converted to svg file using graphviz dot command
draw::Draw(base_name + ".dot", graph);
// generate IR file in human readable format
if (i == actions_.size() - 1) {
DumpIR(base_name + ".ir", graph, false, kWholeStack);
} else {
DumpIR(base_name + ".ir", graph, false, kTopStack);
}
// generate IR file in a heavily commented format, which can also be reloaded
ExportIR(base_name + ".dat", std::to_string(i), graph);
}
i++;
#ifdef ENABLE_TIMELINE


Loading…
Cancel
Save