From b59f378243487a5df3e999b7ad2a74129a989b61 Mon Sep 17 00:00:00 2001 From: ysk Date: Mon, 8 Mar 2021 11:07:00 +0800 Subject: [PATCH] stream --- ge/client/ge_api.cc | 30 +++++ ge/graph/execute/graph_execute.cc | 126 ++++++++++++++---- ge/graph/execute/graph_execute.h | 10 ++ ge/graph/load/graph_loader.cc | 5 +- ge/graph/manager/graph_manager.cc | 98 ++++++++++++++ ge/graph/manager/graph_manager.h | 17 +++ ge/graph/manager/graph_manager_utils.cc | 1 + ge/graph/manager/graph_manager_utils.h | 3 + ge/model/ge_root_model.h | 3 + ge/session/inner_session.cc | 35 +++++ ge/session/inner_session.h | 3 + ge/session/session_manager.cc | 23 ++++ ge/session/session_manager.h | 14 ++ inc/external/ge/ge_api.h | 12 ++ .../error_manager/src/error_manager_stub.cc | 4 + tests/depends/slog/src/slog_stub.cc | 4 + tests/ut/ge/CMakeLists.txt | 2 + tests/ut/ge/graph/ge_executor_unittest.cc | 102 ++++++++++++++ .../ut/ge/graph/manager/run_graph_unittest.cc | 61 +++++++++ tests/ut/ge/session/ge_api_unittest.cc | 59 ++++++++ 20 files changed, 584 insertions(+), 28 deletions(-) create mode 100644 tests/ut/ge/graph/manager/run_graph_unittest.cc create mode 100644 tests/ut/ge/session/ge_api_unittest.cc diff --git a/ge/client/ge_api.cc b/ge/client/ge_api.cc index f0cf9e03..b6381710 100644 --- a/ge/client/ge_api.cc +++ b/ge/client/ge_api.cc @@ -529,6 +529,36 @@ Status Session::RunGraph(uint32_t graph_id, const std::vector &inputs, s return ret; } +Status Session::RunGraphWithStreamAsync(uint32_t graph_id, const std::vector &inputs, + std::vector &outputs, void *stream) { + GELOGT(TRACE_INIT, "Session RunGraphWithStreamAsync start"); + + ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); + std::vector graph_inputs = inputs; + std::shared_ptr instance_ptr = ge::GELib::GetInstance(); + if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Session RunGraph failed"); + return FAILED; + } + GELOGT(TRACE_RUNNING, "Running Graph"); + Status ret = instance_ptr->SessionManagerObj().RunGraphWithStreamAsync(sessionId_, graph_id, + graph_inputs, outputs, stream); + // check return status + if (ret != SUCCESS) { + GELOGE(ret, "Session RunGraph failed"); + return FAILED; + } + + // print output + if (!outputs.empty()) { + PrintOutputResult(outputs); + } + + // return + GELOGT(TRACE_STOP, "Session RunGraph finished"); + return ret; +} + Status Session::RegisterCallBackFunc(const std::string &key, const pCallBackFunc &callback) { ErrorManager::GetInstance().GenWorkStreamIdDefault(); return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, key, callback); diff --git a/ge/graph/execute/graph_execute.cc b/ge/graph/execute/graph_execute.cc index 1aee756c..dd7e50ea 100755 --- a/ge/graph/execute/graph_execute.cc +++ b/ge/graph/execute/graph_execute.cc @@ -278,33 +278,10 @@ Status GraphExecutor::SyncExecuteModel(uint32_t model_id, const std::vector outBufTmp(new (std::nothrow) uint8_t[outputDataTmp.length]); - if (outBufTmp == nullptr) { - GELOGE(FAILED, "Failed to allocate memory."); - return FAILED; - } - GE_PRINT_DYNAMIC_MEMORY(new, "the output memory of data on training.", sizeof(uint8_t) * outputDataTmp.length) - rtError_t ret_value = rtMemcpy(outBufTmp.get(), outputDataTmp.length, outputDataTmp.data, outputDataTmp.length, - RT_MEMCPY_HOST_TO_HOST); - CHECK_FALSE_EXEC(ret_value == RT_ERROR_NONE, - GELOGE(GE_GRAPH_EXECUTE_FAILED, "Call rt api rtMemcpy failed, ret: 0x%X", ret); - return GE_GRAPH_EXECUTE_FAILED); - GeTensor outTensor; - std::vector shapeDims; - for (const auto &dim : output_desc[i].shape_info.dims) { - shapeDims.push_back(dim); - } - - GeShape outShape(shapeDims); - outTensor.MutableTensorDesc().SetShape(outShape); - outTensor.MutableTensorDesc().SetDataType((DataType)output_desc[i].data_type); - (void)outTensor.SetData(outBufTmp.get(), outputDataTmp.length); - output_tensor.push_back(outTensor); + ret = ProcessOutputData(output_data, output_desc, output_tensor); + if (ret != SUCCESS) { + return ret; + } GELOGI("[GraphExecutor] execute model success, modelId=%u.", model_id); @@ -378,6 +355,67 @@ Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr & return SUCCESS; } +Status GraphExecutor::ExecuteGraphWithStream(GraphId graph_id, + const GeRootModelPtr &ge_root_model, + const std::vector &input_tensor, + std::vector &output_tensor, + rtStream_t stream) { + GELOGI("[GraphExecutor] Start to execute graph with stream, graph_id=%u", graph_id); + if (graph_id != last_graph_id_) { + auto ret = FreeExecuteMemory(); + if (ret != SUCCESS) { + return ret; + } + } + last_graph_id_ = graph_id; + + if (!init_flag_) { + GELOGE(GE_GRAPH_EXECUTE_NOT_INIT, "[GraphExecutor] AI Core Engine without calling SetCondition!"); + return GE_GRAPH_EXECUTE_NOT_INIT; + } + GE_CHECK_NOTNULL_EXEC(ge_root_model, return FAILED); + auto model_manager = ge::ModelManager::GetInstance(); + GE_CHECK_NOTNULL(model_manager); + auto model_id = ge_root_model->GetModelId(); + if (model_manager->IsDynamicShape(model_id)) { + GELOGI("[ExecuteGraphWithStream] GetInputOutputDescInfo via dynamic shape model executor, modelId=%u", model_id); + return model_manager->SyncExecuteModel(model_id, input_tensor, output_tensor); + } + + std::vector inputs_desc; + std::vector output_desc; + + GELOGI("[ExecuteGraph] GetInputOutputDescInfo via new ome begin."); + Status ret = GetInputOutputDescInfo(model_id, inputs_desc, output_desc); + if (ret != SUCCESS) { + GELOGE(GE_GRAPH_GET_IN_OUT_FAILED, "[GraphExecutor] GetInputOutputDescInfo failed, modelId=%u.", model_id); + return GE_GRAPH_GET_IN_OUT_FAILED; + } + outputs_desc_.assign(output_desc.begin(), output_desc.end()); + + InputData input_data; + OutputData output_data; + input_data.model_id = model_id; + ret = PrepareInputData(input_tensor, input_data, output_data, output_desc); + if (ret != SUCCESS) { + GELOGE(GE_GRAPH_PREPARE_FAILED, "[GraphExecutor] PrepareInputData failed, modelId=%u.", model_id); + return GE_GRAPH_PREPARE_FAILED; + } + + auto async_mode = true; + std::vector input_ge_desc; + std::vector output_ge_desc; + model_manager->ExecuteModel(model_id, stream, async_mode, input_data, input_ge_desc, output_data, output_ge_desc); + + ret = ProcessOutputData(output_data, output_desc, output_tensor); + if (ret != SUCCESS) { + return ret; + } + GELOGI("[GraphExecutor] execute model success, modelId=%u.", model_id); + + return SUCCESS; +} + Status GraphExecutor::AsyncExecuteModel(uint32_t model_id, const std::vector &inputs) { try { auto model_manager = ge::ModelManager::GetInstance(); @@ -404,6 +442,40 @@ Status GraphExecutor::AsyncExecuteModel(uint32_t model_id, const std::vector &output_desc, + std::vector &output_tensor) { + for (size_t i = 0; i < output_data.blobs.size(); ++i) { + DataBuffer outputDataTmp = output_data.blobs[i]; + CHECK_FALSE_EXEC(outputDataTmp.length != 0, + GELOGE(GE_GRAPH_EXECUTE_FAILED, "Failed to allocate memory, length is 0."); + return GE_GRAPH_EXECUTE_FAILED); + std::unique_ptr outBufTmp(new (std::nothrow) uint8_t[outputDataTmp.length]); + if (outBufTmp == nullptr) { + GELOGE(FAILED, "Failed to allocate memory."); + return FAILED; + } + GE_PRINT_DYNAMIC_MEMORY(new, "the output memory of data on training.", sizeof(uint8_t) * outputDataTmp.length) + rtError_t ret_value = rtMemcpy(outBufTmp.get(), outputDataTmp.length, outputDataTmp.data, outputDataTmp.length, + RT_MEMCPY_HOST_TO_HOST); + CHECK_FALSE_EXEC(ret_value == RT_ERROR_NONE, + GELOGE(GE_GRAPH_EXECUTE_FAILED, "Call rt api rtMemcpy failed, ret: 0x%X", ret_value); + return GE_GRAPH_EXECUTE_FAILED); + GeTensor outTensor; + std::vector shapeDims; + for (const auto &dim : output_desc[i].shape_info.dims) { + shapeDims.push_back(dim); + } + + GeShape outShape(shapeDims); + outTensor.MutableTensorDesc().SetShape(outShape); + outTensor.MutableTensorDesc().SetDataType((DataType)output_desc[i].data_type); + (void)outTensor.SetData(outBufTmp.get(), outputDataTmp.length); + output_tensor.push_back(outTensor); + } + return SUCCESS; +} + Status GraphExecutor::DataInput(const InputData &input_data, OutputData &output_data) { try { auto model_manager = ge::ModelManager::GetInstance(); diff --git a/ge/graph/execute/graph_execute.h b/ge/graph/execute/graph_execute.h index d2a92e47..137a6d2e 100755 --- a/ge/graph/execute/graph_execute.h +++ b/ge/graph/execute/graph_execute.h @@ -52,6 +52,12 @@ class GraphExecutor { ge::Status ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr &ge_root_model, const std::vector &input_tensor); + Status ExecuteGraphWithStream(GraphId graph_id, + const GeRootModelPtr &ge_root_model, + const std::vector &input_tensor, + std::vector &output_tensor, + rtStream_t stream); + Status SetCondition(std::mutex *mutex, std::condition_variable *cond, std::shared_ptr listener); Status SetGraphContext(GraphContextPtr graph_context_ptr); @@ -128,6 +134,10 @@ class GraphExecutor { void InitModelIdInfo(std::vector &out_model_id_info, std::vector &sub_graph_vec, uint32_t output_size); + Status ProcessOutputData(const OutputData &output_data, + const std::vector &output_desc, + std::vector &output_tensor); + Status FreeInOutBuffer(); Status MallocInOutBuffer(const std::vector &buffer_size, std::vector &data_addr); diff --git a/ge/graph/load/graph_loader.cc b/ge/graph/load/graph_loader.cc index 644880ce..d4619889 100755 --- a/ge/graph/load/graph_loader.cc +++ b/ge/graph/load/graph_loader.cc @@ -75,7 +75,10 @@ Status GraphLoader::LoadModelOnline(uint32_t &model_id, const std::shared_ptrCheckIsSpecificStream()) { + GELOGI("No need to start a new thread to run model in specific scene"); + return SUCCESS; + } ret = model_manager->Start(model_id); if (ret != SUCCESS) { if (model_manager->Unload(model_id) != SUCCESS) { diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index 1cbb3fc8..7bc158ca 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -930,6 +930,7 @@ Status GraphManager::LoadGraph(const GeRootModelPtr &ge_root_model, const GraphN GE_CHK_STATUS_RET(CheckAndReleaseMemory(ge_model, graph_node)); } } + ge_root_model->SetIsSpecificStream(graph_node->CheckIsSpecificStream()); GE_TIMESTAMP_START(LoadGraph); Status ret = GraphLoader::LoadModelOnline(model_id_info.model_id, ge_root_model, model_listener); GE_TIMESTAMP_EVENT_END(LoadGraph, "GraphManager::LoadGraph"); @@ -1053,6 +1054,44 @@ Status GraphManager::InnerRunGraph(GraphNodePtr &graph_node, const GraphId &grap return SUCCESS; } +Status GraphManager::InnerRunGraphWithStream(GraphNodePtr &graph_node, const GraphId &graph_id, + const std::vector &inputs, std::vector &outputs, + rtStream_t stream, ComputeGraphPtr &compute_graph_tmp) { + if (GetTrainFlag()) { + GE_CHK_STATUS_RET(graph_executor_.SetGraphContext(GetGraphContext())); + graph_executor_.SetTrainFlag(options_.train_graph_flag); + } + auto ret = graph_executor_.ExecuteGraphWithStream(graph_id, graph_node->GetGeRootModel(), + inputs, outputs, stream); + graph_node->SetRunFlag(false); + graph_node->SetIsSpecificStream(false); + if (ret != SUCCESS) { + GELOGE(ret, "[RunGraphWithStreamAsync] execute graph failed, graph_id = %u.", graph_id); + return ret; + } + + if (GetTrainFlag()) { + if (compute_graph_tmp->IsSummaryGraph()) { + ret = SummaryHandle(graph_id, outputs); + if (ret != SUCCESS) { + GELOGE(ret, "[RunGraphWithStreamAsync] SummaryHandle failed!"); + } + } + GeRootModelPtr root_model = graph_node->GetGeRootModel(); + if (root_model != nullptr) { + GELOGI("Start CheckpointHandle."); + auto checkPointGraph = root_model->GetRootGraph(); + if (IsCheckpointGraph(checkPointGraph)) { + ret = CheckpointHandle(graph_id, checkPointGraph, outputs); + if (ret != SUCCESS) { + GELOGE(ret, "[RunGraphWithStreamAsync] CheckpointHandle failed!"); + } + } + } + } + return ret; +} + Status GraphManager::RunGraph(const GraphId &graph_id, const std::vector &inputs, std::vector &outputs, uint64_t session_id) { ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther); @@ -1140,6 +1179,65 @@ Status GraphManager::RunGraph(const GraphId &graph_id, const std::vector &inputs, + std::vector &outputs, rtStream_t stream, + uint64_t session_id) { + std::lock_guard lock(run_mutex_); + GELOGI("[RunGraphWithStreamAsync] start to run graph, graph_id = %u, is_train_graph: %d", graph_id, GetTrainFlag()); + + if (inputs.empty()) { + GELOGI("[RunGraphWithStreamAsync] initialize sub graph has no inputs"); + } + + // find graph + GraphNodePtr graph_node = nullptr; + Status ret = GetGraphNode(graph_id, graph_node); + if (ret != SUCCESS) { + GELOGE(ret, "[RunGraphWithStreamAsync] graph not exist, graph_id = %u.", graph_id); + return ret; + } + if (graph_node == nullptr) { + GELOGE(GE_GRAPH_GRAPH_NODE_NULL, "[RunGraphWithStreamAsync] graph node is NULL, graph_id = %u.", graph_id); + return GE_GRAPH_GRAPH_NODE_NULL; + } + if (graph_node->GetRunFlag()) { + GELOGE(GE_GRAPH_ALREADY_RUNNING, "[RunGraphWithStreamAsync] graph already running, graph id = %u", graph_id); + return GE_GRAPH_ALREADY_RUNNING; + } + + UpdateLocalOmgContext(graph_id); + + // set graph's run flag + graph_node->SetRunFlag(true); + graph_node->SetIsSpecificStream(true); + ComputeGraphPtr compute_graph_tmp = GraphUtils::GetComputeGraph(*(graph_node->GetGraph())); + GE_IF_BOOL_EXEC(GetTrainFlag(), + GE_IF_BOOL_EXEC(compute_graph_tmp == nullptr, + GELOGE(GE_GRAPH_GRAPH_NODE_NULL, + "[RunGraphWithStreamAsync] compute_graph is NULL, graph id = %u.", graph_id); + return GE_GRAPH_GRAPH_NODE_NULL;)) + + // when set incre build, add cache helper map + AddModelCacheHelperToMap(graph_id, session_id, compute_graph_tmp); + if (options_.local_fmk_op_flag) { + GetCompilerStages(graph_id).optimizer.TranFrameOp(compute_graph_tmp); + } + GeRootModelPtr ge_root_model = nullptr; + ret = StartForRunGraph(graph_node, inputs, ge_root_model, session_id); + if (ret != SUCCESS) { + GELOGE(ret, "[RunGraphWithStreamAsync] StartForRunGraph failed!"); + graph_node->SetRunFlag(false); + return ret; + } + ret = InnerRunGraphWithStream(graph_node, graph_id, inputs, outputs, stream, compute_graph_tmp); + if (ret != SUCCESS) { + GELOGE(ret, "[InnerRunGraphWithStream] RunGraph failed!"); + return ret; + } + GELOGI("[RunGraphWithStreamAsync] run graph success, graph_id = %u.", graph_id); + return SUCCESS; +} + Status GraphManager::GenerateInfershapeGraph(GraphId &graph_id) { GELOGI("[DumpInfershapeJson] start to DumpInfershapeJson graph, graph_id=%u.", graph_id); // find graph diff --git a/ge/graph/manager/graph_manager.h b/ge/graph/manager/graph_manager.h index 661cf9d8..98e9b6bf 100644 --- a/ge/graph/manager/graph_manager.h +++ b/ge/graph/manager/graph_manager.h @@ -103,6 +103,19 @@ class GraphManager { Status RunGraph(const GraphId &graph_id, const std::vector &inputs, std::vector &outputs, uint64_t session_id = INVALID_SESSION_ID); + /// + /// @ingroup ge_graph + /// @brief run specific graph with specific stream + /// @param [in] graph_id graph id + /// @param [in] inputs input data + /// @param [in] stream specific stream + /// @param [out] outputs output data + /// @return Status result of function + /// + Status RunGraphWithStreamAsync(const GraphId &graph_id, const std::vector &inputs, + std::vector &outputs, rtStream_t stream, + uint64_t session_id = INVALID_SESSION_ID); + /// /// @ingroup ge_graph /// @brief build specific graph @@ -243,6 +256,10 @@ class GraphManager { Status InnerRunGraph(GraphNodePtr &graph_node, const GraphId &graph_id, const std::vector &inputs, std::vector &outputs); + + Status InnerRunGraphWithStream(GraphNodePtr &graph_node, const GraphId &graph_id, + const std::vector &inputs, std::vector &outputs, + rtStream_t stream, ComputeGraphPtr &compute_graph_tmp); Status ParseOptions(const std::map &options); diff --git a/ge/graph/manager/graph_manager_utils.cc b/ge/graph/manager/graph_manager_utils.cc index fe7e5b34..f8575112 100644 --- a/ge/graph/manager/graph_manager_utils.cc +++ b/ge/graph/manager/graph_manager_utils.cc @@ -41,6 +41,7 @@ GraphNode::GraphNode(GraphId graph_id) build_flag_(false), load_flag_(false), async_(false), + is_specific_stream_(false), ge_model_(nullptr), sem_(1) { graph_run_async_listener_ = MakeShared(); diff --git a/ge/graph/manager/graph_manager_utils.h b/ge/graph/manager/graph_manager_utils.h index de65c5cb..dc7b9a32 100644 --- a/ge/graph/manager/graph_manager_utils.h +++ b/ge/graph/manager/graph_manager_utils.h @@ -164,6 +164,8 @@ class GraphNode { bool GetLoadFlag() const { return load_flag_; } void SetLoadFlag(bool load_flag) { load_flag_ = load_flag; } void SetGeModel(const GeModelPtr &ge_model) { ge_model_ = ge_model; } + void SetIsSpecificStream(bool specific_stream) { is_specific_stream_ = specific_stream; } + bool CheckIsSpecificStream() { return is_specific_stream_; } GeModelPtr GetGeModel() const { return ge_model_; } void SetGeRootModel(const GeRootModelPtr &ge_root_model) { ge_root_model_ = ge_root_model; } GeRootModelPtr GetGeRootModel() const { return ge_root_model_; } @@ -186,6 +188,7 @@ class GraphNode { bool build_flag_; bool load_flag_; bool async_; + bool is_specific_stream_; GeModelPtr ge_model_; GeRootModelPtr ge_root_model_; BlockingQueue sem_; diff --git a/ge/model/ge_root_model.h b/ge/model/ge_root_model.h index aa5a4d47..3f98c091 100755 --- a/ge/model/ge_root_model.h +++ b/ge/model/ge_root_model.h @@ -34,6 +34,8 @@ class GeRootModel { const ComputeGraphPtr &GetRootGraph() const { return root_graph_; }; void SetModelId(uint32_t model_id) { model_id_ = model_id; } + void SetIsSpecificStream(bool is_specific_stream) { is_specific_stream_ = is_specific_stream; } + bool CheckIsSpecificStream() {return is_specific_stream_; } uint32_t GetModelId() const { return model_id_; } Status CheckIsUnknownShape(bool &is_dynamic_shape); void SetRootGraph(ComputeGraphPtr graph) { root_graph_ = graph; } @@ -41,6 +43,7 @@ class GeRootModel { ComputeGraphPtr root_graph_ = nullptr; std::map subgraph_instance_name_to_model_; uint32_t model_id_ = 0; + bool is_specific_stream_ = false; }; } // namespace ge using GeRootModelPtr = std::shared_ptr; diff --git a/ge/session/inner_session.cc b/ge/session/inner_session.cc index d11ba10e..d4f61109 100755 --- a/ge/session/inner_session.cc +++ b/ge/session/inner_session.cc @@ -236,6 +236,41 @@ Status InnerSession::RunGraph(uint32_t graph_id, const std::vector &inpu } } +Status InnerSession::RunGraphWithStreamAsync(uint32_t graph_id, const std::vector &inputs, + std::vector &outputs, rtStream_t stream) { + GELOGI("[InnerSession:%lu] run graph with stream async on session, graph_id=%u.", session_id_, graph_id); + if (mutex_.try_lock()) { + std::lock_guard lock(mutex_, std::adopt_lock); + if (!init_flag_) { + GELOGE(GE_SESS_INIT_FAILED, "[InnerSession:%lu] initialize failed.", session_id_); + return GE_SESS_INIT_FAILED; + } + UpdateThreadContext(graph_id); + vector geInputs; + for (auto &item : inputs) { + geInputs.push_back(TensorAdapter::AsGeTensor(item)); + } + vector geOutputs; + Status ret = graph_manager_.RunGraphWithStreamAsync(graph_id, geInputs, geOutputs, stream, session_id_); + domi::GetContext().out_nodes_map.clear(); + domi::GetContext().user_out_nodes.clear(); + if (ret != SUCCESS) { + GELOGE(ret, "[InnerSession:%lu] run graph failed, graph_id=%u.", session_id_, graph_id); + return ret; + } + outputs.clear(); + for (auto &item : geOutputs) { + outputs.push_back(TensorAdapter::AsTensor(item)); + } + + GELOGI("[InnerSession:%lu] run graph success, graph_id=%u.", session_id_, graph_id); + return SUCCESS; + } else { + GELOGE(GE_SESS_ALREADY_RUNNING, "[InnerSession:%lu] run graph failed, graph_id=%u.", session_id_, graph_id); + return GE_SESS_ALREADY_RUNNING; + } +} + Status InnerSession::RemoveGraph(uint32_t graph_id) { std::lock_guard lock(resource_mutex_); if (!init_flag_) { diff --git a/ge/session/inner_session.h b/ge/session/inner_session.h index 5cab43d8..e0732763 100644 --- a/ge/session/inner_session.h +++ b/ge/session/inner_session.h @@ -41,6 +41,9 @@ class InnerSession { Status RunGraph(uint32_t graph_id, const std::vector &inputs, std::vector &outputs); + Status RunGraphWithStreamAsync(uint32_t graph_id, const std::vector &inputs, + std::vector &outputs, rtStream_t stream); + Status RemoveGraph(uint32_t graph_id); Status BuildGraph(uint32_t graph_id, const std::vector &inputs); diff --git a/ge/session/session_manager.cc b/ge/session/session_manager.cc index 3c531747..6aed59eb 100755 --- a/ge/session/session_manager.cc +++ b/ge/session/session_manager.cc @@ -219,6 +219,29 @@ Status SessionManager::RunGraph(SessionId session_id, uint32_t graph_id, const s return innerSession->RunGraph(graph_id, inputs, outputs); } +Status SessionManager::RunGraphWithStreamAsync(SessionId session_id, + uint32_t graph_id, + const std::vector &inputs, + std::vector &outputs, + rtStream_t stream) { + if (!init_flag_) { + GELOGE(GE_SESSION_MANAGER_NOT_INIT, "Session manager is not initialized."); + return GE_SESSION_MANAGER_NOT_INIT; + } + SessionPtr innerSession = nullptr; + { + std::lock_guard lock(mutex_); + std::map::iterator it = session_manager_map_.find(session_id); + if (it == session_manager_map_.end()) { + return GE_SESSION_NOT_EXIST; + } else { + innerSession = it->second; + } + } + return innerSession->RunGraphWithStreamAsync(graph_id, inputs, + outputs, stream); +} + Status SessionManager::RemoveGraph(SessionId session_id, uint32_t graph_id) { if (!init_flag_) { GELOGE(GE_SESSION_MANAGER_NOT_INIT, "Session manager is not initialized."); diff --git a/ge/session/session_manager.h b/ge/session/session_manager.h index da23219c..63deb028 100644 --- a/ge/session/session_manager.h +++ b/ge/session/session_manager.h @@ -24,6 +24,7 @@ #include #include "common/ge_inner_error_codes.h" #include "ge/ge_api_types.h" +#include "runtime/base.h" #include "session/inner_session.h" namespace ge { @@ -96,6 +97,19 @@ class SessionManager { Status RunGraph(SessionId session_id, uint32_t graph_id, const std::vector &inputs, std::vector &outputs); + /// + /// @ingroup ge_session + /// @brief run a graph of the session with specific stream + /// @param [in] session_id session id + /// @param [in] graph_id graph id + /// @param [in] inputs input data + /// @param [in] stream specific stream + /// @param [out] outputs output data + /// @return Status result of function + /// + Status RunGraphWithStreamAsync(SessionId session_id, uint32_t graph_id, const std::vector &inputs, + std::vector &outputs, rtStream_t stream); + /// /// @ingroup ge_session /// @brief remove a graph from the session with specific session id diff --git a/inc/external/ge/ge_api.h b/inc/external/ge/ge_api.h index c8b5a8ec..89786e0c 100644 --- a/inc/external/ge/ge_api.h +++ b/inc/external/ge/ge_api.h @@ -121,6 +121,18 @@ class GE_FUNC_VISIBILITY Session { /// Status RunGraph(uint32_t graphId, const std::vector &inputs, std::vector &outputs); + /// + /// @ingroup ge_graph + /// @brief run a graph of the session with specific session id + /// @param [in] graphId graph id + /// @param [in] inputs input data + /// @param [in] stream specific streams + /// @param [out] outputs output data + /// @return Status result of function + /// + Status RunGraphWithStreamAsync(uint32_t graphId, const std::vector &inputs, std::vector &outputs, + void *stream); + /// /// @ingroup ge_graph /// @brief build graph in the session with specific session id diff --git a/tests/depends/error_manager/src/error_manager_stub.cc b/tests/depends/error_manager/src/error_manager_stub.cc index eadc8687..5d305d75 100644 --- a/tests/depends/error_manager/src/error_manager_stub.cc +++ b/tests/depends/error_manager/src/error_manager_stub.cc @@ -94,3 +94,7 @@ using namespace ErrorMessage; void ErrorManager::SetErrorContext(struct Context error_context) {} void ErrorManager::SetStage(const std::string &first_stage, const std::string &second_stage) {} + +std::string ErrorManager::GetErrorMessage() {} + +std::string ErrorManager::GetWarningMessage() {} diff --git a/tests/depends/slog/src/slog_stub.cc b/tests/depends/slog/src/slog_stub.cc index edc245b4..5e70b87c 100644 --- a/tests/depends/slog/src/slog_stub.cc +++ b/tests/depends/slog/src/slog_stub.cc @@ -15,6 +15,7 @@ */ #include "toolchain/slog.h" +#include "toolchain/plog.h" #include #include @@ -46,3 +47,6 @@ int CheckLogLevel(int moduleId, int logLevel) { return 1; } + +DLL_EXPORT int DlogReportInitialize() {return 1;} +DLL_EXPORT int DlogReportFinalize() {return 1;} diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 91b756cc..e9121e95 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -755,8 +755,10 @@ set(MULTI_PARTS_TEST_FILES "graph/build/mem_assigner_unittest.cc" "graph/preprocess/graph_preprocess_unittest.cc" "graph/manager/hcom_util_unittest.cc" + "graph/manager/run_graph_unittest.cc" "graph/manager/graph_caching_allocator_unittest.cc" "session/omg_omg_unittest.cc" + "session/ge_api_unittest.cc" ) set(GENERATOR_TEST_FILES diff --git a/tests/ut/ge/graph/ge_executor_unittest.cc b/tests/ut/ge/graph/ge_executor_unittest.cc index e26aa86e..16fd6d08 100644 --- a/tests/ut/ge/graph/ge_executor_unittest.cc +++ b/tests/ut/ge/graph/ge_executor_unittest.cc @@ -33,6 +33,8 @@ #include "common/properties_manager.h" #include "common/types.h" #include "graph/load/graph_loader.h" +#include "graph/execute/graph_execute.h" +#include "common/profiling/profiling_manager.h" #include "graph/load/model_manager/davinci_model.h" #include "graph/load/model_manager/model_manager.h" #include "graph/load/model_manager/task_info/kernel_task_info.h" @@ -190,4 +192,104 @@ TEST_F(UtestGeExecutor, kernel_ex_InitDumpTask) { kernel_ex_task_info.davinci_model_ = &model; kernel_ex_task_info.InitDumpTask(nullptr, op_desc); } + +TEST_F(UtestGeExecutor, execute_graph_with_stream) { + DavinciModel model(0, nullptr); + ComputeGraphPtr graph = make_shared("default"); + ProfilingManager::Instance().is_load_profiling_ = true; + + GeModelPtr ge_model = make_shared(); + ge_model->SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph)); + AttrUtils::SetInt(ge_model, ATTR_MODEL_MEMORY_SIZE, 10240); + AttrUtils::SetInt(ge_model, ATTR_MODEL_STREAM_NUM, 1); + + shared_ptr model_task_def = make_shared(); + ge_model->SetModelTaskDef(model_task_def); + + GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); + TensorUtils::SetSize(tensor, 512); + { + OpDescPtr op_desc = CreateOpDesc("data", DATA); + op_desc->AddInputDesc(tensor); + op_desc->AddOutputDesc(tensor); + op_desc->SetInputOffset({1024}); + op_desc->SetOutputOffset({1024}); + NodePtr node = graph->AddNode(op_desc); // op_index = 0 + } + + { + OpDescPtr op_desc = CreateOpDesc("square", "Square"); + op_desc->AddInputDesc(tensor); + op_desc->AddOutputDesc(tensor); + op_desc->SetInputOffset({1024}); + op_desc->SetOutputOffset({1024}); + NodePtr node = graph->AddNode(op_desc); // op_index = 1 + + domi::TaskDef *task_def = model_task_def->add_task(); + task_def->set_stream_id(0); + task_def->set_type(RT_MODEL_TASK_KERNEL); + domi::KernelDef *kernel_def = task_def->mutable_kernel(); + kernel_def->set_stub_func("stub_func"); + kernel_def->set_args_size(64); + string args(64, '1'); + kernel_def->set_args(args.data(), 64); + domi::KernelContext *context = kernel_def->mutable_context(); + context->set_op_index(op_desc->GetId()); + context->set_kernel_type(2); // ccKernelType::TE + uint16_t args_offset[9] = {0}; + context->set_args_offset(args_offset, 9 * sizeof(uint16_t)); + } + + { + OpDescPtr op_desc = CreateOpDesc("memcpy", MEMCPYASYNC); + op_desc->AddInputDesc(tensor); + op_desc->AddOutputDesc(tensor); + op_desc->SetInputOffset({1024}); + op_desc->SetOutputOffset({5120}); + NodePtr node = graph->AddNode(op_desc); // op_index = 2 + + domi::TaskDef *task_def = model_task_def->add_task(); + task_def->set_stream_id(0); + task_def->set_type(RT_MODEL_TASK_MEMCPY_ASYNC); + domi::MemcpyAsyncDef *memcpy_async = task_def->mutable_memcpy_async(); + memcpy_async->set_src(1024); + memcpy_async->set_dst(5120); + memcpy_async->set_dst_max(512); + memcpy_async->set_count(1); + memcpy_async->set_kind(RT_MEMCPY_DEVICE_TO_DEVICE); + memcpy_async->set_op_index(op_desc->GetId()); + } + + { + OpDescPtr op_desc = CreateOpDesc("output", NETOUTPUT); + op_desc->AddInputDesc(tensor); + op_desc->SetInputOffset({5120}); + op_desc->SetSrcName( { "memcpy" } ); + op_desc->SetSrcIndex( { 0 } ); + NodePtr node = graph->AddNode(op_desc); // op_index = 3 + } + + EXPECT_EQ(model.Assign(ge_model), SUCCESS); + EXPECT_EQ(model.Init(), SUCCESS); + + EXPECT_EQ(model.input_addrs_list_.size(), 1); + EXPECT_EQ(model.output_addrs_list_.size(), 1); + EXPECT_EQ(model.task_list_.size(), 2); + + OutputData output_data; + vector outputs; + EXPECT_EQ(model.GenOutputTensorInfo(&output_data, outputs), SUCCESS); + + + GraphExecutor graph_executer; + graph_executer.init_flag_ = true; + GeRootModelPtr ge_root_model = make_shared(graph); + std::vector input_tensor; + std::vector output_tensor; + std::vector output_desc; + InputOutputDescInfo desc0; + output_desc.push_back(desc0); + graph_executer.ProcessOutputData(output_data, output_desc, output_tensor); + graph_executer.ExecuteGraphWithStream(0, ge_root_model, input_tensor, output_tensor, nullptr); +} } \ No newline at end of file diff --git a/tests/ut/ge/graph/manager/run_graph_unittest.cc b/tests/ut/ge/graph/manager/run_graph_unittest.cc new file mode 100644 index 00000000..4f1c13de --- /dev/null +++ b/tests/ut/ge/graph/manager/run_graph_unittest.cc @@ -0,0 +1,61 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "graph/anchor.h" +#include "graph/attr_value.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/node_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" +#include "omg/omg_inner_types.h" + +#define protected public +#define private public +#include"graph/manager/graph_manager_utils.h" +#include "graph/manager/graph_manager.h" +#undef protected +#undef private + +using namespace std; +using namespace testing; +using namespace ge; +using domi::GetContext; + +class UtestGraphRunTest : public testing::Test { + protected: + void SetUp() {} + + void TearDown() { GetContext().out_nodes_map.clear(); } +}; + +TEST_F(UtestGraphRunTest, RunGraphWithStreamAsync) { + GraphManager graph_manager; + GeTensor input0, input1; + std::vector inputs{input0, input1}; + std::vector outputs; + GraphNodePtr graph_node = std::make_shared(1); + graph_manager.AddGraphNode(1, graph_node); + GraphPtr graph = std::make_shared("test"); + graph_node->SetGraph(graph); + graph_node->SetRunFlag(false); + graph_node->SetBuildFlag(true); + auto ret = graph_manager.RunGraphWithStreamAsync(1, inputs, outputs, nullptr, 0); + +} \ No newline at end of file diff --git a/tests/ut/ge/session/ge_api_unittest.cc b/tests/ut/ge/session/ge_api_unittest.cc new file mode 100644 index 00000000..124a7d54 --- /dev/null +++ b/tests/ut/ge/session/ge_api_unittest.cc @@ -0,0 +1,59 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include + +#define protected public +#define private public +#include "common/ge/ge_util.h" +#include "proto/ge_ir.pb.h" +#include "inc/external/ge/ge_api.h" +#include "session/session_manager.h" +#undef protected +#undef private + + +using namespace std; + +namespace ge { +class UtestGeApi : public testing::Test { + protected: + void SetUp() override {} + + void TearDown() override {} +}; + +TEST_F(UtestGeApi, run_graph_with_stream) { + vector inputs; + vector outputs; + std::map options; + Session session(options); + auto ret = session.RunGraphWithStreamAsync(10, inputs, outputs, nullptr); + ASSERT_NE(ret, SUCCESS); + SessionManager session_manager; + session_manager.init_flag_ = true; + ret = session_manager.RunGraphWithStreamAsync(10, 10, inputs, outputs, nullptr); + ASSERT_NE(ret, SUCCESS); + InnerSession inner_session(1, options); + inner_session.init_flag_ = true; + ret = inner_session.RunGraphWithStreamAsync(10, inputs, outputs, nullptr); + ASSERT_NE(ret, SUCCESS); +} +} // namespace ge