Merge pull request !1415 from kisnwang/support-mix-targettags/v0.5.0-beta
| @@ -35,6 +35,7 @@ class AscendDeviceAddress : public DeviceAddress { | |||||
| ~AscendDeviceAddress() override; | ~AscendDeviceAddress() override; | ||||
| bool SyncDeviceToHost(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const override; | bool SyncDeviceToHost(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const override; | ||||
| bool SyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type, const void *host_ptr) const override; | bool SyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type, const void *host_ptr) const override; | ||||
| DeviceAddressType DeviceType() const override { return DeviceAddressType::kAscend; } | |||||
| #ifdef ENABLE_DUMP_E2E | #ifdef ENABLE_DUMP_E2E | ||||
| bool DumpMemToFile(bool dump_mode, const std::string &filepath, const std::string &host_fmt, | bool DumpMemToFile(bool dump_mode, const std::string &filepath, const std::string &host_fmt, | ||||
| const std::vector<int> &host_shape, TypeId host_type) const; | const std::vector<int> &host_shape, TypeId host_type) const; | ||||
| @@ -259,6 +259,15 @@ bool AscendKernelRuntime::DumpData(mindspore::session::KernelGraph *graph) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool AscendKernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) { | |||||
| if (AnfAlgo::OutputAddrExist(kernel, index)) { | |||||
| auto address = AnfAlgo::GetOutputAddr(kernel, index); | |||||
| MS_EXCEPTION_IF_NULL(address); | |||||
| return address->DeviceType() == DeviceAddressType::kAscend; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | ||||
| TypeId type_id) { | TypeId type_id) { | ||||
| return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id); | return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id); | ||||
| @@ -45,6 +45,7 @@ class AscendKernelRuntime : public KernelRuntime { | |||||
| protected: | protected: | ||||
| DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | ||||
| TypeId type_id) override; | TypeId type_id) override; | ||||
| bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index) override; | |||||
| bool SyncStream() override; | bool SyncStream() override; | ||||
| private: | private: | ||||
| @@ -34,6 +34,7 @@ class CPUDeviceAddress : public DeviceAddress { | |||||
| bool SyncDeviceToHost(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const override; | bool SyncDeviceToHost(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const override; | ||||
| bool SyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type, const void *host_ptr) const override; | bool SyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type, const void *host_ptr) const override; | ||||
| DeviceAddressType DeviceType() const override { return DeviceAddressType::kCPU; } | |||||
| }; | }; | ||||
| } // namespace cpu | } // namespace cpu | ||||
| } // namespace device | } // namespace device | ||||
| @@ -48,6 +48,7 @@ class GPUMemoryManager; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| enum class DeviceAddressStatus { kInDevice, kInHost, kInDeviceToHost, kInHostToDevice }; | enum class DeviceAddressStatus { kInDevice, kInHost, kInDeviceToHost, kInHostToDevice }; | ||||
| enum class DeviceAddressType { kUnknown, kAscend, kCPU, kGPU }; | |||||
| class DeviceAddress { | class DeviceAddress { | ||||
| public: | public: | ||||
| @@ -64,6 +65,7 @@ class DeviceAddress { | |||||
| TypeId type_id() const { return type_id_; } | TypeId type_id() const { return type_id_; } | ||||
| virtual void set_status(DeviceAddressStatus status) {} | virtual void set_status(DeviceAddressStatus status) {} | ||||
| virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; } | virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; } | ||||
| virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; } | |||||
| protected: | protected: | ||||
| const void *ptr() const { return ptr_; } | const void *ptr() const { return ptr_; } | ||||
| @@ -35,6 +35,7 @@ class GPUDeviceAddress : public DeviceAddress { | |||||
| bool SyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type, const void *host_ptr) const override; | bool SyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type, const void *host_ptr) const override; | ||||
| void set_status(DeviceAddressStatus status) { status_ = status; } | void set_status(DeviceAddressStatus status) { status_ = status; } | ||||
| DeviceAddressStatus status() const { return status_; } | DeviceAddressStatus status() const { return status_; } | ||||
| DeviceAddressType DeviceType() const override { return DeviceAddressType::kGPU; } | |||||
| private: | private: | ||||
| DeviceAddressStatus status_{DeviceAddressStatus::kInDevice}; | DeviceAddressStatus status_{DeviceAddressStatus::kInDevice}; | ||||
| @@ -102,6 +102,13 @@ bool KernelRuntime::RunTask(const session::KernelGraph *graph) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) { | |||||
| if (AnfAlgo::OutputAddrExist(kernel, index)) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &node, size_t output_index) { | size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &node, size_t output_index) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (output_index >= AnfAlgo::GetOutputTensorNum(node)) { | if (output_index >= AnfAlgo::GetOutputTensorNum(node)) { | ||||
| @@ -255,7 +262,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { | |||||
| if (i < graph_valid_input.size() && !graph_valid_input[i]) { | if (i < graph_valid_input.size() && !graph_valid_input[i]) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (AnfAlgo::OutputAddrExist(item, 0)) { | |||||
| if (NodeOutputDeviceAddressExist(item, 0)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto output_size = AnfAlgo::GetOutputTensorNum(item); | auto output_size = AnfAlgo::GetOutputTensorNum(item); | ||||
| @@ -431,7 +438,7 @@ void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int in | |||||
| if ((kGetAllOuts != index) && (SizeToInt(i) != index)) { | if ((kGetAllOuts != index) && (SizeToInt(i) != index)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (AnfAlgo::OutputAddrExist(node, i)) { | |||||
| if (NodeOutputDeviceAddressExist(node, i)) { | |||||
| MS_LOG(INFO) << "Already malloc index:" << i; | MS_LOG(INFO) << "Already malloc index:" << i; | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -493,7 +500,7 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { | |||||
| MS_EXCEPTION_IF_NULL(ms_context); | MS_EXCEPTION_IF_NULL(ms_context); | ||||
| for (auto &value_node : graph->graph_value_nodes()) { | for (auto &value_node : graph->graph_value_nodes()) { | ||||
| MS_EXCEPTION_IF_NULL(value_node); | MS_EXCEPTION_IF_NULL(value_node); | ||||
| if (AnfAlgo::OutputAddrExist(value_node, 0)) { | |||||
| if (NodeOutputDeviceAddressExist(value_node, 0)) { | |||||
| MS_LOG(INFO) << "value_node[" << value_node->DebugString() << "] address already exist"; | MS_LOG(INFO) << "value_node[" << value_node->DebugString() << "] address already exist"; | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -67,6 +67,7 @@ class KernelRuntime { | |||||
| protected: | protected: | ||||
| virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | ||||
| TypeId type_id) = 0; | TypeId type_id) = 0; | ||||
| virtual bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index); | |||||
| virtual bool SyncStream() = 0; | virtual bool SyncStream() = 0; | ||||
| void AssignStaticMemory(session::KernelGraph *graph); | void AssignStaticMemory(session::KernelGraph *graph); | ||||
| void AssignDynamicMemory(session::KernelGraph *graph); | void AssignDynamicMemory(session::KernelGraph *graph); | ||||
| @@ -307,17 +307,27 @@ bool TaskEmitAction(const ResourcePtr &res) { | |||||
| } | } | ||||
| FuncGraphPtr func_graph = res->func_graph(); | FuncGraphPtr func_graph = res->func_graph(); | ||||
| auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>(); | auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>(); | ||||
| if (IsCtrlSink()) { | if (IsCtrlSink()) { | ||||
| res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph)); | res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph)); | ||||
| return true; | return true; | ||||
| } | } | ||||
| std::vector<PrimitivePtr> cut_list = compile::nonlinear_ops; | std::vector<PrimitivePtr> cut_list = compile::nonlinear_ops; | ||||
| if (bc_ptr->name() == kMsConvert) { | if (bc_ptr->name() == kMsConvert) { | ||||
| cut_list = compile::GetMsNonlinearOps(); | cut_list = compile::GetMsNonlinearOps(); | ||||
| } | } | ||||
| std::shared_ptr<CompileGraphs> compile = std::make_shared<CompileGraphs>(bc_ptr, cut_list); | std::shared_ptr<CompileGraphs> compile = std::make_shared<CompileGraphs>(bc_ptr, cut_list); | ||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| if (compile->ContainMixedTarget(func_graph)) { | |||||
| bc_ptr->set_is_multi_graph_sink(false); | |||||
| context_ptr->set_loop_sink_flag(false); | |||||
| } else if (context_ptr->execution_mode() != kPynativeMode) { | |||||
| std::string device_target = context_ptr->device_target(); | |||||
| if (device_target == kAscendDevice) { | |||||
| bc_ptr->set_is_multi_graph_sink(true); | |||||
| } | |||||
| } | |||||
| res->results()[kOutput] = compile->CompileAndLink(func_graph); | res->results()[kOutput] = compile->CompileAndLink(func_graph); | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -778,7 +778,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc | |||||
| MS_EXCEPTION_IF_NULL(convert_fn); | MS_EXCEPTION_IF_NULL(convert_fn); | ||||
| // Convert CNodeList to LinConvertResult. | // Convert CNodeList to LinConvertResult. | ||||
| ConfigManager::GetInstance().set_iter_num(1); | ConfigManager::GetInstance().set_iter_num(1); | ||||
| auto runner = convert_fn({app_init}); | |||||
| auto runner = convert_fn({app_init}, ""); | |||||
| if (MsContext::GetInstance()->execution_mode() != kPynativeMode) { | if (MsContext::GetInstance()->execution_mode() != kPynativeMode) { | ||||
| backend->Link(runner.graph_id); | backend->Link(runner.graph_id); | ||||
| } | } | ||||
| @@ -28,6 +28,23 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace session { | namespace session { | ||||
| ParameterPtr CPUSession::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { | |||||
| MS_EXCEPTION_IF_NULL(anf); | |||||
| if (!anf->isa<Parameter>()) { | |||||
| MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter"; | |||||
| } | |||||
| auto valid_inputs = graph->MutableValidInputs(); | |||||
| MS_EXCEPTION_IF_NULL(valid_inputs); | |||||
| auto graph_inputs = graph->MutableInputs(); | |||||
| MS_EXCEPTION_IF_NULL(graph_inputs); | |||||
| TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info())); | |||||
| ParameterPtr new_parameter = graph->NewParameter(anf->cast<ParameterPtr>()); | |||||
| TraceManager::EndTrace(); | |||||
| graph_inputs->push_back(new_parameter); | |||||
| valid_inputs->push_back(valid_input); | |||||
| return new_parameter; | |||||
| } | |||||
| GraphId CPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | GraphId CPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | ||||
| auto graph_id = graph_sum_; | auto graph_id = graph_sum_; | ||||
| auto graph = ConstructKernelGraph(lst, outputs); | auto graph = ConstructKernelGraph(lst, outputs); | ||||
| @@ -35,6 +35,9 @@ class CPUSession : public SessionBasic { | |||||
| GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | ||||
| void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | ||||
| protected: | |||||
| ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) override; | |||||
| private: | private: | ||||
| void SetKernelInfo(const KernelGraph *kernel_graph); | void SetKernelInfo(const KernelGraph *kernel_graph); | ||||
| void BuildKernel(const KernelGraph *kernel_graph); | void BuildKernel(const KernelGraph *kernel_graph); | ||||
| @@ -482,7 +482,13 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de | |||||
| depend_nodes = GetOutputNodes(depend_node); | depend_nodes = GetOutputNodes(depend_node); | ||||
| } | } | ||||
| for (auto &first_node : prior_nodes) { | for (auto &first_node : prior_nodes) { | ||||
| if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) { | |||||
| continue; | |||||
| } | |||||
| for (auto &second_node : depend_nodes) { | for (auto &second_node : depend_nodes) { | ||||
| if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) { | |||||
| continue; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(first_node); | MS_EXCEPTION_IF_NULL(first_node); | ||||
| MS_EXCEPTION_IF_NULL(second_node); | MS_EXCEPTION_IF_NULL(second_node); | ||||
| MS_LOG(INFO) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString(); | MS_LOG(INFO) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString(); | ||||
| @@ -311,7 +311,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf | |||||
| if (python_paras_ == nullptr) { | if (python_paras_ == nullptr) { | ||||
| python_paras_ = std::make_shared<std::map<PyObject *, ParameterPtr>>(); | python_paras_ = std::make_shared<std::map<PyObject *, ParameterPtr>>(); | ||||
| } | } | ||||
| if (python_paras_->find(m_tensor) != python_paras_->end() && GetGraphIdByNode(anf) == kInvalidGraphId) { | |||||
| if (python_paras_->find(m_tensor) != python_paras_->end()) { | |||||
| new_parameter = (*python_paras_)[m_tensor]; | new_parameter = (*python_paras_)[m_tensor]; | ||||
| } else { | } else { | ||||
| TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info())); | TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info())); | ||||
| @@ -114,7 +114,7 @@ class SessionBasic { | |||||
| BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref); | BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref); | ||||
| // create a new kernel graph and update the graph sum | // create a new kernel graph and update the graph sum | ||||
| KernelGraphPtr NewKernelGraph(); | KernelGraphPtr NewKernelGraph(); | ||||
| ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); | |||||
| virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); | |||||
| ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph); | ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph); | ||||
| ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph); | ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph); | ||||
| AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); | AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); | ||||
| @@ -92,7 +92,7 @@ class MsContext { | |||||
| bool ir_fusion_flag() const { return ir_fusion_flag_; } | bool ir_fusion_flag() const { return ir_fusion_flag_; } | ||||
| bool loop_sink_flag() const { return enable_loop_sink_; } | bool loop_sink_flag() const { return enable_loop_sink_; } | ||||
| void set_loop_sink_flag(bool enable_loop_sink) { enable_loop_sink_ = enable_loop_sink; } | |||||
| void set_enable_mem_reuse(bool enable_mem_reuse) { enable_mem_reuse_ = enable_mem_reuse; } | void set_enable_mem_reuse(bool enable_mem_reuse) { enable_mem_reuse_ = enable_mem_reuse; } | ||||
| bool enable_mem_reuse() const { return enable_mem_reuse_; } | bool enable_mem_reuse() const { return enable_mem_reuse_; } | ||||
| @@ -39,14 +39,14 @@ LinConvertResult MsBackend::GetMultiGraphRun(const FuncGraphPtr &g) { | |||||
| multi_result_.inputs = g->parameters(); | multi_result_.inputs = g->parameters(); | ||||
| final_output_ = NewValueNode("fake_output"); | final_output_ = NewValueNode("fake_output"); | ||||
| multi_result_.outputs = {final_output_}; | multi_result_.outputs = {final_output_}; | ||||
| GraphId final_g = sess_->GetFinalRunGraph(); | |||||
| GraphId final_g = target_sess_->GetFinalRunGraph(); | |||||
| multi_result_.run = std::make_shared<RunFunc>( | multi_result_.run = std::make_shared<RunFunc>( | ||||
| [final_g, this](const VectorRef &args) -> VectorRef { return MsRunGraph(final_g, args); }); | |||||
| [final_g, this](const VectorRef &args) -> VectorRef { return MsRunGraph(final_g, args, ""); }); | |||||
| return multi_result_; | return multi_result_; | ||||
| } | } | ||||
| LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst) { | |||||
| LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::string &target) { | |||||
| MS_LOG(DEBUG) << "MsConvert"; | MS_LOG(DEBUG) << "MsConvert"; | ||||
| MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); | MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); | ||||
| auto cached = g_ConvertCache.find(lst); | auto cached = g_ConvertCache.find(lst); | ||||
| @@ -64,17 +64,24 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst) { | |||||
| result.inputs = inputs; | result.inputs = inputs; | ||||
| result.outputs = outputs; | result.outputs = outputs; | ||||
| result.graph_id = kInvalidGraphId; | result.graph_id = kInvalidGraphId; | ||||
| auto graph_id = sess_->CompileGraph(lst, outputs); | |||||
| if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { | |||||
| sess_->BuildGraph(graph_id); | |||||
| GraphId graph_id = kInvalidGraphId; | |||||
| if (target == kCPUDevice) { | |||||
| graph_id = cpu_sess_->CompileGraph(lst, outputs); | |||||
| } else { | |||||
| graph_id = target_sess_->CompileGraph(lst, outputs); | |||||
| } | } | ||||
| if (MsContext::GetInstance()->precompile_only()) { | if (MsContext::GetInstance()->precompile_only()) { | ||||
| MS_LOG(INFO) << "PrecompileOnly, stop run graph"; | MS_LOG(INFO) << "PrecompileOnly, stop run graph"; | ||||
| return result; | return result; | ||||
| } | } | ||||
| if (target == kCPUDevice) { | |||||
| cpu_sess_->BuildGraph(graph_id); | |||||
| } else if (!is_multi_graph_sink_) { | |||||
| target_sess_->BuildGraph(graph_id); | |||||
| } | |||||
| result.run = std::make_shared<RunFunc>( | result.run = std::make_shared<RunFunc>( | ||||
| [graph_id, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args); }); | |||||
| [graph_id, target, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args, target); }); | |||||
| MS_EXCEPTION_IF_NULL(result.run); | MS_EXCEPTION_IF_NULL(result.run); | ||||
| result.simu_run = std::make_shared<RunFunc>( | result.simu_run = std::make_shared<RunFunc>( | ||||
| @@ -92,7 +99,7 @@ void MsBackend::SetSwitchActive(const BaseRef &c, bool cond) { | |||||
| GraphId cond_g = kInvalidGraphId; | GraphId cond_g = kInvalidGraphId; | ||||
| if (utils::isa<AnfNodePtr>(c)) { | if (utils::isa<AnfNodePtr>(c)) { | ||||
| cond_g = sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(c)); | |||||
| cond_g = target_sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(c)); | |||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "cond not a anf node:" << c.ToString(); | MS_LOG(EXCEPTION) << "cond not a anf node:" << c.ToString(); | ||||
| } | } | ||||
| @@ -116,7 +123,7 @@ void MsBackend::SetSwitchActive(const BaseRef &c, bool cond) { | |||||
| MS_LOG(DEBUG) << "invoke set active:" << active_g; | MS_LOG(DEBUG) << "invoke set active:" << active_g; | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "switch set active:" << active_g << ", " << cond_g; | MS_LOG(DEBUG) << "switch set active:" << active_g << ", " << cond_g; | ||||
| sess_->SetActive(active_g, cond_g); | |||||
| target_sess_->SetActive(active_g, cond_g); | |||||
| } | } | ||||
| void MsBackend::SetSwitchGraph() { | void MsBackend::SetSwitchGraph() { | ||||
| @@ -135,12 +142,12 @@ void MsBackend::SetSwitchGraph() { | |||||
| } | } | ||||
| GraphId cond_g = kInvalidGraphId; | GraphId cond_g = kInvalidGraphId; | ||||
| if (utils::isa<AnfNodePtr>(curr_switch_)) { | if (utils::isa<AnfNodePtr>(curr_switch_)) { | ||||
| cond_g = sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(curr_switch_)); | |||||
| cond_g = target_sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(curr_switch_)); | |||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "cond not a anf node:" << curr_switch_.ToString(); | MS_LOG(EXCEPTION) << "cond not a anf node:" << curr_switch_.ToString(); | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "switch compile:" << cond_g << ", " << true_g << ", " << false_g; | MS_LOG(DEBUG) << "switch compile:" << cond_g << ", " << true_g << ", " << false_g; | ||||
| sess_->SwitchCompile(cond_g, true_g, false_g, utils::cast<AnfNodePtr>(curr_switch_)); | |||||
| target_sess_->SwitchCompile(cond_g, true_g, false_g, utils::cast<AnfNodePtr>(curr_switch_)); | |||||
| } | } | ||||
| is_switch_call_ = false; | is_switch_call_ = false; | ||||
| MS_LOG(DEBUG) << "end SetSwitchGraph:" << curr_cond << ", " << is_switch_call_; | MS_LOG(DEBUG) << "end SetSwitchGraph:" << curr_cond << ", " << is_switch_call_; | ||||
| @@ -202,7 +209,7 @@ void MsBackend::RecallGraphInput(const FuncGraphPtr &func_graph, const VectorRef | |||||
| old_args[i] = args[it->second]; | old_args[i] = args[it->second]; | ||||
| } | } | ||||
| } | } | ||||
| sess_->SetChildGraphInput(graph, old_args); | |||||
| target_sess_->SetChildGraphInput(graph, old_args); | |||||
| } | } | ||||
| graph_inputs_.erase(c); | graph_inputs_.erase(c); | ||||
| } | } | ||||
| @@ -211,7 +218,7 @@ void MsBackend::RecallGraphInput(const FuncGraphPtr &func_graph, const VectorRef | |||||
| VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) { | VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) { | ||||
| MS_LOG(DEBUG) << "set graph input:" << g; | MS_LOG(DEBUG) << "set graph input:" << g; | ||||
| // switch maybe twice | // switch maybe twice | ||||
| sess_->SetChildGraphInput(g, args); | |||||
| target_sess_->SetChildGraphInput(g, args); | |||||
| if (is_switch_call_) { | if (is_switch_call_) { | ||||
| if (!curr_switch_.is_null()) { | if (!curr_switch_.is_null()) { | ||||
| @@ -236,7 +243,7 @@ VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) { | |||||
| return VectorRef(outputs); | return VectorRef(outputs); | ||||
| } | } | ||||
| VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args) { | |||||
| VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) { | |||||
| MS_LOG(DEBUG) << "start ms graph run:" << args.size() << ", g:" << g; | MS_LOG(DEBUG) << "start ms graph run:" << args.size() << ", g:" << g; | ||||
| // Run graph | // Run graph | ||||
| std::vector<tensor::TensorPtr> inputs; | std::vector<tensor::TensorPtr> inputs; | ||||
| @@ -271,7 +278,12 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args) { | |||||
| VectorRef outputs; | VectorRef outputs; | ||||
| // call ms rungraph (graphId, input ,output) | // call ms rungraph (graphId, input ,output) | ||||
| sess_->RunGraph(g, inputs, &outputs); | |||||
| if (target == kCPUDevice) { | |||||
| cpu_sess_->RunGraph(g, inputs, &outputs); | |||||
| } else { | |||||
| target_sess_->RunGraph(g, inputs, &outputs); | |||||
| } | |||||
| MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size(); | MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size(); | ||||
| return outputs; | return outputs; | ||||
| } | } | ||||
| @@ -300,17 +312,17 @@ void MsBackend::SimulateRun(FinalVMPtr rt, FuncGraphPtr root) { | |||||
| (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args), | (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args), | ||||
| [](const AnfNodePtr &v) { return v; }); | [](const AnfNodePtr &v) { return v; }); | ||||
| MS_LOG(DEBUG) << "Simulate start"; | MS_LOG(DEBUG) << "Simulate start"; | ||||
| (void)sess_->SetFinalGraphInput(parameters); | |||||
| (void)target_sess_->SetFinalGraphInput(parameters); | |||||
| BaseRef output = rt->Eval(VectorRef(args)); | BaseRef output = rt->Eval(VectorRef(args)); | ||||
| sess_->SetFinalGraphOutput(output); | |||||
| target_sess_->SetFinalGraphOutput(output); | |||||
| MS_LOG(DEBUG) << "Simulate Eval end"; | MS_LOG(DEBUG) << "Simulate Eval end"; | ||||
| } | } | ||||
| void MsBackend::Link(GraphId graph_id) { | void MsBackend::Link(GraphId graph_id) { | ||||
| if (graph_id == kInvalidGraphId) { | if (graph_id == kInvalidGraphId) { | ||||
| graph_id = sess_->GetFinalRunGraph(); | |||||
| graph_id = target_sess_->GetFinalRunGraph(); | |||||
| } | } | ||||
| sess_->BuildGraph(graph_id); | |||||
| target_sess_->BuildGraph(graph_id); | |||||
| } | } | ||||
| Backend::Backend(const std::string &name) : name_(name) { | Backend::Backend(const std::string &name) : name_(name) { | ||||
| @@ -322,16 +334,26 @@ Backend::Backend(const std::string &name) : name_(name) { | |||||
| } | } | ||||
| MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) { | MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) { | ||||
| convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1); | |||||
| sess_ = session::SessionFactory::Get().Create(target); | |||||
| if (sess_ == nullptr) { | |||||
| convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2); | |||||
| target_sess_ = session::SessionFactory::Get().Create(target); | |||||
| if (target_sess_ == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available."; | MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available."; | ||||
| } | } | ||||
| sess_->Init(device_id); | |||||
| sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); | |||||
| target_sess_->Init(device_id); | |||||
| target_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); | |||||
| if (target == kCPUDevice) { | |||||
| cpu_sess_ = target_sess_; | |||||
| } else { | |||||
| cpu_sess_ = session::SessionFactory::Get().Create(kCPUDevice); | |||||
| if (cpu_sess_ == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Create cpu session failed with target " << target << "."; | |||||
| } | |||||
| cpu_sess_->Init(0); | |||||
| cpu_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); | |||||
| } | |||||
| } | } | ||||
| GraphId MsBackend::CompileGraph(NotNull<FuncGraphPtr> fg) { return sess_->CompileGraph(fg); } | |||||
| GraphId MsBackend::CompileGraph(NotNull<FuncGraphPtr> fg) { return target_sess_->CompileGraph(fg); } | |||||
| VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); } | VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); } | ||||
| @@ -91,8 +91,8 @@ class MsBackend : public Backend { | |||||
| MsBackend(const std::string &name, const std::string &target, uint32_t device_id); | MsBackend(const std::string &name, const std::string &target, uint32_t device_id); | ||||
| ~MsBackend() override = default; | ~MsBackend() override = default; | ||||
| LinConvertResult MsConvert(const AnfNodePtrList &lst); | |||||
| VectorRef MsRunGraph(const GraphId &g, const VectorRef &args); | |||||
| LinConvertResult MsConvert(const AnfNodePtrList &lst, const std::string &target = ""); | |||||
| VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = ""); | |||||
| VectorRef MsSimuRunGraph(const GraphId &g, const VectorRef &args); | VectorRef MsSimuRunGraph(const GraphId &g, const VectorRef &args); | ||||
| void SimulateRun(FinalVMPtr rt, FuncGraphPtr root) override; | void SimulateRun(FinalVMPtr rt, FuncGraphPtr root) override; | ||||
| @@ -109,7 +109,8 @@ class MsBackend : public Backend { | |||||
| VectorRef RunGraph(GraphId graph_id, const VectorRef &args); | VectorRef RunGraph(GraphId graph_id, const VectorRef &args); | ||||
| private: | private: | ||||
| session::SessionPtr sess_; | |||||
| session::SessionPtr target_sess_; | |||||
| session::SessionPtr cpu_sess_; | |||||
| std::unordered_map<BaseRef, CondGraph, BaseRefHash> simu_cond_map_; | std::unordered_map<BaseRef, CondGraph, BaseRefHash> simu_cond_map_; | ||||
| std::unordered_map<GraphId, LinConvertResult> graph_id_map_; | std::unordered_map<GraphId, LinConvertResult> graph_id_map_; | ||||
| std::unordered_map<BaseRef, std::list<std::pair<GraphId, VectorRef>>, BaseRefHash> graph_inputs_; | std::unordered_map<BaseRef, std::list<std::pair<GraphId, VectorRef>>, BaseRefHash> graph_inputs_; | ||||
| @@ -148,7 +148,7 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr | |||||
| // This implementation will convert the nodes into a subgraph | // This implementation will convert the nodes into a subgraph | ||||
| // that will run using the MsVM. | // that will run using the MsVM. | ||||
| template <typename T> | template <typename T> | ||||
| LinConvertResult Convert(const AnfNodePtrList &lst) { | |||||
| LinConvertResult Convert(const AnfNodePtrList &lst, const std::string &) { | |||||
| auto cached = g_ConvertCache.find(lst); | auto cached = g_ConvertCache.find(lst); | ||||
| if (cached != g_ConvertCache.end()) { | if (cached != g_ConvertCache.end()) { | ||||
| return cached->second; | return cached->second; | ||||
| @@ -43,7 +43,7 @@ struct LinConvertResult { | |||||
| uint32_t graph_id; | uint32_t graph_id; | ||||
| }; | }; | ||||
| using LinkFuncType = std::function<LinConvertResult(const AnfNodePtrList &)>; | |||||
| using LinkFuncType = std::function<LinConvertResult(const AnfNodePtrList &, const std::string &)>; | |||||
| using ConvertCache = std::unordered_map<BaseRef, LinConvertResult, BaseRefHash>; | using ConvertCache = std::unordered_map<BaseRef, LinConvertResult, BaseRefHash>; | ||||
| extern LinkFuncType MsVmConvert; | extern LinkFuncType MsVmConvert; | ||||
| extern LinkFuncType GeVmConvert; | extern LinkFuncType GeVmConvert; | ||||
| @@ -20,6 +20,8 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <map> | #include <map> | ||||
| #include <queue> | |||||
| #include <set> | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| @@ -47,6 +49,86 @@ const std::vector<PrimitivePtr> &GetMsNonlinearOps() { | |||||
| return ms_nonlinear_ops; | return ms_nonlinear_ops; | ||||
| } | } | ||||
| namespace { | |||||
| std::string GetCNodeTarget(const AnfNodePtr &node) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| std::string default_target = context_ptr->device_target(); | |||||
| if (!node->isa<CNode>()) { | |||||
| return default_target; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto attr_input = cnode->input(kAnfPrimitiveIndex); | |||||
| if (attr_input == nullptr) { | |||||
| return default_target; | |||||
| } | |||||
| auto value_node = attr_input->cast<ValueNodePtr>(); | |||||
| if (value_node == nullptr) { | |||||
| return default_target; | |||||
| } | |||||
| auto value = value_node->value(); | |||||
| if (value == nullptr) { | |||||
| return default_target; | |||||
| } | |||||
| if (!value->isa<Primitive>()) { | |||||
| return default_target; | |||||
| } | |||||
| auto primitive = value->cast<PrimitivePtr>(); | |||||
| ValuePtr att_target = primitive->GetAttr("target"); | |||||
| if (att_target != nullptr) { | |||||
| std::string target = GetValue<std::string>(att_target); | |||||
| return target; | |||||
| } | |||||
| return default_target; | |||||
| } | |||||
| bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| std::string last_target = context_ptr->device_target(); | |||||
| for (auto &node : nodes) { | |||||
| if (node->isa<CNode>()) { | |||||
| std::string cur_target = GetCNodeTarget(node); | |||||
| if (last_target != cur_target) { | |||||
| return true; | |||||
| } | |||||
| last_target = cur_target; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *nodes_ref) { | |||||
| std::queue<AnfNodePtr> queue; | |||||
| queue.push(graph->get_return()); | |||||
| std::set<AnfNodePtr> visited; | |||||
| while (!queue.empty()) { | |||||
| auto &node = queue.front(); | |||||
| queue.pop(); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (!node->isa<CNode>()) { | |||||
| continue; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| for (auto &input : cnode->inputs()) { | |||||
| auto iter = nodes_ref->find(input); | |||||
| if (iter != nodes_ref->end()) { | |||||
| iter->second++; | |||||
| } else { | |||||
| (void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(input, 1)); | |||||
| } | |||||
| if (visited.find(input) != visited.end()) { | |||||
| continue; | |||||
| } | |||||
| visited.insert(input); | |||||
| queue.push(input); | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace | |||||
| CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list) | CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list) | ||||
| : backend_(backend), cut_list_(cut_list) { | : backend_(backend), cut_list_(cut_list) { | ||||
| MS_EXCEPTION_IF_NULL(backend_); | MS_EXCEPTION_IF_NULL(backend_); | ||||
| @@ -98,12 +180,67 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| std::vector<AnfNodePtr> CompileGraph::SplitSort(const FuncGraphPtr &graph) { | |||||
| std::vector<AnfNodePtr> result; | |||||
| std::queue<AnfNodePtr> queue; | |||||
| std::queue<AnfNodePtr> next_queue; | |||||
| std::map<AnfNodePtr, size_t> nodes_ref; | |||||
| CalcNodeRefCount(graph, &nodes_ref); | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| std::string queue_target = context_ptr->device_target(); | |||||
| std::string next_target = ""; | |||||
| queue.push(graph->get_return()); | |||||
| while (!queue.empty() || !next_queue.empty()) { | |||||
| if (queue.empty()) { | |||||
| queue.swap(next_queue); | |||||
| queue_target = next_target; | |||||
| } | |||||
| auto &node = queue.front(); | |||||
| queue.pop(); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| result.emplace_back(node); | |||||
| if (!node->isa<CNode>()) { | |||||
| continue; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| for (auto &input : cnode->inputs()) { | |||||
| auto iter = nodes_ref.find(input); | |||||
| if (iter != nodes_ref.end()) { | |||||
| iter->second--; | |||||
| if (iter->second != 0) { | |||||
| continue; | |||||
| } | |||||
| } | |||||
| if (!input->isa<CNode>()) { | |||||
| queue.push(input); | |||||
| continue; | |||||
| } | |||||
| std::string input_target = GetCNodeTarget(input); | |||||
| if (input_target == queue_target) { | |||||
| queue.push(input); | |||||
| } else if (next_queue.empty() || input_target == next_target) { | |||||
| next_queue.push(input); | |||||
| next_target = input_target; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "only support two different target"; | |||||
| } | |||||
| } | |||||
| } | |||||
| std::reverse(result.begin(), result.end()); | |||||
| return result; | |||||
| } | |||||
| VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { | VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| VectorRef splits; | VectorRef splits; | ||||
| VectorRef split; | VectorRef split; | ||||
| std::vector<AnfNodePtr> nodes = TopoSort(graph->get_return()); | |||||
| auto nodes = TopoSort(graph->get_return()); | |||||
| if (ContainMultiTarget(nodes)) { | |||||
| nodes = SplitSort(graph); | |||||
| } | |||||
| std::string last_target; | |||||
| MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size(); | MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size(); | ||||
| for (auto &node : nodes) { | for (auto &node : nodes) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| @@ -114,7 +251,13 @@ VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { | |||||
| } | } | ||||
| splits.push_back(node); | splits.push_back(node); | ||||
| split.clear(); | split.clear(); | ||||
| } else if (!(node->isa<ValueNode>() || node->isa<Parameter>())) { | |||||
| } else if (node->isa<CNode>()) { | |||||
| std::string cur_target = GetCNodeTarget(node); | |||||
| if (cur_target != last_target && !last_target.empty() && split.size() != 0) { | |||||
| splits.push_back(split); | |||||
| split.clear(); | |||||
| } | |||||
| last_target = cur_target; | |||||
| split.push_back(node); | split.push_back(node); | ||||
| MS_LOG(DEBUG) << "Insert node:" << node->DebugString(10) << ", size:" << split.size(); | MS_LOG(DEBUG) << "Insert node:" << node->DebugString(10) << ", size:" << split.size(); | ||||
| } | } | ||||
| @@ -200,14 +343,14 @@ void CompileGraph::PushParameters(const FuncGraphPtr &graph) { | |||||
| } | } | ||||
| } | } | ||||
| int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &node_list) { | |||||
| int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &node_list, const std::string &target) { | |||||
| MS_LOG(DEBUG) << "LinConvert start"; | MS_LOG(DEBUG) << "LinConvert start"; | ||||
| LinConvertResult result; | LinConvertResult result; | ||||
| if (backend_->simu_flag()) { | if (backend_->simu_flag()) { | ||||
| result = backend_->GetMultiGraphRun(graph); | result = backend_->GetMultiGraphRun(graph); | ||||
| } else { | } else { | ||||
| result = lin_convert_(node_list); | |||||
| result = lin_convert_(node_list, target); | |||||
| } | } | ||||
| if (result.run == nullptr) { | if (result.run == nullptr) { | ||||
| @@ -316,7 +459,12 @@ bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) { | |||||
| auto vec_ref = utils::cast<VectorRef>(split); | auto vec_ref = utils::cast<VectorRef>(split); | ||||
| (void)std::transform(vec_ref.begin(), vec_ref.end(), std::back_inserter(args), | (void)std::transform(vec_ref.begin(), vec_ref.end(), std::back_inserter(args), | ||||
| [](const BaseRef &v) { return utils::cast<AnfNodePtr>(v); }); | [](const BaseRef &v) { return utils::cast<AnfNodePtr>(v); }); | ||||
| ret = LinConvert(graph, args); | |||||
| if (args.size() > 0) { | |||||
| std::string cur_target = GetCNodeTarget(args[0]); | |||||
| ret = LinConvert(graph, args, cur_target); | |||||
| } else { | |||||
| ret = LinConvert(graph, args); | |||||
| } | |||||
| MS_LOG(DEBUG) << "End a extern LinConvert"; | MS_LOG(DEBUG) << "End a extern LinConvert"; | ||||
| if (ret == RET_FAILED) { | if (ret == RET_FAILED) { | ||||
| return false; | return false; | ||||
| @@ -637,6 +785,19 @@ FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) { | |||||
| return rt; | return rt; | ||||
| } | } | ||||
| bool CompileGraphs::ContainMixedTarget(const FuncGraphPtr &graph) { | |||||
| auto graph_manager = graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(graph_manager); | |||||
| FuncGraphSet graphs = graph_manager->func_graphs(); | |||||
| for (auto &g : graphs) { | |||||
| auto nodes = TopoSort(g->get_return()); | |||||
| if (ContainMultiTarget(nodes)) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| BackendPtr CreateBackend() { | BackendPtr CreateBackend() { | ||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| @@ -79,8 +79,9 @@ class CompileGraph { | |||||
| private: | private: | ||||
| void PushParameters(const FuncGraphPtr &func_graph); | void PushParameters(const FuncGraphPtr &func_graph); | ||||
| std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph); | |||||
| bool SplitGraph(const FuncGraphPtr &func_graph); | bool SplitGraph(const FuncGraphPtr &func_graph); | ||||
| int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list); | |||||
| int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list, const std::string &target = ""); | |||||
| int InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node); | int InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node); | ||||
| int AddCall(const FuncGraphPtr &graph, const CNodePtr &node); | int AddCall(const FuncGraphPtr &graph, const CNodePtr &node); | ||||
| void AddSinkSwitch(const CNodePtr &node); | void AddSinkSwitch(const CNodePtr &node); | ||||
| @@ -124,6 +125,7 @@ class CompileGraphs { | |||||
| void Compile(const FuncGraphPtr &func_graph); | void Compile(const FuncGraphPtr &func_graph); | ||||
| FinalVMPtr Link(const FuncGraphPtr &func_graph); | FinalVMPtr Link(const FuncGraphPtr &func_graph); | ||||
| FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph); | FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph); | ||||
| bool ContainMixedTarget(const FuncGraphPtr &graph); | |||||
| private: | private: | ||||
| InstSet insts_; | InstSet insts_; | ||||
| @@ -65,7 +65,7 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert1) { | |||||
| for (auto &item : utils::cast<VectorRef>(todos[0])) { | for (auto &item : utils::cast<VectorRef>(todos[0])) { | ||||
| anf_list.push_back(utils::cast<AnfNodePtr>(item)); | anf_list.push_back(utils::cast<AnfNodePtr>(item)); | ||||
| } | } | ||||
| auto convertResult = MsVmConvert(anf_list); | |||||
| auto convertResult = MsVmConvert(anf_list, ""); | |||||
| auto runResult = (*(convertResult.run))(args); | auto runResult = (*(convertResult.run))(args); | ||||
| ASSERT_TRUE(runResult.size() == 1 && py::cast<double>(BaseRefToPyData(runResult[0])) == 3.0); | ASSERT_TRUE(runResult.size() == 1 && py::cast<double>(BaseRefToPyData(runResult[0])) == 3.0); | ||||
| } | } | ||||
| @@ -89,7 +89,7 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert2) { | |||||
| for (auto &item : utils::cast<VectorRef>(todos[0])) { | for (auto &item : utils::cast<VectorRef>(todos[0])) { | ||||
| anf_list.push_back(utils::cast<AnfNodePtr>(item)); | anf_list.push_back(utils::cast<AnfNodePtr>(item)); | ||||
| } | } | ||||
| auto convertResult = MsVmConvert(anf_list); | |||||
| auto convertResult = MsVmConvert(anf_list, ""); | |||||
| auto runResult = (*(convertResult.run))(args); | auto runResult = (*(convertResult.run))(args); | ||||
| ASSERT_TRUE(runResult.size() == 1 && py::cast<double>(BaseRefToPyData(runResult[0])) == 2.0); | ASSERT_TRUE(runResult.size() == 1 && py::cast<double>(BaseRefToPyData(runResult[0])) == 2.0); | ||||
| } | } | ||||
| @@ -113,7 +113,7 @@ TEST_F(TestCompileSegmentRunner, test_if) { | |||||
| for (auto &item : utils::cast<VectorRef>(todos[0])) { | for (auto &item : utils::cast<VectorRef>(todos[0])) { | ||||
| anf_list.push_back(utils::cast<AnfNodePtr>(item)); | anf_list.push_back(utils::cast<AnfNodePtr>(item)); | ||||
| } | } | ||||
| auto convertResult = MsVmConvert(anf_list); | |||||
| auto convertResult = MsVmConvert(anf_list, ""); | |||||
| auto runResult = (*(convertResult.run))(args); | auto runResult = (*(convertResult.run))(args); | ||||
| auto result = py::cast<bool>(BaseRefToPyData(runResult[0])); | auto result = py::cast<bool>(BaseRefToPyData(runResult[0])); | ||||