Merge pull request !1415 from kisnwang/support-mix-targettags/v0.5.0-beta
| @@ -35,6 +35,7 @@ class AscendDeviceAddress : public DeviceAddress { | |||
| ~AscendDeviceAddress() override; | |||
| bool SyncDeviceToHost(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const override; | |||
| bool SyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type, const void *host_ptr) const override; | |||
| DeviceAddressType DeviceType() const override { return DeviceAddressType::kAscend; } | |||
| #ifdef ENABLE_DUMP_E2E | |||
| bool DumpMemToFile(bool dump_mode, const std::string &filepath, const std::string &host_fmt, | |||
| const std::vector<int> &host_shape, TypeId host_type) const; | |||
| @@ -259,6 +259,15 @@ bool AscendKernelRuntime::DumpData(mindspore::session::KernelGraph *graph) { | |||
| return true; | |||
| } | |||
| bool AscendKernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) { | |||
| if (AnfAlgo::OutputAddrExist(kernel, index)) { | |||
| auto address = AnfAlgo::GetOutputAddr(kernel, index); | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| return address->DeviceType() == DeviceAddressType::kAscend; | |||
| } | |||
| return false; | |||
| } | |||
| DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | |||
| TypeId type_id) { | |||
| return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id); | |||
| @@ -45,6 +45,7 @@ class AscendKernelRuntime : public KernelRuntime { | |||
| protected: | |||
| DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | |||
| TypeId type_id) override; | |||
| bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index) override; | |||
| bool SyncStream() override; | |||
| private: | |||
| @@ -34,6 +34,7 @@ class CPUDeviceAddress : public DeviceAddress { | |||
| bool SyncDeviceToHost(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const override; | |||
| bool SyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type, const void *host_ptr) const override; | |||
| DeviceAddressType DeviceType() const override { return DeviceAddressType::kCPU; } | |||
| }; | |||
| } // namespace cpu | |||
| } // namespace device | |||
| @@ -48,6 +48,7 @@ class GPUMemoryManager; | |||
| namespace mindspore { | |||
| namespace device { | |||
| enum class DeviceAddressStatus { kInDevice, kInHost, kInDeviceToHost, kInHostToDevice }; | |||
| enum class DeviceAddressType { kUnknown, kAscend, kCPU, kGPU }; | |||
| class DeviceAddress { | |||
| public: | |||
| @@ -64,6 +65,7 @@ class DeviceAddress { | |||
| TypeId type_id() const { return type_id_; } | |||
| virtual void set_status(DeviceAddressStatus status) {} | |||
| virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; } | |||
| virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; } | |||
| protected: | |||
| const void *ptr() const { return ptr_; } | |||
| @@ -35,6 +35,7 @@ class GPUDeviceAddress : public DeviceAddress { | |||
| bool SyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type, const void *host_ptr) const override; | |||
| void set_status(DeviceAddressStatus status) { status_ = status; } | |||
| DeviceAddressStatus status() const { return status_; } | |||
| DeviceAddressType DeviceType() const override { return DeviceAddressType::kGPU; } | |||
| private: | |||
| DeviceAddressStatus status_{DeviceAddressStatus::kInDevice}; | |||
| @@ -102,6 +102,13 @@ bool KernelRuntime::RunTask(const session::KernelGraph *graph) { | |||
| return false; | |||
| } | |||
| bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) { | |||
| if (AnfAlgo::OutputAddrExist(kernel, index)) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &node, size_t output_index) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (output_index >= AnfAlgo::GetOutputTensorNum(node)) { | |||
| @@ -255,7 +262,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { | |||
| if (i < graph_valid_input.size() && !graph_valid_input[i]) { | |||
| continue; | |||
| } | |||
| if (AnfAlgo::OutputAddrExist(item, 0)) { | |||
| if (NodeOutputDeviceAddressExist(item, 0)) { | |||
| continue; | |||
| } | |||
| auto output_size = AnfAlgo::GetOutputTensorNum(item); | |||
| @@ -431,7 +438,7 @@ void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int in | |||
| if ((kGetAllOuts != index) && (SizeToInt(i) != index)) { | |||
| continue; | |||
| } | |||
| if (AnfAlgo::OutputAddrExist(node, i)) { | |||
| if (NodeOutputDeviceAddressExist(node, i)) { | |||
| MS_LOG(INFO) << "Already malloc index:" << i; | |||
| continue; | |||
| } | |||
| @@ -493,7 +500,7 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| for (auto &value_node : graph->graph_value_nodes()) { | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| if (AnfAlgo::OutputAddrExist(value_node, 0)) { | |||
| if (NodeOutputDeviceAddressExist(value_node, 0)) { | |||
| MS_LOG(INFO) << "value_node[" << value_node->DebugString() << "] address already exist"; | |||
| continue; | |||
| } | |||
| @@ -67,6 +67,7 @@ class KernelRuntime { | |||
| protected: | |||
| virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | |||
| TypeId type_id) = 0; | |||
| virtual bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index); | |||
| virtual bool SyncStream() = 0; | |||
| void AssignStaticMemory(session::KernelGraph *graph); | |||
| void AssignDynamicMemory(session::KernelGraph *graph); | |||
| @@ -307,17 +307,27 @@ bool TaskEmitAction(const ResourcePtr &res) { | |||
| } | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>(); | |||
| if (IsCtrlSink()) { | |||
| res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph)); | |||
| return true; | |||
| } | |||
| std::vector<PrimitivePtr> cut_list = compile::nonlinear_ops; | |||
| if (bc_ptr->name() == kMsConvert) { | |||
| cut_list = compile::GetMsNonlinearOps(); | |||
| } | |||
| std::shared_ptr<CompileGraphs> compile = std::make_shared<CompileGraphs>(bc_ptr, cut_list); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (compile->ContainMixedTarget(func_graph)) { | |||
| bc_ptr->set_is_multi_graph_sink(false); | |||
| context_ptr->set_loop_sink_flag(false); | |||
| } else if (context_ptr->execution_mode() != kPynativeMode) { | |||
| std::string device_target = context_ptr->device_target(); | |||
| if (device_target == kAscendDevice) { | |||
| bc_ptr->set_is_multi_graph_sink(true); | |||
| } | |||
| } | |||
| res->results()[kOutput] = compile->CompileAndLink(func_graph); | |||
| return true; | |||
| } | |||
| @@ -778,7 +778,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc | |||
| MS_EXCEPTION_IF_NULL(convert_fn); | |||
| // Convert CNodeList to LinConvertResult. | |||
| ConfigManager::GetInstance().set_iter_num(1); | |||
| auto runner = convert_fn({app_init}); | |||
| auto runner = convert_fn({app_init}, ""); | |||
| if (MsContext::GetInstance()->execution_mode() != kPynativeMode) { | |||
| backend->Link(runner.graph_id); | |||
| } | |||
| @@ -28,6 +28,23 @@ | |||
| namespace mindspore { | |||
| namespace session { | |||
| ParameterPtr CPUSession::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(anf); | |||
| if (!anf->isa<Parameter>()) { | |||
| MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter"; | |||
| } | |||
| auto valid_inputs = graph->MutableValidInputs(); | |||
| MS_EXCEPTION_IF_NULL(valid_inputs); | |||
| auto graph_inputs = graph->MutableInputs(); | |||
| MS_EXCEPTION_IF_NULL(graph_inputs); | |||
| TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info())); | |||
| ParameterPtr new_parameter = graph->NewParameter(anf->cast<ParameterPtr>()); | |||
| TraceManager::EndTrace(); | |||
| graph_inputs->push_back(new_parameter); | |||
| valid_inputs->push_back(valid_input); | |||
| return new_parameter; | |||
| } | |||
| GraphId CPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | |||
| auto graph_id = graph_sum_; | |||
| auto graph = ConstructKernelGraph(lst, outputs); | |||
| @@ -35,6 +35,9 @@ class CPUSession : public SessionBasic { | |||
| GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | |||
| void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | |||
| protected: | |||
| ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) override; | |||
| private: | |||
| void SetKernelInfo(const KernelGraph *kernel_graph); | |||
| void BuildKernel(const KernelGraph *kernel_graph); | |||
| @@ -482,7 +482,13 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de | |||
| depend_nodes = GetOutputNodes(depend_node); | |||
| } | |||
| for (auto &first_node : prior_nodes) { | |||
| if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) { | |||
| continue; | |||
| } | |||
| for (auto &second_node : depend_nodes) { | |||
| if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) { | |||
| continue; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(first_node); | |||
| MS_EXCEPTION_IF_NULL(second_node); | |||
| MS_LOG(INFO) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString(); | |||
| @@ -311,7 +311,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf | |||
| if (python_paras_ == nullptr) { | |||
| python_paras_ = std::make_shared<std::map<PyObject *, ParameterPtr>>(); | |||
| } | |||
| if (python_paras_->find(m_tensor) != python_paras_->end() && GetGraphIdByNode(anf) == kInvalidGraphId) { | |||
| if (python_paras_->find(m_tensor) != python_paras_->end()) { | |||
| new_parameter = (*python_paras_)[m_tensor]; | |||
| } else { | |||
| TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info())); | |||
| @@ -114,7 +114,7 @@ class SessionBasic { | |||
| BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref); | |||
| // create a new kernel graph and update the graph sum | |||
| KernelGraphPtr NewKernelGraph(); | |||
| ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); | |||
| virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); | |||
| ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph); | |||
| ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph); | |||
| AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); | |||
| @@ -92,7 +92,7 @@ class MsContext { | |||
| bool ir_fusion_flag() const { return ir_fusion_flag_; } | |||
| bool loop_sink_flag() const { return enable_loop_sink_; } | |||
| void set_loop_sink_flag(bool enable_loop_sink) { enable_loop_sink_ = enable_loop_sink; } | |||
| void set_enable_mem_reuse(bool enable_mem_reuse) { enable_mem_reuse_ = enable_mem_reuse; } | |||
| bool enable_mem_reuse() const { return enable_mem_reuse_; } | |||
| @@ -39,14 +39,14 @@ LinConvertResult MsBackend::GetMultiGraphRun(const FuncGraphPtr &g) { | |||
| multi_result_.inputs = g->parameters(); | |||
| final_output_ = NewValueNode("fake_output"); | |||
| multi_result_.outputs = {final_output_}; | |||
| GraphId final_g = sess_->GetFinalRunGraph(); | |||
| GraphId final_g = target_sess_->GetFinalRunGraph(); | |||
| multi_result_.run = std::make_shared<RunFunc>( | |||
| [final_g, this](const VectorRef &args) -> VectorRef { return MsRunGraph(final_g, args); }); | |||
| [final_g, this](const VectorRef &args) -> VectorRef { return MsRunGraph(final_g, args, ""); }); | |||
| return multi_result_; | |||
| } | |||
| LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst) { | |||
| LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::string &target) { | |||
| MS_LOG(DEBUG) << "MsConvert"; | |||
| MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); | |||
| auto cached = g_ConvertCache.find(lst); | |||
| @@ -64,17 +64,24 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst) { | |||
| result.inputs = inputs; | |||
| result.outputs = outputs; | |||
| result.graph_id = kInvalidGraphId; | |||
| auto graph_id = sess_->CompileGraph(lst, outputs); | |||
| if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { | |||
| sess_->BuildGraph(graph_id); | |||
| GraphId graph_id = kInvalidGraphId; | |||
| if (target == kCPUDevice) { | |||
| graph_id = cpu_sess_->CompileGraph(lst, outputs); | |||
| } else { | |||
| graph_id = target_sess_->CompileGraph(lst, outputs); | |||
| } | |||
| if (MsContext::GetInstance()->precompile_only()) { | |||
| MS_LOG(INFO) << "PrecompileOnly, stop run graph"; | |||
| return result; | |||
| } | |||
| if (target == kCPUDevice) { | |||
| cpu_sess_->BuildGraph(graph_id); | |||
| } else if (!is_multi_graph_sink_) { | |||
| target_sess_->BuildGraph(graph_id); | |||
| } | |||
| result.run = std::make_shared<RunFunc>( | |||
| [graph_id, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args); }); | |||
| [graph_id, target, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args, target); }); | |||
| MS_EXCEPTION_IF_NULL(result.run); | |||
| result.simu_run = std::make_shared<RunFunc>( | |||
| @@ -92,7 +99,7 @@ void MsBackend::SetSwitchActive(const BaseRef &c, bool cond) { | |||
| GraphId cond_g = kInvalidGraphId; | |||
| if (utils::isa<AnfNodePtr>(c)) { | |||
| cond_g = sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(c)); | |||
| cond_g = target_sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(c)); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "cond not a anf node:" << c.ToString(); | |||
| } | |||
| @@ -116,7 +123,7 @@ void MsBackend::SetSwitchActive(const BaseRef &c, bool cond) { | |||
| MS_LOG(DEBUG) << "invoke set active:" << active_g; | |||
| } | |||
| MS_LOG(DEBUG) << "switch set active:" << active_g << ", " << cond_g; | |||
| sess_->SetActive(active_g, cond_g); | |||
| target_sess_->SetActive(active_g, cond_g); | |||
| } | |||
| void MsBackend::SetSwitchGraph() { | |||
| @@ -135,12 +142,12 @@ void MsBackend::SetSwitchGraph() { | |||
| } | |||
| GraphId cond_g = kInvalidGraphId; | |||
| if (utils::isa<AnfNodePtr>(curr_switch_)) { | |||
| cond_g = sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(curr_switch_)); | |||
| cond_g = target_sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(curr_switch_)); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "cond not a anf node:" << curr_switch_.ToString(); | |||
| } | |||
| MS_LOG(DEBUG) << "switch compile:" << cond_g << ", " << true_g << ", " << false_g; | |||
| sess_->SwitchCompile(cond_g, true_g, false_g, utils::cast<AnfNodePtr>(curr_switch_)); | |||
| target_sess_->SwitchCompile(cond_g, true_g, false_g, utils::cast<AnfNodePtr>(curr_switch_)); | |||
| } | |||
| is_switch_call_ = false; | |||
| MS_LOG(DEBUG) << "end SetSwitchGraph:" << curr_cond << ", " << is_switch_call_; | |||
| @@ -202,7 +209,7 @@ void MsBackend::RecallGraphInput(const FuncGraphPtr &func_graph, const VectorRef | |||
| old_args[i] = args[it->second]; | |||
| } | |||
| } | |||
| sess_->SetChildGraphInput(graph, old_args); | |||
| target_sess_->SetChildGraphInput(graph, old_args); | |||
| } | |||
| graph_inputs_.erase(c); | |||
| } | |||
| @@ -211,7 +218,7 @@ void MsBackend::RecallGraphInput(const FuncGraphPtr &func_graph, const VectorRef | |||
| VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) { | |||
| MS_LOG(DEBUG) << "set graph input:" << g; | |||
| // switch maybe twice | |||
| sess_->SetChildGraphInput(g, args); | |||
| target_sess_->SetChildGraphInput(g, args); | |||
| if (is_switch_call_) { | |||
| if (!curr_switch_.is_null()) { | |||
| @@ -236,7 +243,7 @@ VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) { | |||
| return VectorRef(outputs); | |||
| } | |||
| VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args) { | |||
| VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) { | |||
| MS_LOG(DEBUG) << "start ms graph run:" << args.size() << ", g:" << g; | |||
| // Run graph | |||
| std::vector<tensor::TensorPtr> inputs; | |||
| @@ -271,7 +278,12 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args) { | |||
| VectorRef outputs; | |||
| // call ms rungraph (graphId, input ,output) | |||
| sess_->RunGraph(g, inputs, &outputs); | |||
| if (target == kCPUDevice) { | |||
| cpu_sess_->RunGraph(g, inputs, &outputs); | |||
| } else { | |||
| target_sess_->RunGraph(g, inputs, &outputs); | |||
| } | |||
| MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size(); | |||
| return outputs; | |||
| } | |||
| @@ -300,17 +312,17 @@ void MsBackend::SimulateRun(FinalVMPtr rt, FuncGraphPtr root) { | |||
| (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args), | |||
| [](const AnfNodePtr &v) { return v; }); | |||
| MS_LOG(DEBUG) << "Simulate start"; | |||
| (void)sess_->SetFinalGraphInput(parameters); | |||
| (void)target_sess_->SetFinalGraphInput(parameters); | |||
| BaseRef output = rt->Eval(VectorRef(args)); | |||
| sess_->SetFinalGraphOutput(output); | |||
| target_sess_->SetFinalGraphOutput(output); | |||
| MS_LOG(DEBUG) << "Simulate Eval end"; | |||
| } | |||
| void MsBackend::Link(GraphId graph_id) { | |||
| if (graph_id == kInvalidGraphId) { | |||
| graph_id = sess_->GetFinalRunGraph(); | |||
| graph_id = target_sess_->GetFinalRunGraph(); | |||
| } | |||
| sess_->BuildGraph(graph_id); | |||
| target_sess_->BuildGraph(graph_id); | |||
| } | |||
| Backend::Backend(const std::string &name) : name_(name) { | |||
| @@ -322,16 +334,26 @@ Backend::Backend(const std::string &name) : name_(name) { | |||
| } | |||
| MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) { | |||
| convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1); | |||
| sess_ = session::SessionFactory::Get().Create(target); | |||
| if (sess_ == nullptr) { | |||
| convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2); | |||
| target_sess_ = session::SessionFactory::Get().Create(target); | |||
| if (target_sess_ == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available."; | |||
| } | |||
| sess_->Init(device_id); | |||
| sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); | |||
| target_sess_->Init(device_id); | |||
| target_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); | |||
| if (target == kCPUDevice) { | |||
| cpu_sess_ = target_sess_; | |||
| } else { | |||
| cpu_sess_ = session::SessionFactory::Get().Create(kCPUDevice); | |||
| if (cpu_sess_ == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Create cpu session failed with target " << target << "."; | |||
| } | |||
| cpu_sess_->Init(0); | |||
| cpu_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); | |||
| } | |||
| } | |||
| GraphId MsBackend::CompileGraph(NotNull<FuncGraphPtr> fg) { return sess_->CompileGraph(fg); } | |||
| GraphId MsBackend::CompileGraph(NotNull<FuncGraphPtr> fg) { return target_sess_->CompileGraph(fg); } | |||
| VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); } | |||
| @@ -91,8 +91,8 @@ class MsBackend : public Backend { | |||
| MsBackend(const std::string &name, const std::string &target, uint32_t device_id); | |||
| ~MsBackend() override = default; | |||
| LinConvertResult MsConvert(const AnfNodePtrList &lst); | |||
| VectorRef MsRunGraph(const GraphId &g, const VectorRef &args); | |||
| LinConvertResult MsConvert(const AnfNodePtrList &lst, const std::string &target = ""); | |||
| VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = ""); | |||
| VectorRef MsSimuRunGraph(const GraphId &g, const VectorRef &args); | |||
| void SimulateRun(FinalVMPtr rt, FuncGraphPtr root) override; | |||
| @@ -109,7 +109,8 @@ class MsBackend : public Backend { | |||
| VectorRef RunGraph(GraphId graph_id, const VectorRef &args); | |||
| private: | |||
| session::SessionPtr sess_; | |||
| session::SessionPtr target_sess_; | |||
| session::SessionPtr cpu_sess_; | |||
| std::unordered_map<BaseRef, CondGraph, BaseRefHash> simu_cond_map_; | |||
| std::unordered_map<GraphId, LinConvertResult> graph_id_map_; | |||
| std::unordered_map<BaseRef, std::list<std::pair<GraphId, VectorRef>>, BaseRefHash> graph_inputs_; | |||
| @@ -148,7 +148,7 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr | |||
| // This implementation will convert the nodes into a subgraph | |||
| // that will run using the MsVM. | |||
| template <typename T> | |||
| LinConvertResult Convert(const AnfNodePtrList &lst) { | |||
| LinConvertResult Convert(const AnfNodePtrList &lst, const std::string &) { | |||
| auto cached = g_ConvertCache.find(lst); | |||
| if (cached != g_ConvertCache.end()) { | |||
| return cached->second; | |||
| @@ -43,7 +43,7 @@ struct LinConvertResult { | |||
| uint32_t graph_id; | |||
| }; | |||
| using LinkFuncType = std::function<LinConvertResult(const AnfNodePtrList &)>; | |||
| using LinkFuncType = std::function<LinConvertResult(const AnfNodePtrList &, const std::string &)>; | |||
| using ConvertCache = std::unordered_map<BaseRef, LinConvertResult, BaseRefHash>; | |||
| extern LinkFuncType MsVmConvert; | |||
| extern LinkFuncType GeVmConvert; | |||
| @@ -20,6 +20,8 @@ | |||
| #include <algorithm> | |||
| #include <map> | |||
| #include <queue> | |||
| #include <set> | |||
| #include <string> | |||
| #include <vector> | |||
| @@ -47,6 +49,86 @@ const std::vector<PrimitivePtr> &GetMsNonlinearOps() { | |||
| return ms_nonlinear_ops; | |||
| } | |||
| namespace { | |||
| std::string GetCNodeTarget(const AnfNodePtr &node) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| std::string default_target = context_ptr->device_target(); | |||
| if (!node->isa<CNode>()) { | |||
| return default_target; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto attr_input = cnode->input(kAnfPrimitiveIndex); | |||
| if (attr_input == nullptr) { | |||
| return default_target; | |||
| } | |||
| auto value_node = attr_input->cast<ValueNodePtr>(); | |||
| if (value_node == nullptr) { | |||
| return default_target; | |||
| } | |||
| auto value = value_node->value(); | |||
| if (value == nullptr) { | |||
| return default_target; | |||
| } | |||
| if (!value->isa<Primitive>()) { | |||
| return default_target; | |||
| } | |||
| auto primitive = value->cast<PrimitivePtr>(); | |||
| ValuePtr att_target = primitive->GetAttr("target"); | |||
| if (att_target != nullptr) { | |||
| std::string target = GetValue<std::string>(att_target); | |||
| return target; | |||
| } | |||
| return default_target; | |||
| } | |||
| bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| std::string last_target = context_ptr->device_target(); | |||
| for (auto &node : nodes) { | |||
| if (node->isa<CNode>()) { | |||
| std::string cur_target = GetCNodeTarget(node); | |||
| if (last_target != cur_target) { | |||
| return true; | |||
| } | |||
| last_target = cur_target; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *nodes_ref) { | |||
| std::queue<AnfNodePtr> queue; | |||
| queue.push(graph->get_return()); | |||
| std::set<AnfNodePtr> visited; | |||
| while (!queue.empty()) { | |||
| auto &node = queue.front(); | |||
| queue.pop(); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| for (auto &input : cnode->inputs()) { | |||
| auto iter = nodes_ref->find(input); | |||
| if (iter != nodes_ref->end()) { | |||
| iter->second++; | |||
| } else { | |||
| (void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(input, 1)); | |||
| } | |||
| if (visited.find(input) != visited.end()) { | |||
| continue; | |||
| } | |||
| visited.insert(input); | |||
| queue.push(input); | |||
| } | |||
| } | |||
| } | |||
| } // namespace | |||
| CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list) | |||
| : backend_(backend), cut_list_(cut_list) { | |||
| MS_EXCEPTION_IF_NULL(backend_); | |||
| @@ -98,12 +180,67 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) { | |||
| return false; | |||
| } | |||
| std::vector<AnfNodePtr> CompileGraph::SplitSort(const FuncGraphPtr &graph) { | |||
| std::vector<AnfNodePtr> result; | |||
| std::queue<AnfNodePtr> queue; | |||
| std::queue<AnfNodePtr> next_queue; | |||
| std::map<AnfNodePtr, size_t> nodes_ref; | |||
| CalcNodeRefCount(graph, &nodes_ref); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| std::string queue_target = context_ptr->device_target(); | |||
| std::string next_target = ""; | |||
| queue.push(graph->get_return()); | |||
| while (!queue.empty() || !next_queue.empty()) { | |||
| if (queue.empty()) { | |||
| queue.swap(next_queue); | |||
| queue_target = next_target; | |||
| } | |||
| auto &node = queue.front(); | |||
| queue.pop(); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| result.emplace_back(node); | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| for (auto &input : cnode->inputs()) { | |||
| auto iter = nodes_ref.find(input); | |||
| if (iter != nodes_ref.end()) { | |||
| iter->second--; | |||
| if (iter->second != 0) { | |||
| continue; | |||
| } | |||
| } | |||
| if (!input->isa<CNode>()) { | |||
| queue.push(input); | |||
| continue; | |||
| } | |||
| std::string input_target = GetCNodeTarget(input); | |||
| if (input_target == queue_target) { | |||
| queue.push(input); | |||
| } else if (next_queue.empty() || input_target == next_target) { | |||
| next_queue.push(input); | |||
| next_target = input_target; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "only support two different target"; | |||
| } | |||
| } | |||
| } | |||
| std::reverse(result.begin(), result.end()); | |||
| return result; | |||
| } | |||
| VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| VectorRef splits; | |||
| VectorRef split; | |||
| std::vector<AnfNodePtr> nodes = TopoSort(graph->get_return()); | |||
| auto nodes = TopoSort(graph->get_return()); | |||
| if (ContainMultiTarget(nodes)) { | |||
| nodes = SplitSort(graph); | |||
| } | |||
| std::string last_target; | |||
| MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size(); | |||
| for (auto &node : nodes) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| @@ -114,7 +251,13 @@ VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { | |||
| } | |||
| splits.push_back(node); | |||
| split.clear(); | |||
| } else if (!(node->isa<ValueNode>() || node->isa<Parameter>())) { | |||
| } else if (node->isa<CNode>()) { | |||
| std::string cur_target = GetCNodeTarget(node); | |||
| if (cur_target != last_target && !last_target.empty() && split.size() != 0) { | |||
| splits.push_back(split); | |||
| split.clear(); | |||
| } | |||
| last_target = cur_target; | |||
| split.push_back(node); | |||
| MS_LOG(DEBUG) << "Insert node:" << node->DebugString(10) << ", size:" << split.size(); | |||
| } | |||
| @@ -200,14 +343,14 @@ void CompileGraph::PushParameters(const FuncGraphPtr &graph) { | |||
| } | |||
| } | |||
| int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &node_list) { | |||
| int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &node_list, const std::string &target) { | |||
| MS_LOG(DEBUG) << "LinConvert start"; | |||
| LinConvertResult result; | |||
| if (backend_->simu_flag()) { | |||
| result = backend_->GetMultiGraphRun(graph); | |||
| } else { | |||
| result = lin_convert_(node_list); | |||
| result = lin_convert_(node_list, target); | |||
| } | |||
| if (result.run == nullptr) { | |||
| @@ -316,7 +459,12 @@ bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) { | |||
| auto vec_ref = utils::cast<VectorRef>(split); | |||
| (void)std::transform(vec_ref.begin(), vec_ref.end(), std::back_inserter(args), | |||
| [](const BaseRef &v) { return utils::cast<AnfNodePtr>(v); }); | |||
| ret = LinConvert(graph, args); | |||
| if (args.size() > 0) { | |||
| std::string cur_target = GetCNodeTarget(args[0]); | |||
| ret = LinConvert(graph, args, cur_target); | |||
| } else { | |||
| ret = LinConvert(graph, args); | |||
| } | |||
| MS_LOG(DEBUG) << "End a extern LinConvert"; | |||
| if (ret == RET_FAILED) { | |||
| return false; | |||
| @@ -637,6 +785,19 @@ FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) { | |||
| return rt; | |||
| } | |||
| bool CompileGraphs::ContainMixedTarget(const FuncGraphPtr &graph) { | |||
| auto graph_manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(graph_manager); | |||
| FuncGraphSet graphs = graph_manager->func_graphs(); | |||
| for (auto &g : graphs) { | |||
| auto nodes = TopoSort(g->get_return()); | |||
| if (ContainMultiTarget(nodes)) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| BackendPtr CreateBackend() { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| @@ -79,8 +79,9 @@ class CompileGraph { | |||
| private: | |||
| void PushParameters(const FuncGraphPtr &func_graph); | |||
| std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph); | |||
| bool SplitGraph(const FuncGraphPtr &func_graph); | |||
| int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list); | |||
| int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list, const std::string &target = ""); | |||
| int InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node); | |||
| int AddCall(const FuncGraphPtr &graph, const CNodePtr &node); | |||
| void AddSinkSwitch(const CNodePtr &node); | |||
| @@ -124,6 +125,7 @@ class CompileGraphs { | |||
| void Compile(const FuncGraphPtr &func_graph); | |||
| FinalVMPtr Link(const FuncGraphPtr &func_graph); | |||
| FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph); | |||
| bool ContainMixedTarget(const FuncGraphPtr &graph); | |||
| private: | |||
| InstSet insts_; | |||
| @@ -65,7 +65,7 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert1) { | |||
| for (auto &item : utils::cast<VectorRef>(todos[0])) { | |||
| anf_list.push_back(utils::cast<AnfNodePtr>(item)); | |||
| } | |||
| auto convertResult = MsVmConvert(anf_list); | |||
| auto convertResult = MsVmConvert(anf_list, ""); | |||
| auto runResult = (*(convertResult.run))(args); | |||
| ASSERT_TRUE(runResult.size() == 1 && py::cast<double>(BaseRefToPyData(runResult[0])) == 3.0); | |||
| } | |||
| @@ -89,7 +89,7 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert2) { | |||
| for (auto &item : utils::cast<VectorRef>(todos[0])) { | |||
| anf_list.push_back(utils::cast<AnfNodePtr>(item)); | |||
| } | |||
| auto convertResult = MsVmConvert(anf_list); | |||
| auto convertResult = MsVmConvert(anf_list, ""); | |||
| auto runResult = (*(convertResult.run))(args); | |||
| ASSERT_TRUE(runResult.size() == 1 && py::cast<double>(BaseRefToPyData(runResult[0])) == 2.0); | |||
| } | |||
| @@ -113,7 +113,7 @@ TEST_F(TestCompileSegmentRunner, test_if) { | |||
| for (auto &item : utils::cast<VectorRef>(todos[0])) { | |||
| anf_list.push_back(utils::cast<AnfNodePtr>(item)); | |||
| } | |||
| auto convertResult = MsVmConvert(anf_list); | |||
| auto convertResult = MsVmConvert(anf_list, ""); | |||
| auto runResult = (*(convertResult.run))(args); | |||
| auto result = py::cast<bool>(BaseRefToPyData(runResult[0])); | |||