Browse Source

new control sink entry

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
tags/v0.3.0-alpha
zhoufeng 5 years ago
parent
commit
b78e54a5c4
8 changed files with 71 additions and 6 deletions
  1. +4
    -0
      mindspore/ccsrc/operator/ops.cc
  2. +4
    -0
      mindspore/ccsrc/operator/ops.h
  3. +51
    -2
      mindspore/ccsrc/pipeline/action.cc
  4. +1
    -1
      mindspore/ccsrc/session/ascend_session.cc
  5. +1
    -1
      mindspore/ccsrc/session/ascend_session.h
  6. +2
    -1
      mindspore/ccsrc/session/session_basic.h
  7. +4
    -0
      mindspore/ccsrc/vm/backend.cc
  8. +4
    -1
      mindspore/ccsrc/vm/backend.h

+ 4
- 0
mindspore/ccsrc/operator/ops.cc View File

@@ -78,6 +78,10 @@ const PrimitivePtr kPrimEmbed = std::make_shared<Primitive>("embed");
const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEmbed"); const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEmbed");
const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance"); const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance");


const PrimitivePtr kPrimLabelGoto = std::make_shared<Primitive>("LabelGoto");
const PrimitivePtr kPrimLabelSwitch = std::make_shared<Primitive>("LabelSwitch");
const PrimitivePtr kPrimLabelSet = std::make_shared<Primitive>("LabelSet");

// Structure // Structure
const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal"); const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal");
const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat"); const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat");


+ 4
- 0
mindspore/ccsrc/operator/ops.h View File

@@ -84,6 +84,10 @@ extern const PrimitivePtr kPrimEmbed;
extern const PrimitivePtr kPrimRefToEmbed; extern const PrimitivePtr kPrimRefToEmbed;
extern const PrimitivePtr kPrimCreateInstance; extern const PrimitivePtr kPrimCreateInstance;


extern const PrimitivePtr kPrimLabelGoto;
extern const PrimitivePtr kPrimLabelSwitch;
extern const PrimitivePtr kPrimLabelSet;

// Structure // Structure
extern const PrimitivePtr kPrimStringEqual; extern const PrimitivePtr kPrimStringEqual;
extern const PrimitivePtr kPrimStringConcat; extern const PrimitivePtr kPrimStringConcat;


+ 51
- 2
mindspore/ccsrc/pipeline/action.cc View File

@@ -269,13 +269,41 @@ bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePa


bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); } bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); }


static bool IsCtrlSink() {
auto ms_ctx = MsContext::GetInstance();
std::string device_target = ms_ctx->device_target();
if (device_target != kAscendDevice) {
return false;
}

if (!ms_ctx->enable_task_sink()) {
return false;
}

char *enable_ctrl_sink = std::getenv("ENABLE_CTRL_SINK");
if (enable_ctrl_sink == nullptr) {
return false;
}
std::string enable_ctrl_sink_str(enable_ctrl_sink);
if (enable_ctrl_sink_str == "0") {
return false;
}

return true;
}

bool TaskEmitAction(const ResourcePtr &res) { bool TaskEmitAction(const ResourcePtr &res) {
if (res->func_graph() == nullptr) { if (res->func_graph() == nullptr) {
MS_LOG(EXCEPTION) << "TaskEmit args error"; MS_LOG(EXCEPTION) << "TaskEmit args error";
} }
FuncGraphPtr func_graph = res->func_graph(); FuncGraphPtr func_graph = res->func_graph();

auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>(); auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>();

if (IsCtrlSink()) {
res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph));
return true;
}

std::vector<PrimitivePtr> cut_list = compile::nonlinear_ops; std::vector<PrimitivePtr> cut_list = compile::nonlinear_ops;
if (bc_ptr->name() == kMsConvert) { if (bc_ptr->name() == kMsConvert) {
cut_list = compile::GetMsNonlinearOps(); cut_list = compile::GetMsNonlinearOps();
@@ -286,10 +314,31 @@ bool TaskEmitAction(const ResourcePtr &res) {
} }


bool ExecuteAction(const ResourcePtr &res) { bool ExecuteAction(const ResourcePtr &res) {
if (res->results().count(kOutput) == 0 || !res->results()[kOutput].is<compile::FinalVMPtr>()) {
if (res->results().count(kOutput) == 0) {
MS_LOG(EXCEPTION) << "Execute args error"; MS_LOG(EXCEPTION) << "Execute args error";
} }


if (IsCtrlSink()) {
if (!res->results()[kOutput].is<GraphId>()) {
MS_LOG(EXCEPTION) << "Execute args error";
}

auto graph_id = res->results()[kOutput].cast<GraphId>();
auto bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::MsBackend>>();
compile::VmEvalFuncPtr run =
std::make_shared<compile::VmEvalFunc>([&bc_ptr, graph_id](const VectorRef &args) -> BaseRef {
MS_LOG(INFO) << "Execute args size" << args.size();
auto outs = bc_ptr->RunGraph(graph_id, args);
MS_LOG(DEBUG) << "out size" << outs.size();
return outs[0];
});
res->results()[kOutput] = run;
return true;
}

if (!res->results()[kOutput].is<compile::FinalVMPtr>()) {
MS_LOG(EXCEPTION) << "Execute args error";
}
compile::FinalVMPtr vm = res->results()[kOutput].cast<compile::FinalVMPtr>(); compile::FinalVMPtr vm = res->results()[kOutput].cast<compile::FinalVMPtr>();
if (vm == nullptr) { if (vm == nullptr) {
MS_LOG(INFO) << "Call GE to Run the func_graph instead of VM"; MS_LOG(INFO) << "Call GE to Run the func_graph instead of VM";


+ 1
- 1
mindspore/ccsrc/session/ascend_session.cc View File

@@ -138,7 +138,7 @@ GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrL
return graph_id; return graph_id;
} }


GraphId AscendSession::CompileGraph(const FuncGraphPtr &func_graph) {
GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
MS_LOG(INFO) << "start"; MS_LOG(INFO) << "start";
auto graph = ConstructKernelGraph(func_graph); auto graph = ConstructKernelGraph(func_graph);
// split switch // split switch


+ 1
- 1
mindspore/ccsrc/session/ascend_session.h View File

@@ -42,7 +42,7 @@ class AscendSession : public SessionBasic {
context_ = std::make_shared<Context>(kAscendDevice, device_id); context_ = std::make_shared<Context>(kAscendDevice, device_id);
} }
GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
GraphId CompileGraph(const FuncGraphPtr &func_graph) override;
GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) override;
void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override;
void BuildGraph(GraphId) override; void BuildGraph(GraphId) override;
void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,


+ 2
- 1
mindspore/ccsrc/session/session_basic.h View File

@@ -28,6 +28,7 @@
#include "ir/meta_tensor.h" #include "ir/meta_tensor.h"
#include "utils/any.h" #include "utils/any.h"
#include "utils/base_ref.h" #include "utils/base_ref.h"
#include "utils/contract.h"
#include "pynative/pynative_execute.h" #include "pynative/pynative_execute.h"
#include "device/kernel_info.h" #include "device/kernel_info.h"


@@ -57,7 +58,7 @@ class SessionBasic {
virtual ~SessionBasic() { summary_callback_ = nullptr; } virtual ~SessionBasic() { summary_callback_ = nullptr; }


virtual GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0; virtual GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0;
virtual GraphId CompileGraph(const FuncGraphPtr &) { return kInvalidGraphId; }
virtual GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; }
// build graph, used to handle multiple child graphs // build graph, used to handle multiple child graphs
virtual void BuildGraph(GraphId) {} virtual void BuildGraph(GraphId) {}




+ 4
- 0
mindspore/ccsrc/vm/backend.cc View File

@@ -327,5 +327,9 @@ MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_
sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
} }


GraphId MsBackend::CompileGraph(NotNull<FuncGraphPtr> fg) { return sess_->CompileGraph(fg); }

VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); }

} // namespace compile } // namespace compile
} // namespace mindspore } // namespace mindspore

+ 4
- 1
mindspore/ccsrc/vm/backend.h View File

@@ -22,6 +22,7 @@
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>


#include "utils/contract.h"
#include "ir/anf.h" #include "ir/anf.h"
#include "vm/segment_runner.h" #include "vm/segment_runner.h"
#include "vm/vm.h" #include "vm/vm.h"
@@ -49,7 +50,7 @@ class Backend {
virtual void SetSwitchActive(const BaseRef &, bool) {} virtual void SetSwitchActive(const BaseRef &, bool) {}
virtual void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) {} virtual void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) {}
virtual void SetGraphUserInputs(const FuncGraphPtr &, const FuncGraphPtr &, const AnfNodePtrList &) {} virtual void SetGraphUserInputs(const FuncGraphPtr &, const FuncGraphPtr &, const AnfNodePtrList &) {}
virtual GraphId CompileGraph(NotNull<FuncGraphPtr> fg) { return kInvalidGraphId; }
void set_curr_switch(const BaseRef &value) { void set_curr_switch(const BaseRef &value) {
curr_switch_ = value; curr_switch_ = value;
is_switch_call_ = true; is_switch_call_ = true;
@@ -104,6 +105,8 @@ class MsBackend : public Backend {
void Link(GraphId) override; void Link(GraphId) override;
AnfNodePtr ConvertGraphInput(const FuncGraphPtr &, const AnfNodePtr &); AnfNodePtr ConvertGraphInput(const FuncGraphPtr &, const AnfNodePtr &);
LinConvertResult GetMultiGraphRun(const FuncGraphPtr &g) override; LinConvertResult GetMultiGraphRun(const FuncGraphPtr &g) override;
GraphId CompileGraph(NotNull<FuncGraphPtr> fg) override;
VectorRef RunGraph(GraphId graph_id, const VectorRef &args);


private: private:
session::SessionPtr sess_; session::SessionPtr sess_;


Loading…
Cancel
Save