Browse Source

optimize updateoutput in gpu

tags/v0.7.0-beta
chujinjin 5 years ago
parent
commit
1cb8d9daf3
3 changed files with 20 additions and 4 deletions
  1. +6
    -1
      mindspore/ccsrc/backend/session/ascend_session.cc
  2. +12
    -1
      mindspore/ccsrc/backend/session/gpu_session.cc
  3. +2
    -2
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc

+ 6
- 1
mindspore/ccsrc/backend/session/ascend_session.cc View File

@@ -426,7 +426,12 @@ py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &gr
if (op_run_info.value != nullptr) {
std::vector<tensor::TensorPtr> pre_output_tensors;
TensorValueToTensor(op_run_info.value, &pre_output_tensors);
std::copy(pre_output_tensors.begin(), pre_output_tensors.end(), std::back_inserter(outputs));
for (auto &pre_output : pre_output_tensors) {
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(pre_output->data_type(), pre_output->shape());
tensor->set_device_address(pre_output->device_address());
tensor->set_dirty(false);
outputs.emplace_back(tensor);
}
} else {
UpdateOutputs(graph, &outputs, input_tensors);
}


+ 12
- 1
mindspore/ccsrc/backend/session/gpu_session.cc View File

@@ -300,7 +300,18 @@ py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph
}
// Fetch outputs
VectorRef outputs;
UpdateOutputs(kernel_graph, &outputs, input_tensors);
if (op_run_info.value != nullptr) {
std::vector<tensor::TensorPtr> pre_output_tensors;
TensorValueToTensor(op_run_info.value, &pre_output_tensors);
for (auto &pre_output : pre_output_tensors) {
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(pre_output->data_type(), pre_output->shape());
tensor->set_device_address(pre_output->device_address());
tensor->set_dirty(false);
outputs.emplace_back(tensor);
}
} else {
UpdateOutputs(kernel_graph, &outputs, input_tensors);
}
// Trans output to tuple
auto output_tensors = TransformBaseRefListToTuple(outputs);
if (!utils::isa<PyObjectRef>(output_tensors) ||


+ 2
- 2
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -565,9 +565,9 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat

if (session == nullptr) {
session = session::SessionFactory::Get().Create(device_target);
MS_EXCEPTION_IF_NULL(session);
session->Init(ms_context->device_id());
}
MS_EXCEPTION_IF_NULL(session);
session->Init(ms_context->device_id());

std::vector<tensor::TensorPtr> input_tensors;
std::vector<int> tensors_mask;


Loading…
Cancel
Save