| @@ -426,7 +426,12 @@ py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &gr | |||||
| if (op_run_info.value != nullptr) { | if (op_run_info.value != nullptr) { | ||||
| std::vector<tensor::TensorPtr> pre_output_tensors; | std::vector<tensor::TensorPtr> pre_output_tensors; | ||||
| TensorValueToTensor(op_run_info.value, &pre_output_tensors); | TensorValueToTensor(op_run_info.value, &pre_output_tensors); | ||||
| std::copy(pre_output_tensors.begin(), pre_output_tensors.end(), std::back_inserter(outputs)); | |||||
| for (auto &pre_output : pre_output_tensors) { | |||||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(pre_output->data_type(), pre_output->shape()); | |||||
| tensor->set_device_address(pre_output->device_address()); | |||||
| tensor->set_dirty(false); | |||||
| outputs.emplace_back(tensor); | |||||
| } | |||||
| } else { | } else { | ||||
| UpdateOutputs(graph, &outputs, input_tensors); | UpdateOutputs(graph, &outputs, input_tensors); | ||||
| } | } | ||||
| @@ -300,7 +300,18 @@ py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph | |||||
| } | } | ||||
| // Fetch outputs | // Fetch outputs | ||||
| VectorRef outputs; | VectorRef outputs; | ||||
| UpdateOutputs(kernel_graph, &outputs, input_tensors); | |||||
| if (op_run_info.value != nullptr) { | |||||
| std::vector<tensor::TensorPtr> pre_output_tensors; | |||||
| TensorValueToTensor(op_run_info.value, &pre_output_tensors); | |||||
| for (auto &pre_output : pre_output_tensors) { | |||||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(pre_output->data_type(), pre_output->shape()); | |||||
| tensor->set_device_address(pre_output->device_address()); | |||||
| tensor->set_dirty(false); | |||||
| outputs.emplace_back(tensor); | |||||
| } | |||||
| } else { | |||||
| UpdateOutputs(kernel_graph, &outputs, input_tensors); | |||||
| } | |||||
| // Trans output to tuple | // Trans output to tuple | ||||
| auto output_tensors = TransformBaseRefListToTuple(outputs); | auto output_tensors = TransformBaseRefListToTuple(outputs); | ||||
| if (!utils::isa<PyObjectRef>(output_tensors) || | if (!utils::isa<PyObjectRef>(output_tensors) || | ||||
| @@ -565,9 +565,9 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat | |||||
| if (session == nullptr) { | if (session == nullptr) { | ||||
| session = session::SessionFactory::Get().Create(device_target); | session = session::SessionFactory::Get().Create(device_target); | ||||
| MS_EXCEPTION_IF_NULL(session); | |||||
| session->Init(ms_context->device_id()); | |||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(session); | |||||
| session->Init(ms_context->device_id()); | |||||
| std::vector<tensor::TensorPtr> input_tensors; | std::vector<tensor::TensorPtr> input_tensors; | ||||
| std::vector<int> tensors_mask; | std::vector<int> tensors_mask; | ||||