| @@ -55,8 +55,8 @@ void AscendInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &k | |||
| } | |||
| } | |||
| GraphId AscendInferenceSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||
| auto graph_id = AscendSession::CompileGraph(func_graph); | |||
| GraphId AscendInferenceSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { | |||
| auto graph_id = AscendSession::CompileGraphImpl(func_graph); | |||
| auto kernel_graph = GetGraph(graph_id); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| // load weight data to device | |||
| @@ -38,7 +38,6 @@ class AscendInferenceSession : public AscendSession { | |||
| ~AscendInferenceSession() = default; | |||
| void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | |||
| const std::vector<tensor::TensorPtr> &inputs_const) const; | |||
| GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) override; | |||
| bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||
| std::string *error_msg) const override; | |||
| bool CompareInput(const tensor::TensorPtr &input, const ParameterPtr ¶meter) const; | |||
| @@ -46,6 +45,9 @@ class AscendInferenceSession : public AscendSession { | |||
| std::string PrintInputShape(std::vector<T> shape) const; | |||
| std::string InputsInfo(const std::vector<ParameterPtr> ¶s, const std::vector<tensor::TensorPtr> &inputs) const; | |||
| void GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs) const override; | |||
| protected: | |||
| GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override; | |||
| }; | |||
| MS_REG_SESSION(kDavinciInferenceDevice, AscendInferenceSession); | |||
| } // namespace session | |||
| @@ -114,7 +114,7 @@ void InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_graph) { | |||
| } | |||
| } // namespace | |||
| GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | |||
| GraphId AscendSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | |||
| MS_LOG(INFO) << "Start"; | |||
| // construct graph, if successfully, graph_sum_ + 1 | |||
| auto graph = ConstructKernelGraph(lst, outputs); | |||
| @@ -123,7 +123,7 @@ GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrL | |||
| return graph_id; | |||
| } | |||
| GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||
| GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { | |||
| MS_LOG(INFO) << "Start"; | |||
| std::vector<KernelGraphPtr> all_graphs; | |||
| auto root_graph = ConstructKernelGraph(func_graph, &all_graphs); | |||
| @@ -205,7 +205,7 @@ void AscendSession::SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> | |||
| kernel_graph->set_summary_node_exist(false); | |||
| } | |||
| void AscendSession::BuildGraph(GraphId graph_id) { | |||
| void AscendSession::BuildGraphImpl(GraphId graph_id) { | |||
| MS_LOG(INFO) << "Start"; | |||
| auto graph = GetGraph(graph_id); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| @@ -301,8 +301,8 @@ void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) { | |||
| runtime_instance->AssignStaticMemoryValueNode(child_graph.get()); | |||
| } | |||
| void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||
| VectorRef *const outputs) { | |||
| void AscendSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||
| VectorRef *const outputs) { | |||
| MS_LOG(INFO) << "Start"; | |||
| auto kernel_graph = GetGraph(graph_id); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| @@ -350,8 +350,9 @@ bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const { | |||
| return run_op_graphs_.find(graph_info) != run_op_graphs_.end(); | |||
| } | |||
| void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask) { | |||
| void AscendSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, | |||
| const std::vector<int> &tensors_mask) { | |||
| MS_LOG(INFO) << "Build op " << op_run_info.op_name << " start !"; | |||
| if (GraphCacheExist(graph_info)) { | |||
| MS_LOG(INFO) << "Build op " << op_run_info.op_name << " graph cache has existed !"; | |||
| @@ -375,8 +376,8 @@ void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph | |||
| MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !"; | |||
| } | |||
| void AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) { | |||
| void AscendSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) { | |||
| auto graph = run_op_graphs_[graph_info]; | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!"; | |||
| @@ -1049,9 +1050,9 @@ void AscendSession::UpdateRefOutputMap(NotNull<KernelGraphPtr> graph, | |||
| } | |||
| } | |||
| GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph, const vector<tensor::TensorPtr> &inputs) { | |||
| GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph, const vector<tensor::TensorPtr> &inputs) { | |||
| RunInfer(func_graph, inputs); | |||
| return CompileGraph(func_graph); | |||
| return CompileGraphImpl(func_graph); | |||
| } | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -40,7 +40,16 @@ enum GraphType : int { COMMON_GRAPH = 0, CONDITION_GRAPH = 1, BRANCH_START = 2, | |||
| class AscendSession : public SessionBasic { | |||
| public: | |||
| AscendSession() { final_graph_id_ = kInvalidGraphId; } | |||
| ~AscendSession() override = default; | |||
| ~AscendSession() { | |||
| if (rt_context_ != nullptr) { | |||
| auto ret = rtCtxDestroy(rt_context_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "Call rtCtxDestroy, ret[" << ret << "]"; | |||
| } | |||
| rt_context_ = nullptr; | |||
| } | |||
| } | |||
| void Init(uint32_t device_id) override { | |||
| InitDevice(kAscendDevice, device_id); | |||
| auto ret = rtCtxCreate(&rt_context_, 0, device_id); | |||
| @@ -52,24 +61,26 @@ class AscendSession : public SessionBasic { | |||
| MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]"; | |||
| } | |||
| } | |||
| GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | |||
| GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) override; | |||
| GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) override; | |||
| void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | |||
| void BuildGraph(GraphId) override; | |||
| void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask) override; | |||
| void RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) override; | |||
| // get graph id in child graphs by ME front anf node pointer | |||
| GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const override; | |||
| // get graph id of final graph | |||
| GraphId GetFinalRunGraph() const override { return final_graph_id_; } | |||
| // compile child graph when session have multiple child graphs | |||
| void CompileChildGraph(const KernelGraphPtr &child_graph); | |||
| protected: | |||
| GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | |||
| GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override; | |||
| GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) override; | |||
| void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | |||
| void BuildGraphImpl(GraphId) override; | |||
| void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask) override; | |||
| void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) override; | |||
| private: | |||
| // compile child graph when session have multiple child graphs | |||
| void CompileChildGraph(const KernelGraphPtr &child_graph); | |||
| void RecurseSetSummaryNodes(KernelGraph *graph, std::map<std::string, std::pair<AnfNodePtr, int>> *summary); | |||
| void SetSummaryNodes(KernelGraph *graph) override; | |||
| void InitRuntimeResource(); | |||
| @@ -61,7 +61,7 @@ void CPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||
| kernel_graph->SetExecOrderByDefault(); | |||
| } | |||
| GraphId CPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | |||
| GraphId CPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | |||
| auto graph_id = graph_sum_; | |||
| auto graph = ConstructKernelGraph(lst, outputs); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| @@ -85,14 +85,16 @@ void CPUSession::CreateOutputTensors(const GraphId &graph_id, const std::vector< | |||
| std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) { | |||
| auto kernel_graph = GetGraph(graph_id); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| MS_LOG(INFO) << "Bind input output address"; | |||
| runtime_.BindInputOutput(kernel_graph.get(), input_tensors, outputs); | |||
| return; | |||
| runtime_.CreateOutputTensors(kernel_graph.get(), input_tensors, outputs); | |||
| } | |||
| void CPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) { | |||
| void CPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||
| VectorRef *outputs) { | |||
| auto kernel_graph = GetGraph(graph_id); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| MS_LOG(INFO) << "Bind input output address"; | |||
| runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| InitPSParamAndOptim(kernel_graph, inputs); | |||
| #endif | |||
| @@ -30,13 +30,12 @@ class CPUSession : public SessionBasic { | |||
| CPUSession() = default; | |||
| ~CPUSession() override = default; | |||
| void Init(uint32_t device_id) override { InitDevice(kCPUDevice, device_id); } | |||
| GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | |||
| void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | |||
| protected: | |||
| void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *, | |||
| std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) override; | |||
| protected: | |||
| GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | |||
| void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | |||
| ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) override; | |||
| void Optimize(const std::shared_ptr<KernelGraph> &kernel_graph); | |||
| @@ -32,7 +32,6 @@ void UpdateOutputTensors(const VectorRef *outputs, | |||
| } else if (utils::isa<tensor::TensorPtr>(item)) { | |||
| auto tensor = utils::cast<tensor::TensorPtr>(item); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| tensor->SetNeedWait(false); | |||
| auto iter = tensor_to_node.find(tensor); | |||
| if (iter != tensor_to_node.end()) { | |||
| auto &node = iter->second.first; | |||
| @@ -41,44 +40,67 @@ void UpdateOutputTensors(const VectorRef *outputs, | |||
| tensor->set_device_address(address); | |||
| } | |||
| if (tensor->NeedSyncDeviceToHostImmediately()) { | |||
| tensor->data_sync(); | |||
| auto tensor_address = tensor->device_address(); | |||
| MS_EXCEPTION_IF_NULL(tensor_address); | |||
| tensor_address->SyncDeviceToHost(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||
| tensor->data_c()); | |||
| tensor->set_device_address(nullptr); | |||
| tensor->set_sync_status(kNeedSyncHostToDevice); | |||
| } | |||
| tensor->SetNeedWait(false); | |||
| } | |||
| } | |||
| } | |||
| bool TensorInVector(const VectorRef *outputs) { | |||
| MS_EXCEPTION_IF_NULL(outputs); | |||
| for (auto item : *outputs) { | |||
| if (utils::isa<VectorRefPtr>(item)) { | |||
| auto vector_ref = utils::cast<VectorRef>(item); | |||
| if (TensorInVector(&vector_ref)) { | |||
| return true; | |||
| } | |||
| } else if (utils::isa<tensor::TensorPtr>(item)) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace | |||
| void CompileNodesTask::Run() { | |||
| MS_EXCEPTION_IF_NULL(session_); | |||
| graph_id_ = session_->CompileGraph(nodes_, output_nodes_); | |||
| graph_id_ = session_->CompileGraphImpl(nodes_, output_nodes_); | |||
| } | |||
| void CompileGraphTask::Run() { | |||
| MS_EXCEPTION_IF_NULL(session_); | |||
| graph_id_ = session_->CompileGraph(NOT_NULL(func_graph_)); | |||
| graph_id_ = session_->CompileGraphImpl(NOT_NULL(func_graph_)); | |||
| } | |||
| void BuildGraphTask::Run() { | |||
| MS_EXCEPTION_IF_NULL(session_); | |||
| session_->BuildGraph(graph_id_); | |||
| session_->BuildGraphImpl(graph_id_); | |||
| } | |||
| void RunGraphTask::Run() { | |||
| MS_EXCEPTION_IF_NULL(session_); | |||
| session_->RunGraph(graph_id_, input_tensors_, &outputs_); | |||
| session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_); | |||
| UpdateOutputTensors(&outputs_, tensor_to_node_); | |||
| for (auto &tensor : input_need_lock_tensors_) { | |||
| tensor->SetNeedWait(false); | |||
| } | |||
| ExecutorManager::Instance().OnRunGraphFinished(); | |||
| } | |||
| void BuildOpTask::Run() { | |||
| MS_EXCEPTION_IF_NULL(session_); | |||
| session_->BuildOp(*op_run_info_, graph_info_, input_tensors_, tensors_mask_); | |||
| session_->BuildOpImpl(*op_run_info_, graph_info_, input_tensors_, tensors_mask_); | |||
| } | |||
| void RunOpTask::Run() { | |||
| MS_EXCEPTION_IF_NULL(session_); | |||
| session_->RunOp(*op_run_info_, graph_info_, input_tensors_, &outputs_); | |||
| session_->RunOpImpl(*op_run_info_, graph_info_, input_tensors_, &outputs_); | |||
| } | |||
| void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); } | |||
| @@ -132,8 +154,12 @@ void Executor::WorkerLoop() { | |||
| } catch (const std::exception &e) { | |||
| exception_ptr_ = std::current_exception(); | |||
| } | |||
| task = nullptr; | |||
| sync_cond_var_.notify_all(); | |||
| if (task->type_ != kRunGraph || task->sync_run_) { | |||
| task = nullptr; | |||
| sync_cond_var_.notify_all(); | |||
| } else { | |||
| task = nullptr; | |||
| } | |||
| } | |||
| } | |||
| @@ -142,7 +168,7 @@ std::vector<std::shared_ptr<RunGraphTask>> Executor::GetNewReadyTasks() { | |||
| std::unique_lock<std::mutex> lock(pending_task_mutex_); | |||
| for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end();) { | |||
| auto task = *iter; | |||
| if (IsAllInputsReady(task->input_tensors_)) { | |||
| if (IsTaskReady(task)) { | |||
| new_ready_tasks.emplace_back(task); | |||
| pending_tasks_.erase(iter++); | |||
| } else { | |||
| @@ -163,8 +189,9 @@ void Executor::OnRunGraphFinished() { | |||
| } | |||
| } | |||
| bool Executor::IsAllInputsReady(const std::vector<tensor::TensorPtr> &inputs) { | |||
| for (auto &input : inputs) { | |||
| bool Executor::IsTaskReady(const std::shared_ptr<RunGraphTask> &task) { | |||
| MS_EXCEPTION_IF_NULL(task); | |||
| for (auto &input : task->input_need_wait_tensors_) { | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| if (input->NeedWait()) { | |||
| return false; | |||
| @@ -173,8 +200,7 @@ bool Executor::IsAllInputsReady(const std::vector<tensor::TensorPtr> &inputs) { | |||
| return true; | |||
| } | |||
| GraphId Executor::CompileGraphAsync(const SessionPtr &session, const AnfNodePtrList &lst, | |||
| const AnfNodePtrList &outputs) { | |||
| GraphId Executor::CompileGraph(const SessionPtr &session, const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | |||
| CheckException(); | |||
| std::unique_lock<std::mutex> lock(task_mutex_); | |||
| auto task = std::make_shared<CompileNodesTask>(); | |||
| @@ -188,7 +214,7 @@ GraphId Executor::CompileGraphAsync(const SessionPtr &session, const AnfNodePtrL | |||
| return task->graph_id_; | |||
| } | |||
| GraphId Executor::CompileGraphAsync(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph) { | |||
| GraphId Executor::CompileGraph(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph) { | |||
| CheckException(); | |||
| std::unique_lock<std::mutex> lock(task_mutex_); | |||
| auto task = std::make_shared<CompileGraphTask>(); | |||
| @@ -201,7 +227,7 @@ GraphId Executor::CompileGraphAsync(const SessionPtr &session, NotNull<FuncGraph | |||
| return task->graph_id_; | |||
| } | |||
| void Executor::BuildGraphAsync(const SessionPtr &session, GraphId graphId) { | |||
| void Executor::BuildGraph(const SessionPtr &session, GraphId graphId) { | |||
| CheckException(); | |||
| std::unique_lock<std::mutex> lock(task_mutex_); | |||
| auto task = std::make_shared<BuildGraphTask>(); | |||
| @@ -213,19 +239,62 @@ void Executor::BuildGraphAsync(const SessionPtr &session, GraphId graphId) { | |||
| CheckException(); | |||
| } | |||
| void Executor::RunGraph(const SessionPtr &session, const GraphId &graph_id, | |||
| const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) { | |||
| CheckException(); | |||
| MS_EXCEPTION_IF_NULL(session); | |||
| MS_EXCEPTION_IF_NULL(outputs); | |||
| auto task = std::make_shared<RunGraphTask>(); | |||
| task->session_ = session; | |||
| task->graph_id_ = graph_id; | |||
| task->input_tensors_ = inputs; | |||
| session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_); | |||
| task->outputs_ = *outputs; | |||
| task->sync_run_ = true; | |||
| std::unique_lock<std::mutex> lock(task_mutex_); | |||
| ready_tasks_.push(task); | |||
| task_cond_var_.notify_all(); | |||
| mindspore::ScopedLongRunning long_running; | |||
| sync_cond_var_.wait(lock); | |||
| CheckException(); | |||
| } | |||
| void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, | |||
| const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) { | |||
| CheckException(); | |||
| MS_EXCEPTION_IF_NULL(session); | |||
| MS_EXCEPTION_IF_NULL(outputs); | |||
| auto task = std::make_shared<RunGraphTask>(); | |||
| task->session_ = session; | |||
| task->graph_id_ = graph_id; | |||
| task->input_tensors_ = inputs; | |||
| MS_EXCEPTION_IF_NULL(session); | |||
| // lock inputs | |||
| for (auto &tensor : inputs) { | |||
| if (tensor->NeedWait()) { | |||
| task->input_need_wait_tensors_.emplace_back(tensor); | |||
| } | |||
| } | |||
| task->input_need_lock_tensors_ = session->GetNeedLockInputTensors(graph_id, inputs); | |||
| for (auto &tensor : task->input_need_lock_tensors_) { | |||
| tensor->SetNeedWait(true); | |||
| } | |||
| session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_); | |||
| // maintain a copy of output vector | |||
| task->outputs_ = *outputs; | |||
| bool ready = IsAllInputsReady(inputs); | |||
| // sync run graph without output tensor(int dataset graph) | |||
| if (!TensorInVector(outputs)) { | |||
| task->sync_run_ = true; | |||
| std::unique_lock<std::mutex> lock(task_mutex_); | |||
| ready_tasks_.push(task); | |||
| task_cond_var_.notify_all(); | |||
| mindspore::ScopedLongRunning long_running; | |||
| sync_cond_var_.wait(lock); | |||
| CheckException(); | |||
| return; | |||
| } | |||
| bool ready = IsTaskReady(task); | |||
| if (!ready) { | |||
| std::unique_lock<std::mutex> lock(pending_task_mutex_); | |||
| pending_tasks_.push_back(task); | |||
| @@ -234,13 +303,10 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, | |||
| std::unique_lock<std::mutex> lock(task_mutex_); | |||
| ready_tasks_.push(task); | |||
| task_cond_var_.notify_all(); | |||
| mindspore::ScopedLongRunning long_running; | |||
| sync_cond_var_.wait(lock); | |||
| CheckException(); | |||
| } | |||
| void Executor::BuildOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask) { | |||
| void Executor::BuildOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask) { | |||
| CheckException(); | |||
| std::unique_lock<std::mutex> lock(task_mutex_); | |||
| auto task = std::make_shared<BuildOpTask>(); | |||
| @@ -255,8 +321,8 @@ void Executor::BuildOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, c | |||
| CheckException(); | |||
| } | |||
| void Executor::RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) { | |||
| void Executor::RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) { | |||
| CheckException(); | |||
| std::unique_lock<std::mutex> lock(task_mutex_); | |||
| auto task = std::make_shared<RunOpTask>(); | |||
| @@ -55,6 +55,7 @@ class Task { | |||
| virtual ~Task() = default; | |||
| SessionPtr session_{nullptr}; | |||
| TaskType type_{kUnKnown}; | |||
| bool sync_run_{false}; | |||
| virtual void Run() {} | |||
| }; | |||
| @@ -91,6 +92,8 @@ class RunGraphTask : public Task { | |||
| ~RunGraphTask() override = default; | |||
| void Run() override; | |||
| std::vector<tensor::TensorPtr> input_tensors_; | |||
| std::vector<tensor::TensorPtr> input_need_wait_tensors_; | |||
| std::vector<tensor::TensorPtr> input_need_lock_tensors_; | |||
| VectorRef outputs_; | |||
| GraphId graph_id_{0}; | |||
| std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node_; | |||
| @@ -149,15 +152,17 @@ class Executor { | |||
| ~Executor(); | |||
| void WorkerLoop(); | |||
| void WorkerJoin(); | |||
| GraphId CompileGraphAsync(const SessionPtr &session, const AnfNodePtrList &lst, const AnfNodePtrList &outputs); | |||
| GraphId CompileGraphAsync(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph); | |||
| void BuildGraphAsync(const SessionPtr &session, GraphId graphId); | |||
| GraphId CompileGraph(const SessionPtr &session, const AnfNodePtrList &lst, const AnfNodePtrList &outputs); | |||
| GraphId CompileGraph(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph); | |||
| void BuildGraph(const SessionPtr &session, GraphId graphId); | |||
| void RunGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||
| VectorRef *outputs); | |||
| void RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||
| VectorRef *outputs); | |||
| void BuildOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask); | |||
| void RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs); | |||
| void BuildOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask); | |||
| void RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs); | |||
| void OnRunGraphFinished(); | |||
| bool CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks); | |||
| bool DestroyCommGroup(const std::string &group_name); | |||
| @@ -166,7 +171,7 @@ class Executor { | |||
| void UpdateOutputTensors(VectorRef *outputs, | |||
| const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node); | |||
| std::vector<std::shared_ptr<RunGraphTask>> GetNewReadyTasks(); | |||
| bool IsAllInputsReady(const std::vector<tensor::TensorPtr> &inputs); | |||
| bool IsTaskReady(const std::shared_ptr<RunGraphTask> &task); | |||
| void CheckException(); | |||
| void OnWorkerExit(); | |||
| @@ -218,7 +218,7 @@ void GPUSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const | |||
| } | |||
| } | |||
| GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | |||
| GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | |||
| // Construct graph, if successfully, graph_sum_ + 1 | |||
| auto graph_id = graph_sum_; | |||
| auto graph = ConstructKernelGraph(lst, outputs); | |||
| @@ -277,7 +277,8 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList | |||
| return graph_id; | |||
| } | |||
| void GPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) { | |||
| void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||
| VectorRef *outputs) { | |||
| auto &kernel_graph = graphs_[graph_id]; | |||
| // Load input data from user input | |||
| LoadInputData(kernel_graph, inputs); | |||
| @@ -298,8 +299,9 @@ void GPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten | |||
| PostIterationDbg(kernel_graph); | |||
| } | |||
| void GPUSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask) { | |||
| void GPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, | |||
| const std::vector<int> &tensors_mask) { | |||
| // Check if the graph cache exists. | |||
| if (run_op_graphs_.find(graph_info) != run_op_graphs_.end()) { | |||
| return; | |||
| @@ -315,8 +317,8 @@ void GPUSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_in | |||
| run_op_graphs_[graph_info] = kernel_graph; | |||
| } | |||
| void GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) { | |||
| void GPUSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) { | |||
| auto kernel_graph = run_op_graphs_[graph_info]; | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| // Remove NopOp from execution graph | |||
| @@ -31,16 +31,15 @@ class GPUSession : public SessionBasic { | |||
| public: | |||
| GPUSession() = default; | |||
| ~GPUSession() override = default; | |||
| void Init(uint32_t device_id) override { InitDevice(kGPUDevice, device_id); } | |||
| GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | |||
| void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | |||
| void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask) override; | |||
| void RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) override; | |||
| protected: | |||
| GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | |||
| void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | |||
| void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask) override; | |||
| void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) override; | |||
| private: | |||
| void SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| @@ -307,7 +307,7 @@ void MSInferSession::RegAllOp() { | |||
| Status MSInferSession::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id) { | |||
| MS_ASSERT(session_impl_ != nullptr); | |||
| try { | |||
| auto graph_id = session_impl_->CompileGraphAsync(NOT_NULL(funcGraphPtr)); | |||
| auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); | |||
| py::gil_scoped_release gil_release; | |||
| model_id = graph_id; | |||
| return SUCCESS; | |||
| @@ -321,8 +321,7 @@ std::vector<tensor::TensorPtr> MSInferSession::RunGraph(uint32_t graph_id, | |||
| const std::vector<tensor::TensorPtr> &inputs) { | |||
| try { | |||
| VectorRef outputs; | |||
| session_impl_->RunGraphAsync(graph_id, inputs, &outputs); | |||
| session_impl_->RunGraph(graph_id, inputs, &outputs); | |||
| return TransformVectorRefToMultiTensor(outputs); | |||
| } catch (std::exception &e) { | |||
| MS_LOG(ERROR) << "Inference Rungraph failed"; | |||
| @@ -100,9 +100,7 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o | |||
| } else { | |||
| tensor->set_sync_status(kNeedSyncDeviceToHost); | |||
| } | |||
| if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | |||
| tensor->SetNeedWait(true); | |||
| } | |||
| tensor->SetNeedWait(true); | |||
| return tensor; | |||
| } | |||
| @@ -953,8 +951,12 @@ bool TensorNeedSync(const AnfNodePtr ¶meter, const tensor::TensorPtr &tensor | |||
| if (tensor->NeedSyncHostToDevice()) { | |||
| return true; | |||
| } | |||
| if (tensor->device_address() != device_address) { | |||
| (void)tensor->data_sync(); | |||
| auto tensor_address = tensor->device_address(); | |||
| if (tensor_address != device_address) { | |||
| if (tensor_address != nullptr) { | |||
| tensor_address->SyncDeviceToHost(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||
| tensor->data_c()); | |||
| } | |||
| return true; | |||
| } | |||
| return false; | |||
| @@ -1025,6 +1027,30 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap | |||
| } | |||
| } | |||
| std::vector<tensor::TensorPtr> SessionBasic::GetNeedLockInputTensors(const GraphId &graph_id, | |||
| const std::vector<tensor::TensorPtr> &inputs) { | |||
| auto graph = GetGraph(graph_id); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| bool has_optimizer = false; | |||
| for (const auto &cnode : graph->execution_order()) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (kOptOperatorSet.find(AnfAlgo::GetCNodeName(cnode)) != kOptOperatorSet.end()) { | |||
| has_optimizer = true; | |||
| break; | |||
| } | |||
| } | |||
| if (!has_optimizer) { | |||
| return {}; | |||
| } | |||
| std::vector<tensor::TensorPtr> result; | |||
| for (auto &tensor : inputs) { | |||
| if (!tensor->NeedWait()) { | |||
| result.emplace_back(tensor); | |||
| } | |||
| } | |||
| return result; | |||
| } | |||
| void SessionBasic::CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors, | |||
| VectorRef *outputs, | |||
| std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) { | |||
| @@ -1341,32 +1367,36 @@ AnfNodePtr SessionBasic::FindPullNode(const AnfNodePtr &push_node, const std::ve | |||
| return nullptr; | |||
| } | |||
| GraphId SessionBasic::CompileGraphAsync(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | |||
| GraphId SessionBasic::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | |||
| MS_EXCEPTION_IF_NULL(executor_); | |||
| return executor_->CompileGraphAsync(shared_from_this(), lst, outputs); | |||
| return executor_->CompileGraph(shared_from_this(), lst, outputs); | |||
| } | |||
| GraphId SessionBasic::CompileGraphAsync(NotNull<FuncGraphPtr> func_graph) { | |||
| GraphId SessionBasic::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||
| MS_EXCEPTION_IF_NULL(executor_); | |||
| return executor_->CompileGraphAsync(shared_from_this(), func_graph); | |||
| return executor_->CompileGraph(shared_from_this(), func_graph); | |||
| } | |||
| void SessionBasic::BuildGraphAsync(GraphId graph_id) { | |||
| void SessionBasic::BuildGraph(GraphId graph_id) { | |||
| MS_EXCEPTION_IF_NULL(executor_); | |||
| executor_->BuildGraphAsync(shared_from_this(), graph_id); | |||
| executor_->BuildGraph(shared_from_this(), graph_id); | |||
| } | |||
| void SessionBasic::BuildOpAsync(OpRunInfo *op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, | |||
| const std::vector<int> &tensors_mask) { | |||
| void SessionBasic::BuildOp(OpRunInfo *op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask) { | |||
| MS_EXCEPTION_IF_NULL(executor_); | |||
| executor_->BuildOp(shared_from_this(), op_run_info, graph_info, input_tensors, tensors_mask); | |||
| } | |||
| void SessionBasic::RunOp(OpRunInfo *op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) { | |||
| MS_EXCEPTION_IF_NULL(executor_); | |||
| executor_->BuildOpAsync(shared_from_this(), op_run_info, graph_info, input_tensors, tensors_mask); | |||
| executor_->RunOp(shared_from_this(), op_run_info, graph_info, input_tensors, outputs); | |||
| } | |||
| void SessionBasic::RunOpAsync(OpRunInfo *op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) { | |||
| void SessionBasic::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) { | |||
| MS_EXCEPTION_IF_NULL(executor_); | |||
| executor_->RunOpAsync(shared_from_this(), op_run_info, graph_info, input_tensors, outputs); | |||
| executor_->RunGraph(shared_from_this(), graph_id, inputs, outputs); | |||
| } | |||
| void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||
| @@ -65,36 +65,16 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| void InitDevice(const std::string &device_name, uint32_t device_id); | |||
| virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors, | |||
| VectorRef *outputs, | |||
| std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node); | |||
| virtual ~SessionBasic() { summary_callback_ = nullptr; } | |||
| virtual GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0; | |||
| virtual GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; } | |||
| virtual GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) { | |||
| MS_EXCEPTION(NotExistsError) << "Call an empty function"; | |||
| } | |||
| // build graph, used to handle multiple child graphs | |||
| virtual void BuildGraph(GraphId) {} | |||
| virtual void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) = 0; | |||
| virtual void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask) {} | |||
| virtual void RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {} | |||
| GraphId CompileGraphAsync(const AnfNodePtrList &lst, const AnfNodePtrList &outputs); | |||
| GraphId CompileGraphAsync(NotNull<FuncGraphPtr> func_graph); | |||
| void BuildGraphAsync(GraphId graphId); | |||
| GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs); | |||
| GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph); | |||
| void BuildGraph(GraphId graphId); | |||
| void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs); | |||
| void RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs); | |||
| void BuildOpAsync(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors, | |||
| const std::vector<int> &tensors_mask); | |||
| void RunOpAsync(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors, | |||
| VectorRef *outputs); | |||
| void BuildOp(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors, | |||
| const std::vector<int> &tensors_mask); | |||
| void RunOp(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs); | |||
| virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback); | |||
| @@ -118,7 +98,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| return true; | |||
| } | |||
| virtual void GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs) const {} | |||
| std::vector<tensor::TensorPtr> GetNeedLockInputTensors(const GraphId &graph_id, | |||
| const std::vector<tensor::TensorPtr> &inputs); | |||
| #ifdef ENABLE_DEBUGGER | |||
| // set debugger | |||
| void SetDebugger() { | |||
| @@ -140,6 +121,28 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode); | |||
| protected: | |||
| friend class Executor; | |||
| friend class CompileNodesTask; | |||
| friend class CompileGraphTask; | |||
| friend class BuildGraphTask; | |||
| friend class RunGraphTask; | |||
| friend class BuildOpTask; | |||
| friend class RunOpTask; | |||
| virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors, | |||
| VectorRef *outputs, | |||
| std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node); | |||
| virtual GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0; | |||
| virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; } | |||
| virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) { | |||
| MS_EXCEPTION(NotExistsError) << "Call an empty function"; | |||
| } | |||
| virtual void BuildGraphImpl(GraphId) {} | |||
| virtual void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||
| VectorRef *outputs) = 0; | |||
| virtual void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask) {} | |||
| virtual void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {} | |||
| void RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs); | |||
| // Get graph by graph id ,if not exist return null ptr | |||
| KernelGraphPtr GetGraph(GraphId graph_id) const; | |||
| @@ -669,10 +669,10 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat | |||
| std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors); | |||
| session::OpRunInfo op_run_info = {op_exec_info->op_name, op_exec_info->py_primitive, op_exec_info->abstract, | |||
| op_exec_info->value}; | |||
| session->BuildOpAsync(&op_run_info, graph_info, input_tensors, tensors_mask); | |||
| session->BuildOp(&op_run_info, graph_info, input_tensors, tensors_mask); | |||
| EraseValueNodeTensor(tensors_mask, &input_tensors); | |||
| VectorRef outputs; | |||
| session->RunOpAsync(&op_run_info, graph_info, input_tensors, &outputs); | |||
| session->RunOp(&op_run_info, graph_info, input_tensors, &outputs); | |||
| auto result = BaseRefToPyData(outputs); | |||
| ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false); | |||
| *status = PYNATIVE_SUCCESS; | |||
| @@ -97,6 +97,12 @@ const std::set<std::string> kOpNeedTransFormat = {kOpFormat_NHWC, kOpForm | |||
| kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; | |||
| void SyncMemory(void *dst, const void *src, uint64_t size, rtMemcpyKind_t kind) { | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID); | |||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id); | |||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||
| runtime_instance->SetContext(); | |||
| auto ret_rt_memcpy = rtMemcpy(dst, size, src, size, kind); | |||
| if (ret_rt_memcpy != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "rtMemcpy failed"; | |||
| @@ -21,6 +21,7 @@ | |||
| #include <utility> | |||
| #include <exception> | |||
| #include <algorithm> | |||
| #include <thread> | |||
| #include "runtime/device/ascend/ascend_device_address.h" | |||
| #include "runtime/device/cpu/mpi/mpi_interface.h" | |||
| #include "utils/ms_context.h" | |||
| @@ -61,6 +62,7 @@ namespace mindspore { | |||
| namespace device { | |||
| namespace ascend { | |||
| static const size_t PRAMATER_OUTPUT_INDEX = 0; | |||
| static thread_local rtContext_t thread_local_rt_context{nullptr}; | |||
| namespace { | |||
| std::string GetRankId() { | |||
| std::string rank_id_str; | |||
| @@ -97,6 +99,20 @@ std::string GetRankId() { | |||
| AscendKernelRuntime::~AscendKernelRuntime() { graph_model_map_.clear(); } | |||
| void AscendKernelRuntime::SetContext() { | |||
| if (rt_context_ == nullptr) { | |||
| return; | |||
| } | |||
| if (thread_local_rt_context == rt_context_) { | |||
| return; | |||
| } | |||
| auto ret = rtCtxSetCurrent(rt_context_); | |||
| thread_local_rt_context = rt_context_; | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]"; | |||
| } | |||
| } | |||
| void AscendKernelRuntime::InnerSetContext() { | |||
| if (rt_context_ == nullptr) { | |||
| return; | |||
| } | |||
| @@ -107,7 +123,7 @@ void AscendKernelRuntime::SetContext() { | |||
| } | |||
| void AscendKernelRuntime::ClearGraphModelMap() { | |||
| SetContext(); | |||
| InnerSetContext(); | |||
| for (auto &iter : graph_data_dumper_) { | |||
| MS_LOG(INFO) << "[DataDump] Unload data dumper:" << iter.first; | |||
| auto &data_dumper = iter.second; | |||
| @@ -131,7 +147,7 @@ void AscendKernelRuntime::ClearGraphModelMap() { | |||
| void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &, | |||
| const std::unordered_set<ValueNodePtr> &, | |||
| const std::vector<CNodePtr> &) { | |||
| SetContext(); | |||
| InnerSetContext(); | |||
| MS_LOG(DEBUG) << "Clear graph:" << graph_id << " data dumper"; | |||
| if (auto dumper_iter = graph_data_dumper_.find(graph_id); dumper_iter != graph_data_dumper_.end()) { | |||
| MS_LOG(DEBUG) << "Unload dump info " << graph_id; | |||
| @@ -189,7 +205,7 @@ void AscendKernelRuntime::ReleaseDeviceRes() { | |||
| if (!initialized_) { | |||
| return; | |||
| } | |||
| SetContext(); | |||
| InnerSetContext(); | |||
| // release ge runtime | |||
| ClearGraphModelMap(); | |||
| @@ -214,7 +230,7 @@ void AscendKernelRuntime::ReleaseDeviceRes() { | |||
| bool AscendKernelRuntime::Init() { | |||
| if (initialized_) { | |||
| SetContext(); | |||
| InnerSetContext(); | |||
| return true; | |||
| } | |||
| // Start up profiling before rtSetDevice | |||
| @@ -336,7 +352,7 @@ bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) { | |||
| } | |||
| bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { | |||
| SetContext(); | |||
| InnerSetContext(); | |||
| if (graph == nullptr) { | |||
| MS_EXCEPTION(NotExistsError) << "session::KernelGraph is NULL!"; | |||
| } | |||
| @@ -390,7 +406,7 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { | |||
| } | |||
| bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { | |||
| SetContext(); | |||
| InnerSetContext(); | |||
| if (graph == nullptr) { | |||
| MS_EXCEPTION(NotExistsError) << "Null pointer graph, LoadTask failed. "; | |||
| } | |||
| @@ -505,7 +521,7 @@ bool AscendKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink, De | |||
| } | |||
| bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) { | |||
| SetContext(); | |||
| InnerSetContext(); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_LOG(INFO) << "RunTask start. GraphId:" << graph->graph_id(); | |||
| @@ -533,7 +549,7 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) { | |||
| } | |||
| bool AscendKernelRuntime::SyncStream() { | |||
| SetContext(); | |||
| InnerSetContext(); | |||
| if (RT_ERROR_NONE != rtStreamSynchronize(stream_)) { // o for switch stream | |||
| MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error."; | |||
| return false; | |||
| @@ -570,7 +586,7 @@ bool AscendKernelRuntime::InitDevice() { | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast<int>(ret) << "]"; | |||
| } | |||
| SetContext(); | |||
| InnerSetContext(); | |||
| ret = rtStreamCreate(&stream_, 0); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]"; | |||
| @@ -580,7 +596,7 @@ bool AscendKernelRuntime::InitDevice() { | |||
| } | |||
| bool AscendKernelRuntime::ResetDevice() { | |||
| SetContext(); | |||
| InnerSetContext(); | |||
| if (stream_ != nullptr) { | |||
| auto ret = rtStreamDestroy(stream_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| @@ -49,6 +49,7 @@ class AscendKernelRuntime : public KernelRuntime { | |||
| const std::vector<CNodePtr> &execution_order) override; | |||
| void ClearGlobalIdleMem() override; | |||
| bool SyncStream() override; | |||
| void SetContext() override; | |||
| protected: | |||
| DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | |||
| @@ -62,7 +63,7 @@ class AscendKernelRuntime : public KernelRuntime { | |||
| bool HcclInit(); | |||
| bool NeedDestroyHccl(); | |||
| bool DestroyHccl(); | |||
| void SetContext(); | |||
| void InnerSetContext(); | |||
| void ClearGraphModelMap(); | |||
| void ReleaseDeviceRes() override; | |||
| @@ -147,7 +147,6 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *k | |||
| } | |||
| auto address = AnfAlgo::GetMutableOutputAddr(node, index); | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| TypeId infer_type_id = AnfAlgo::GetOutputInferDataType(node, index); | |||
| TypeId device_type_id = AnfAlgo::GetOutputDeviceDataType(node, index); | |||
| tensor::TensorPtr tensor = kernel_graph->GetInternalOutputTensor(node, index); | |||
| @@ -161,8 +160,8 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *k | |||
| kernel_graph->AddInternalOutputTensor(node, index, tensor); | |||
| } | |||
| } | |||
| tensor->set_device_address(address); | |||
| if (bound_addresses_.find(address) != bound_addresses_.end()) { | |||
| tensor->set_device_address(address); | |||
| tensor->set_sync_status(kNeedSyncDeviceToHostImmediately); | |||
| } else { | |||
| if (infer_type_id != device_type_id) { | |||
| @@ -170,17 +169,13 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *k | |||
| ShapeVector data_shape = tensor->shape(); | |||
| size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies<size_t>()); | |||
| address->ptr_ = resource_manager_.MemMalloc(tensor_size); | |||
| tensor->set_device_address(address); | |||
| tensor->set_sync_status(kNeedSyncDeviceToHostImmediately); | |||
| } else { | |||
| tensor->set_device_address(nullptr); | |||
| address->ptr_ = tensor->data_c(); | |||
| tensor->set_sync_status(kNoNeedSync); | |||
| } | |||
| address->ref_count_ = INIT_NODE_REF; | |||
| (void)bound_addresses_.insert(address); | |||
| } | |||
| tensor->SetNeedWait(true); | |||
| return tensor; | |||
| } | |||
| @@ -214,11 +209,11 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *kernel_grap | |||
| } | |||
| return BaseRef(); | |||
| } | |||
| void CPUKernelRuntime::BindInputOutput(session::KernelGraph *kernel_graph, const std::vector<tensor::TensorPtr> &inputs, | |||
| VectorRef *outputs) { | |||
| void CPUKernelRuntime::CreateOutputTensors(session::KernelGraph *kernel_graph, | |||
| const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| MS_EXCEPTION_IF_NULL(outputs); | |||
| // bind input ptr | |||
| auto &input_nodes = kernel_graph->inputs(); | |||
| if (input_nodes.size() != inputs.size()) { | |||
| MS_LOG(EXCEPTION) << "Input size not equal to input node size!"; | |||
| @@ -228,6 +223,27 @@ void CPUKernelRuntime::BindInputOutput(session::KernelGraph *kernel_graph, const | |||
| for (auto &item : input_nodes) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| input_param_tensor_map_[item] = inputs[input_idx]; | |||
| input_idx++; | |||
| } | |||
| bound_addresses_.clear(); | |||
| auto output_nodes = kernel_graph->outputs(); | |||
| for (const auto &item : output_nodes) { | |||
| auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, true); | |||
| auto out = CreatTensorForOutput(kernel_graph, item_with_index); | |||
| outputs->push_back(std::move(out)); | |||
| } | |||
| } | |||
| void CPUKernelRuntime::BindInputTensorAddressPtr(const session::KernelGraph &kernel_graph, | |||
| const std::vector<tensor::TensorPtr> &inputs) { | |||
| auto &input_nodes = kernel_graph.inputs(); | |||
| if (input_nodes.size() != inputs.size()) { | |||
| MS_LOG(EXCEPTION) << "Input size not equal to input node size!"; | |||
| } | |||
| size_t input_idx = 0; | |||
| for (auto &item : input_nodes) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| if (item->isa<Parameter>()) { | |||
| auto address = AnfAlgo::GetMutableOutputAddr(item, 0); | |||
| auto tensor = inputs[input_idx]; | |||
| @@ -235,7 +251,8 @@ void CPUKernelRuntime::BindInputOutput(session::KernelGraph *kernel_graph, const | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| if (tensor_address != nullptr && tensor_address != address) { | |||
| (void)tensor->data_sync(); | |||
| tensor_address->SyncDeviceToHost(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||
| tensor->data_c()); | |||
| } | |||
| if (tensor->data_type() == address->type_id_ || tensor->data_type() == kNumberTypeFloat32 || | |||
| tensor->data_type() == kNumberTypeInt32 || tensor->data_type() == kNumberTypeInt64) { | |||
| @@ -255,16 +272,37 @@ void CPUKernelRuntime::BindInputOutput(session::KernelGraph *kernel_graph, const | |||
| } | |||
| input_idx++; | |||
| } | |||
| // new output and bind ptr | |||
| bound_addresses_.clear(); | |||
| auto output_nodes = kernel_graph->outputs(); | |||
| for (const auto &item : output_nodes) { | |||
| auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, true); | |||
| auto out = CreatTensorForOutput(kernel_graph, item_with_index); | |||
| outputs->push_back(std::move(out)); | |||
| } | |||
| void CPUKernelRuntime::BindOutputTensorAddressPtr(const VectorRef *outputs) { | |||
| MS_EXCEPTION_IF_NULL(outputs); | |||
| for (auto item : *outputs) { | |||
| if (utils::isa<VectorRefPtr>(item)) { | |||
| auto vector_ref = utils::cast<VectorRef>(item); | |||
| BindOutputTensorAddressPtr(&vector_ref); | |||
| } else if (utils::isa<tensor::TensorPtr>(item)) { | |||
| auto tensor = utils::cast<tensor::TensorPtr>(item); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| auto address = tensor->device_address(); | |||
| if (address == nullptr) { | |||
| continue; | |||
| } | |||
| auto address_ptr = std::dynamic_pointer_cast<device::DeviceAddress>(address); | |||
| if (tensor->sync_status() == kNoNeedSync) { | |||
| address_ptr->ptr_ = tensor->data_c(); | |||
| } | |||
| address_ptr->ref_count_ = INIT_NODE_REF; | |||
| } | |||
| } | |||
| } | |||
| void CPUKernelRuntime::BindInputOutput(session::KernelGraph *kernel_graph, const std::vector<tensor::TensorPtr> &inputs, | |||
| VectorRef *outputs) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| BindInputTensorAddressPtr(*kernel_graph, inputs); | |||
| BindOutputTensorAddressPtr(outputs); | |||
| } | |||
| void CPUKernelRuntime::AddRuntimeAddress(DeviceAddress *address, std::vector<kernel::AddressPtr> *input_list) { | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| MS_EXCEPTION_IF_NULL(input_list); | |||
| @@ -38,6 +38,8 @@ class CPUKernelRuntime : public KernelRuntime { | |||
| bool Init() override { return true; } | |||
| bool Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger = nullptr) override; | |||
| void AssignKernelAddress(session::KernelGraph *kernel_graph); | |||
| void CreateOutputTensors(session::KernelGraph *kernel_graph, const std::vector<tensor::TensorPtr> &inputs, | |||
| VectorRef *outputs); | |||
| void BindInputOutput(session::KernelGraph *kernel_graph, const std::vector<tensor::TensorPtr> &inputs, | |||
| VectorRef *outputs); | |||
| void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); | |||
| @@ -50,8 +52,9 @@ class CPUKernelRuntime : public KernelRuntime { | |||
| private: | |||
| tensor::TensorPtr CreatTensorForOutput(session::KernelGraph *kernel_graph, const CNodePtr &node, size_t index); | |||
| BaseRef CreatTensorForOutput(session::KernelGraph *kernel_graph, const session::KernelWithIndex &kernel_with_index); | |||
| void BindInputTensorAddressPtr(const session::KernelGraph &graph, const std::vector<tensor::TensorPtr> &inputs); | |||
| void BindOutputTensorAddressPtr(const VectorRef *outputs); | |||
| void AssignValueNodeAddress(session::KernelGraph *kernel_graph); | |||
| void AssignInputNodeAddress(const session::KernelGraph *kernel_graph); | |||
| void AssignKernelOutputAddress(const session::KernelGraph *kernel_graph); | |||
| @@ -72,6 +72,7 @@ class KernelRuntime { | |||
| const std::vector<CNodePtr> &execution_order); | |||
| virtual bool SyncStream() = 0; | |||
| virtual void ClearGlobalIdleMem() {} | |||
| virtual void SetContext() {} | |||
| // for GPU and D to impl | |||
| virtual void ReleaseDeviceRes() {} | |||
| @@ -39,6 +39,17 @@ py::object BuiltinsToPyData(const BaseRef &value); | |||
| py::object VectorToPyData(const Any &value); | |||
| py::object VectorRefToPyData(const VectorRef &value); | |||
| py::object TensorToPyData(const tensor::TensorPtr &tensor) { | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| if (tensor->NeedWait()) { | |||
| py::gil_scoped_release release; | |||
| tensor->Wait(); | |||
| } | |||
| py::tuple v(1); | |||
| v[0] = tensor; | |||
| return v[0]; | |||
| } | |||
| py::object ValuePtrToPyData(const ValuePtr &value) { | |||
| if (value == nullptr) { | |||
| MS_LOG(EXCEPTION) << "value is null"; | |||
| @@ -94,9 +105,8 @@ py::object ValuePtrToPyData(const ValuePtr &value) { | |||
| ret = v; | |||
| } else if (value->isa<tensor::Tensor>()) { | |||
| MS_LOG(DEBUG) << "tensor"; | |||
| py::tuple v(1); | |||
| v[0] = value->cast<tensor::TensorPtr>(); | |||
| ret = v[0]; | |||
| auto tensor_ptr = value->cast<tensor::TensorPtr>(); | |||
| ret = TensorToPyData(tensor_ptr); | |||
| } else if (value->isa<tensor::MetaTensor>()) { | |||
| MS_LOG(DEBUG) << "MetaTensor"; | |||
| py::tuple v(1); | |||
| @@ -166,9 +176,8 @@ py::object AnyToPyData(const Any &value) { | |||
| ret = ValuePtrToPyData(v); | |||
| } else if (value.is<tensor::TensorPtr>()) { | |||
| MS_LOG(DEBUG) << "tensor"; | |||
| py::tuple v(1); | |||
| v[0] = value.cast<tensor::TensorPtr>(); | |||
| ret = v[0]; | |||
| auto tensor_ptr = value.cast<tensor::TensorPtr>(); | |||
| ret = TensorToPyData(tensor_ptr); | |||
| } else if (value.is<py::object>()) { | |||
| MS_LOG(DEBUG) << "py obj"; | |||
| ret = value.cast<py::object>(); | |||
| @@ -210,9 +219,8 @@ py::object BaseRefToPyData(const BaseRef &value) { | |||
| ret = ValuePtrToPyData(v); | |||
| } else if (utils::isa<tensor::TensorPtr>(value)) { | |||
| MS_LOG(DEBUG) << "tensor"; | |||
| py::tuple v(1); | |||
| v[0] = utils::cast<tensor::TensorPtr>(value); | |||
| ret = v[0]; | |||
| auto tensor_ptr = utils::cast<tensor::TensorPtr>(value); | |||
| ret = TensorToPyData(tensor_ptr); | |||
| } else if (utils::isa<PyObjectRef>(value)) { | |||
| MS_LOG(DEBUG) << "py obj"; | |||
| PyObjectRef py_ref = utils::cast<PyObjectRef>(value); | |||
| @@ -55,9 +55,9 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri | |||
| GraphId graph_id = kInvalidGraphId; | |||
| if (target != target_device_ && !target.empty()) { | |||
| CreateOtherSession(target); | |||
| graph_id = other_sess_->CompileGraphAsync(lst, outputs); | |||
| graph_id = other_sess_->CompileGraph(lst, outputs); | |||
| } else { | |||
| graph_id = target_sess_->CompileGraphAsync(lst, outputs); | |||
| graph_id = target_sess_->CompileGraph(lst, outputs); | |||
| } | |||
| if (MsContext::GetInstance()->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) { | |||
| @@ -65,9 +65,9 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri | |||
| return result; | |||
| } | |||
| if (target != target_device_ && !target.empty()) { | |||
| other_sess_->BuildGraphAsync(graph_id); | |||
| other_sess_->BuildGraph(graph_id); | |||
| } else if (!is_multi_graph_sink_) { | |||
| target_sess_->BuildGraphAsync(graph_id); | |||
| target_sess_->BuildGraph(graph_id); | |||
| } | |||
| result.run = std::make_shared<RunFunc>( | |||
| [graph_id, target, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args, target); }); | |||
| @@ -151,7 +151,7 @@ void MsBackend::Link(GraphId graph_id) { | |||
| if (graph_id == kInvalidGraphId) { | |||
| graph_id = target_sess_->GetFinalRunGraph(); | |||
| } | |||
| target_sess_->BuildGraphAsync(graph_id); | |||
| target_sess_->BuildGraph(graph_id); | |||
| } | |||
| Backend::Backend(const std::string &name) : name_(name) { | |||
| @@ -187,7 +187,7 @@ void MsBackend::CreateOtherSession(const std::string &target) { | |||
| other_device_ = target; | |||
| } | |||
| GraphId MsBackend::CompileGraph(NotNull<FuncGraphPtr> fg) { return target_sess_->CompileGraphAsync(fg); } | |||
| GraphId MsBackend::CompileGraph(NotNull<FuncGraphPtr> fg) { return target_sess_->CompileGraph(fg); } | |||
| VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); } | |||
| @@ -541,7 +541,7 @@ std::string Tensor::ToStringInternal(int limit_size) const { | |||
| std::ostringstream buf; | |||
| auto dtype = Dtype(); | |||
| MS_EXCEPTION_IF_NULL(dtype); | |||
| data_sync(); | |||
| data_sync(false); | |||
| buf << "Tensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString() << ", value="; | |||
| if (limit_size <= 0 || DataSize() < limit_size) { | |||
| // Only print data for small tensor. | |||
| @@ -567,14 +567,16 @@ std::string Tensor::ToStringRepr() const { | |||
| std::ostringstream buf; | |||
| auto dtype = Dtype(); | |||
| MS_EXCEPTION_IF_NULL(dtype); | |||
| data_sync(); | |||
| data_sync(false); | |||
| buf << "Tensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString() | |||
| << ", value=" << ((data().ndim() > 1) ? '\n' : ' ') << data().ToString(data_type_, shape_, true) << ')'; | |||
| return buf.str(); | |||
| } | |||
| void Tensor::data_sync() const { | |||
| Wait(); | |||
| void Tensor::data_sync(bool need_wait) const { | |||
| if (need_wait) { | |||
| Wait(); | |||
| } | |||
| if (device_sync_ == nullptr) { | |||
| return; | |||
| } | |||
| @@ -229,8 +229,8 @@ class Tensor : public MetaTensor { | |||
| void *data_c() const { return data_->data(); } | |||
| // brief Sync data with device. | |||
| void data_sync() const; | |||
| // brief Sync data with device, need wait data valid. | |||
| void data_sync(bool need_wait = true) const; | |||
| // brief Get the internal data object. | |||
| // | |||