Browse Source

new control sink entry

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
tags/v0.3.0-alpha
zhoufeng 5 years ago
parent
commit
b78e54a5c4
8 changed files with 71 additions and 6 deletions
  1. +4
    -0
      mindspore/ccsrc/operator/ops.cc
  2. +4
    -0
      mindspore/ccsrc/operator/ops.h
  3. +51
    -2
      mindspore/ccsrc/pipeline/action.cc
  4. +1
    -1
      mindspore/ccsrc/session/ascend_session.cc
  5. +1
    -1
      mindspore/ccsrc/session/ascend_session.h
  6. +2
    -1
      mindspore/ccsrc/session/session_basic.h
  7. +4
    -0
      mindspore/ccsrc/vm/backend.cc
  8. +4
    -1
      mindspore/ccsrc/vm/backend.h

+ 4
- 0
mindspore/ccsrc/operator/ops.cc View File

@@ -78,6 +78,10 @@ const PrimitivePtr kPrimEmbed = std::make_shared<Primitive>("embed");
const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEmbed");
const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance");

const PrimitivePtr kPrimLabelGoto = std::make_shared<Primitive>("LabelGoto");
const PrimitivePtr kPrimLabelSwitch = std::make_shared<Primitive>("LabelSwitch");
const PrimitivePtr kPrimLabelSet = std::make_shared<Primitive>("LabelSet");

// Structure
const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal");
const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat");


+ 4
- 0
mindspore/ccsrc/operator/ops.h View File

@@ -84,6 +84,10 @@ extern const PrimitivePtr kPrimEmbed;
extern const PrimitivePtr kPrimRefToEmbed;
extern const PrimitivePtr kPrimCreateInstance;

extern const PrimitivePtr kPrimLabelGoto;
extern const PrimitivePtr kPrimLabelSwitch;
extern const PrimitivePtr kPrimLabelSet;

// Structure
extern const PrimitivePtr kPrimStringEqual;
extern const PrimitivePtr kPrimStringConcat;


+ 51
- 2
mindspore/ccsrc/pipeline/action.cc View File

@@ -269,13 +269,41 @@ bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePa

bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); }

static bool IsCtrlSink() {
auto ms_ctx = MsContext::GetInstance();
std::string device_target = ms_ctx->device_target();
if (device_target != kAscendDevice) {
return false;
}

if (!ms_ctx->enable_task_sink()) {
return false;
}

char *enable_ctrl_sink = std::getenv("ENABLE_CTRL_SINK");
if (enable_ctrl_sink == nullptr) {
return false;
}
std::string enable_ctrl_sink_str(enable_ctrl_sink);
if (enable_ctrl_sink_str == "0") {
return false;
}

return true;
}

bool TaskEmitAction(const ResourcePtr &res) {
if (res->func_graph() == nullptr) {
MS_LOG(EXCEPTION) << "TaskEmit args error";
}
FuncGraphPtr func_graph = res->func_graph();

auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>();

if (IsCtrlSink()) {
res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph));
return true;
}

std::vector<PrimitivePtr> cut_list = compile::nonlinear_ops;
if (bc_ptr->name() == kMsConvert) {
cut_list = compile::GetMsNonlinearOps();
@@ -286,10 +314,31 @@ bool TaskEmitAction(const ResourcePtr &res) {
}

bool ExecuteAction(const ResourcePtr &res) {
if (res->results().count(kOutput) == 0 || !res->results()[kOutput].is<compile::FinalVMPtr>()) {
if (res->results().count(kOutput) == 0) {
MS_LOG(EXCEPTION) << "Execute args error";
}

if (IsCtrlSink()) {
if (!res->results()[kOutput].is<GraphId>()) {
MS_LOG(EXCEPTION) << "Execute args error";
}

auto graph_id = res->results()[kOutput].cast<GraphId>();
auto bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::MsBackend>>();
compile::VmEvalFuncPtr run =
std::make_shared<compile::VmEvalFunc>([&bc_ptr, graph_id](const VectorRef &args) -> BaseRef {
MS_LOG(INFO) << "Execute args size" << args.size();
auto outs = bc_ptr->RunGraph(graph_id, args);
MS_LOG(DEBUG) << "out size" << outs.size();
return outs[0];
});
res->results()[kOutput] = run;
return true;
}

if (!res->results()[kOutput].is<compile::FinalVMPtr>()) {
MS_LOG(EXCEPTION) << "Execute args error";
}
compile::FinalVMPtr vm = res->results()[kOutput].cast<compile::FinalVMPtr>();
if (vm == nullptr) {
MS_LOG(INFO) << "Call GE to Run the func_graph instead of VM";


+ 1
- 1
mindspore/ccsrc/session/ascend_session.cc View File

@@ -138,7 +138,7 @@ GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrL
return graph_id;
}

GraphId AscendSession::CompileGraph(const FuncGraphPtr &func_graph) {
GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
MS_LOG(INFO) << "start";
auto graph = ConstructKernelGraph(func_graph);
// split switch


+ 1
- 1
mindspore/ccsrc/session/ascend_session.h View File

@@ -42,7 +42,7 @@ class AscendSession : public SessionBasic {
context_ = std::make_shared<Context>(kAscendDevice, device_id);
}
GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
GraphId CompileGraph(const FuncGraphPtr &func_graph) override;
GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) override;
void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override;
void BuildGraph(GraphId) override;
void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,


+ 2
- 1
mindspore/ccsrc/session/session_basic.h View File

@@ -28,6 +28,7 @@
#include "ir/meta_tensor.h"
#include "utils/any.h"
#include "utils/base_ref.h"
#include "utils/contract.h"
#include "pynative/pynative_execute.h"
#include "device/kernel_info.h"

@@ -57,7 +58,7 @@ class SessionBasic {
virtual ~SessionBasic() { summary_callback_ = nullptr; }

virtual GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0;
virtual GraphId CompileGraph(const FuncGraphPtr &) { return kInvalidGraphId; }
virtual GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; }
// build graph, used to handle multiple child graphs
virtual void BuildGraph(GraphId) {}



+ 4
- 0
mindspore/ccsrc/vm/backend.cc View File

@@ -327,5 +327,9 @@ MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_
sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
}

GraphId MsBackend::CompileGraph(NotNull<FuncGraphPtr> fg) { return sess_->CompileGraph(fg); }

VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); }

} // namespace compile
} // namespace mindspore

+ 4
- 1
mindspore/ccsrc/vm/backend.h View File

@@ -22,6 +22,7 @@
#include <unordered_map>
#include <utility>

#include "utils/contract.h"
#include "ir/anf.h"
#include "vm/segment_runner.h"
#include "vm/vm.h"
@@ -49,7 +50,7 @@ class Backend {
virtual void SetSwitchActive(const BaseRef &, bool) {}
virtual void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) {}
virtual void SetGraphUserInputs(const FuncGraphPtr &, const FuncGraphPtr &, const AnfNodePtrList &) {}
virtual GraphId CompileGraph(NotNull<FuncGraphPtr> fg) { return kInvalidGraphId; }
void set_curr_switch(const BaseRef &value) {
curr_switch_ = value;
is_switch_call_ = true;
@@ -104,6 +105,8 @@ class MsBackend : public Backend {
void Link(GraphId) override;
AnfNodePtr ConvertGraphInput(const FuncGraphPtr &, const AnfNodePtr &);
LinConvertResult GetMultiGraphRun(const FuncGraphPtr &g) override;
GraphId CompileGraph(NotNull<FuncGraphPtr> fg) override;
VectorRef RunGraph(GraphId graph_id, const VectorRef &args);

private:
session::SessionPtr sess_;


Loading…
Cancel
Save