Signed-off-by: zhoufeng <zhoufeng54@huawei.com>tags/v0.3.0-alpha
| @@ -78,6 +78,10 @@ const PrimitivePtr kPrimEmbed = std::make_shared<Primitive>("embed"); | |||||
| const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEmbed"); | const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEmbed"); | ||||
| const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance"); | const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance"); | ||||
| const PrimitivePtr kPrimLabelGoto = std::make_shared<Primitive>("LabelGoto"); | |||||
| const PrimitivePtr kPrimLabelSwitch = std::make_shared<Primitive>("LabelSwitch"); | |||||
| const PrimitivePtr kPrimLabelSet = std::make_shared<Primitive>("LabelSet"); | |||||
| // Structure | // Structure | ||||
| const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal"); | const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal"); | ||||
| const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat"); | const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat"); | ||||
| @@ -84,6 +84,10 @@ extern const PrimitivePtr kPrimEmbed; | |||||
| extern const PrimitivePtr kPrimRefToEmbed; | extern const PrimitivePtr kPrimRefToEmbed; | ||||
| extern const PrimitivePtr kPrimCreateInstance; | extern const PrimitivePtr kPrimCreateInstance; | ||||
| extern const PrimitivePtr kPrimLabelGoto; | |||||
| extern const PrimitivePtr kPrimLabelSwitch; | |||||
| extern const PrimitivePtr kPrimLabelSet; | |||||
| // Structure | // Structure | ||||
| extern const PrimitivePtr kPrimStringEqual; | extern const PrimitivePtr kPrimStringEqual; | ||||
| extern const PrimitivePtr kPrimStringConcat; | extern const PrimitivePtr kPrimStringConcat; | ||||
| @@ -269,13 +269,41 @@ bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePa | |||||
| bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); } | bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); } | ||||
| static bool IsCtrlSink() { | |||||
| auto ms_ctx = MsContext::GetInstance(); | |||||
| std::string device_target = ms_ctx->device_target(); | |||||
| if (device_target != kAscendDevice) { | |||||
| return false; | |||||
| } | |||||
| if (!ms_ctx->enable_task_sink()) { | |||||
| return false; | |||||
| } | |||||
| char *enable_ctrl_sink = std::getenv("ENABLE_CTRL_SINK"); | |||||
| if (enable_ctrl_sink == nullptr) { | |||||
| return false; | |||||
| } | |||||
| std::string enable_ctrl_sink_str(enable_ctrl_sink); | |||||
| if (enable_ctrl_sink_str == "0") { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool TaskEmitAction(const ResourcePtr &res) { | bool TaskEmitAction(const ResourcePtr &res) { | ||||
| if (res->func_graph() == nullptr) { | if (res->func_graph() == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "TaskEmit args error"; | MS_LOG(EXCEPTION) << "TaskEmit args error"; | ||||
| } | } | ||||
| FuncGraphPtr func_graph = res->func_graph(); | FuncGraphPtr func_graph = res->func_graph(); | ||||
| auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>(); | auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>(); | ||||
| if (IsCtrlSink()) { | |||||
| res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph)); | |||||
| return true; | |||||
| } | |||||
| std::vector<PrimitivePtr> cut_list = compile::nonlinear_ops; | std::vector<PrimitivePtr> cut_list = compile::nonlinear_ops; | ||||
| if (bc_ptr->name() == kMsConvert) { | if (bc_ptr->name() == kMsConvert) { | ||||
| cut_list = compile::GetMsNonlinearOps(); | cut_list = compile::GetMsNonlinearOps(); | ||||
| @@ -286,10 +314,31 @@ bool TaskEmitAction(const ResourcePtr &res) { | |||||
| } | } | ||||
| bool ExecuteAction(const ResourcePtr &res) { | bool ExecuteAction(const ResourcePtr &res) { | ||||
| if (res->results().count(kOutput) == 0 || !res->results()[kOutput].is<compile::FinalVMPtr>()) { | |||||
| if (res->results().count(kOutput) == 0) { | |||||
| MS_LOG(EXCEPTION) << "Execute args error"; | MS_LOG(EXCEPTION) << "Execute args error"; | ||||
| } | } | ||||
| if (IsCtrlSink()) { | |||||
| if (!res->results()[kOutput].is<GraphId>()) { | |||||
| MS_LOG(EXCEPTION) << "Execute args error"; | |||||
| } | |||||
| auto graph_id = res->results()[kOutput].cast<GraphId>(); | |||||
| auto bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::MsBackend>>(); | |||||
| compile::VmEvalFuncPtr run = | |||||
| std::make_shared<compile::VmEvalFunc>([&bc_ptr, graph_id](const VectorRef &args) -> BaseRef { | |||||
| MS_LOG(INFO) << "Execute args size" << args.size(); | |||||
| auto outs = bc_ptr->RunGraph(graph_id, args); | |||||
| MS_LOG(DEBUG) << "out size" << outs.size(); | |||||
| return outs[0]; | |||||
| }); | |||||
| res->results()[kOutput] = run; | |||||
| return true; | |||||
| } | |||||
| if (!res->results()[kOutput].is<compile::FinalVMPtr>()) { | |||||
| MS_LOG(EXCEPTION) << "Execute args error"; | |||||
| } | |||||
| compile::FinalVMPtr vm = res->results()[kOutput].cast<compile::FinalVMPtr>(); | compile::FinalVMPtr vm = res->results()[kOutput].cast<compile::FinalVMPtr>(); | ||||
| if (vm == nullptr) { | if (vm == nullptr) { | ||||
| MS_LOG(INFO) << "Call GE to Run the func_graph instead of VM"; | MS_LOG(INFO) << "Call GE to Run the func_graph instead of VM"; | ||||
| @@ -138,7 +138,7 @@ GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrL | |||||
| return graph_id; | return graph_id; | ||||
| } | } | ||||
| GraphId AscendSession::CompileGraph(const FuncGraphPtr &func_graph) { | |||||
| GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||||
| MS_LOG(INFO) << "start"; | MS_LOG(INFO) << "start"; | ||||
| auto graph = ConstructKernelGraph(func_graph); | auto graph = ConstructKernelGraph(func_graph); | ||||
| // split switch | // split switch | ||||
| @@ -42,7 +42,7 @@ class AscendSession : public SessionBasic { | |||||
| context_ = std::make_shared<Context>(kAscendDevice, device_id); | context_ = std::make_shared<Context>(kAscendDevice, device_id); | ||||
| } | } | ||||
| GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | ||||
| GraphId CompileGraph(const FuncGraphPtr &func_graph) override; | |||||
| GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) override; | |||||
| void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | ||||
| void BuildGraph(GraphId) override; | void BuildGraph(GraphId) override; | ||||
| void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | ||||
| @@ -28,6 +28,7 @@ | |||||
| #include "ir/meta_tensor.h" | #include "ir/meta_tensor.h" | ||||
| #include "utils/any.h" | #include "utils/any.h" | ||||
| #include "utils/base_ref.h" | #include "utils/base_ref.h" | ||||
| #include "utils/contract.h" | |||||
| #include "pynative/pynative_execute.h" | #include "pynative/pynative_execute.h" | ||||
| #include "device/kernel_info.h" | #include "device/kernel_info.h" | ||||
| @@ -57,7 +58,7 @@ class SessionBasic { | |||||
| virtual ~SessionBasic() { summary_callback_ = nullptr; } | virtual ~SessionBasic() { summary_callback_ = nullptr; } | ||||
| virtual GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0; | virtual GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0; | ||||
| virtual GraphId CompileGraph(const FuncGraphPtr &) { return kInvalidGraphId; } | |||||
| virtual GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; } | |||||
| // build graph, used to handle multiple child graphs | // build graph, used to handle multiple child graphs | ||||
| virtual void BuildGraph(GraphId) {} | virtual void BuildGraph(GraphId) {} | ||||
| @@ -327,5 +327,9 @@ MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_ | |||||
| sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); | sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); | ||||
| } | } | ||||
| GraphId MsBackend::CompileGraph(NotNull<FuncGraphPtr> fg) { return sess_->CompileGraph(fg); } | |||||
| VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); } | |||||
| } // namespace compile | } // namespace compile | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <utility> | #include <utility> | ||||
| #include "utils/contract.h" | |||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "vm/segment_runner.h" | #include "vm/segment_runner.h" | ||||
| #include "vm/vm.h" | #include "vm/vm.h" | ||||
| @@ -49,7 +50,7 @@ class Backend { | |||||
| virtual void SetSwitchActive(const BaseRef &, bool) {} | virtual void SetSwitchActive(const BaseRef &, bool) {} | ||||
| virtual void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) {} | virtual void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) {} | ||||
| virtual void SetGraphUserInputs(const FuncGraphPtr &, const FuncGraphPtr &, const AnfNodePtrList &) {} | virtual void SetGraphUserInputs(const FuncGraphPtr &, const FuncGraphPtr &, const AnfNodePtrList &) {} | ||||
| virtual GraphId CompileGraph(NotNull<FuncGraphPtr> fg) { return kInvalidGraphId; } | |||||
| void set_curr_switch(const BaseRef &value) { | void set_curr_switch(const BaseRef &value) { | ||||
| curr_switch_ = value; | curr_switch_ = value; | ||||
| is_switch_call_ = true; | is_switch_call_ = true; | ||||
| @@ -104,6 +105,8 @@ class MsBackend : public Backend { | |||||
| void Link(GraphId) override; | void Link(GraphId) override; | ||||
| AnfNodePtr ConvertGraphInput(const FuncGraphPtr &, const AnfNodePtr &); | AnfNodePtr ConvertGraphInput(const FuncGraphPtr &, const AnfNodePtr &); | ||||
| LinConvertResult GetMultiGraphRun(const FuncGraphPtr &g) override; | LinConvertResult GetMultiGraphRun(const FuncGraphPtr &g) override; | ||||
| GraphId CompileGraph(NotNull<FuncGraphPtr> fg) override; | |||||
| VectorRef RunGraph(GraphId graph_id, const VectorRef &args); | |||||
| private: | private: | ||||
| session::SessionPtr sess_; | session::SessionPtr sess_; | ||||