From 5d1e8b766db99aac0de03977274f9f2a763de6f0 Mon Sep 17 00:00:00 2001 From: zhaozhixuan Date: Tue, 13 Apr 2021 10:44:39 +0800 Subject: [PATCH] fix bug. --- ge/graph/execute/graph_execute.cc | 15 +++++++-------- ge/graph/execute/graph_execute.h | 3 +++ ge/graph/manager/graph_manager.cc | 8 ++++---- ge/session/inner_session.cc | 8 ++++---- 4 files changed, 18 insertions(+), 16 deletions(-) diff --git a/ge/graph/execute/graph_execute.cc b/ge/graph/execute/graph_execute.cc index 8dddca06..53562bc8 100755 --- a/ge/graph/execute/graph_execute.cc +++ b/ge/graph/execute/graph_execute.cc @@ -404,22 +404,21 @@ void GraphExecutor::GetInputOutputData(const std::vector &input_tensor std::vector &output_tensor, InputData &inputs, OutputData &outputs) { - graph_input_data.index = 0; - graph_input_data.timeout = 0; - graph_input_data.timestamp = 0; - - for (const auto &tensor : input_tensor) { + inputs.index = 0; + inputs.timeout = 0; + inputs.timestamp = 0; + for (auto &tensor : input_tensor) { DataBuffer in_data_buf; - in_data_buf.data = reinterpret_cast(tensor.GetData().data()); + in_data_buf.data = const_cast(tensor.GetData().data()); in_data_buf.length = tensor.GetData().size(); in_data_buf.isDataSupportMemShare = false; inputs.blobs.emplace_back(in_data_buf); } outputs.index = 0; - for (const auto &tensor : output_tensor) { + for (auto &tensor : output_tensor) { DataBuffer out_data_buf; - out_data_buf.data = reinterpret_cast(tensor.GetData().data()); + out_data_buf.data = const_cast(tensor.GetData().data()); out_data_buf.length = tensor.GetData().size(); out_data_buf.isDataSupportMemShare = false; outputs.blobs.emplace_back(out_data_buf); diff --git a/ge/graph/execute/graph_execute.h b/ge/graph/execute/graph_execute.h index 14c4c858..ec9ac7be 100755 --- a/ge/graph/execute/graph_execute.h +++ b/ge/graph/execute/graph_execute.h @@ -126,6 +126,9 @@ class GraphExecutor { Status PrepareInputData(const std::vector &input_tensor, InputData &graph_input_data, OutputData &graph_output_data, std::vector &output_desc); + void GetInputOutputData(const std::vector &input_tensor, std::vector &output_tensor, + InputData &inputs, OutputData &outputs); + Status SyncExecuteModel(uint32_t model_id, const std::vector &input_tensor, std::vector &output_tensor); diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index 02a02fa8..eaaaf862 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -1128,8 +1128,8 @@ Status GraphManager::InnerRunGraph(GraphNodePtr &graph_node, const GraphId &grap return SUCCESS; } -Status RunGraphWithStreamAsync(const GraphId &graph_id, const std::vector &inputs, - std::vector &outputs, rtStream_t stream, uint64_t session_id) { +Status GraphManager::RunGraphWithStreamAsync(const GraphId &graph_id, const std::vector &inputs, + std::vector &outputs, rtStream_t stream, uint64_t session_id) { ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther); std::lock_guard lock(run_mutex_); GELOGI("[RunGraphWithStreamAsync] start to run graph, graph_id = %u, is_train_graph: %d", graph_id, GetTrainFlag()); @@ -1173,8 +1173,8 @@ Status RunGraphWithStreamAsync(const GraphId &graph_id, const std::vectorGetGeRootModel(), - inputs, outputs, stream); + ret = graph_executor_.ExecuteGraphWithStream(graph_id, graph_node->GetGeRootModel(), + inputs, outputs, stream); graph_node->SetRunFlag(false); graph_node->SetIsSpecificStream(false); if (ret != SUCCESS) { diff --git a/ge/session/inner_session.cc b/ge/session/inner_session.cc index 49e3a079..30b2e29d 100755 --- a/ge/session/inner_session.cc +++ b/ge/session/inner_session.cc @@ -272,15 +272,15 @@ Status InnerSession::RunGraphWithStreamAsync(uint32_t graph_id, const std::vecto return GE_SESS_INIT_FAILED; } UpdateThreadContext(graph_id); - vector geInputs; + vector ge_inputs; for (auto &item : inputs) { - geInputs.emplace_back(TensorAdapter::AsGeTensor(item)); + ge_inputs.emplace_back(TensorAdapter::AsGeTensor(item)); } - vector geOutputs; + vector ge_outputs; for (auto &item : outputs) { ge_outputs.emplace_back(TensorAdapter::AsGeTensor(item)); } - Status ret = graph_manager_.RunGraphWithStreamAsync(graph_id, geInputs, geOutputs, stream, session_id_); + Status ret = graph_manager_.RunGraphWithStreamAsync(graph_id, ge_inputs, ge_outputs, stream, session_id_); domi::GetContext().out_nodes_map.clear(); domi::GetContext().user_out_nodes.clear(); if (ret != SUCCESS) {