Browse Source

fix bug.

pull/1506/head
zhaozhixuan 4 years ago
parent
commit
5d1e8b766d
4 changed files with 18 additions and 16 deletions
  1. +7
    -8
      ge/graph/execute/graph_execute.cc
  2. +3
    -0
      ge/graph/execute/graph_execute.h
  3. +4
    -4
      ge/graph/manager/graph_manager.cc
  4. +4
    -4
      ge/session/inner_session.cc

+ 7
- 8
ge/graph/execute/graph_execute.cc View File

@@ -404,22 +404,21 @@ void GraphExecutor::GetInputOutputData(const std::vector<GeTensor> &input_tensor
std::vector<GeTensor> &output_tensor,
InputData &inputs,
OutputData &outputs) {
graph_input_data.index = 0;
graph_input_data.timeout = 0;
graph_input_data.timestamp = 0;

for (const auto &tensor : input_tensor) {
inputs.index = 0;
inputs.timeout = 0;
inputs.timestamp = 0;
for (auto &tensor : input_tensor) {
DataBuffer in_data_buf;
in_data_buf.data = reinterpret_cast<uint8_t *>(tensor.GetData().data());
in_data_buf.data = const_cast<uint8_t *>(tensor.GetData().data());
in_data_buf.length = tensor.GetData().size();
in_data_buf.isDataSupportMemShare = false;
inputs.blobs.emplace_back(in_data_buf);
}

outputs.index = 0;
for (const auto &tensor : output_tensor) {
for (auto &tensor : output_tensor) {
DataBuffer out_data_buf;
out_data_buf.data = reinterpret_cast<uint8_t *>(tensor.GetData().data());
out_data_buf.data = const_cast<uint8_t *>(tensor.GetData().data());
out_data_buf.length = tensor.GetData().size();
out_data_buf.isDataSupportMemShare = false;
outputs.blobs.emplace_back(out_data_buf);


+ 3
- 0
ge/graph/execute/graph_execute.h View File

@@ -126,6 +126,9 @@ class GraphExecutor {
Status PrepareInputData(const std::vector<GeTensor> &input_tensor, InputData &graph_input_data,
OutputData &graph_output_data, std::vector<InputOutputDescInfo> &output_desc);

void GetInputOutputData(const std::vector<GeTensor> &input_tensor, std::vector<GeTensor> &output_tensor,
InputData &inputs, OutputData &outputs);

Status SyncExecuteModel(uint32_t model_id, const std::vector<GeTensor> &input_tensor,
std::vector<GeTensor> &output_tensor);



+ 4
- 4
ge/graph/manager/graph_manager.cc View File

@@ -1128,8 +1128,8 @@ Status GraphManager::InnerRunGraph(GraphNodePtr &graph_node, const GraphId &grap
return SUCCESS;
}

Status RunGraphWithStreamAsync(const GraphId &graph_id, const std::vector<GeTensor> &inputs,
std::vector<GeTensor> &outputs, rtStream_t stream, uint64_t session_id) {
Status GraphManager::RunGraphWithStreamAsync(const GraphId &graph_id, const std::vector<GeTensor> &inputs,
std::vector<GeTensor> &outputs, rtStream_t stream, uint64_t session_id) {
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther);
std::lock_guard<std::mutex> lock(run_mutex_);
GELOGI("[RunGraphWithStreamAsync] start to run graph, graph_id = %u, is_train_graph: %d", graph_id, GetTrainFlag());
@@ -1173,8 +1173,8 @@ Status RunGraphWithStreamAsync(const GraphId &graph_id, const std::vector<GeTens
return ret;
}

auto ret = graph_executor_.ExecuteGraphWithStream(graph_id, graph_node->GetGeRootModel(),
inputs, outputs, stream);
ret = graph_executor_.ExecuteGraphWithStream(graph_id, graph_node->GetGeRootModel(),
inputs, outputs, stream);
graph_node->SetRunFlag(false);
graph_node->SetIsSpecificStream(false);
if (ret != SUCCESS) {


+ 4
- 4
ge/session/inner_session.cc View File

@@ -272,15 +272,15 @@ Status InnerSession::RunGraphWithStreamAsync(uint32_t graph_id, const std::vecto
return GE_SESS_INIT_FAILED;
}
UpdateThreadContext(graph_id);
vector<GeTensor> geInputs;
vector<GeTensor> ge_inputs;
for (auto &item : inputs) {
geInputs.emplace_back(TensorAdapter::AsGeTensor(item));
ge_inputs.emplace_back(TensorAdapter::AsGeTensor(item));
}
vector<GeTensor> geOutputs;
vector<GeTensor> ge_outputs;
for (auto &item : outputs) {
ge_outputs.emplace_back(TensorAdapter::AsGeTensor(item));
}
Status ret = graph_manager_.RunGraphWithStreamAsync(graph_id, geInputs, geOutputs, stream, session_id_);
Status ret = graph_manager_.RunGraphWithStreamAsync(graph_id, ge_inputs, ge_outputs, stream, session_id_);
domi::GetContext().out_nodes_map.clear();
domi::GetContext().user_out_nodes.clear();
if (ret != SUCCESS) {


Loading…
Cancel
Save