From 2739b6d5e59283a41b2c88d4811f973c9b17e91a Mon Sep 17 00:00:00 2001 From: buxue Date: Mon, 28 Dec 2020 10:34:55 +0800 Subject: [PATCH] fix the bug for op when return a single element tuple in PyNative mode --- .../pipeline/pynative/pynative_execute.cc | 38 +++++++++---------- .../pipeline/pynative/pynative_execute.h | 4 +- mindspore/ops/primitive.py | 4 -- 3 files changed, 20 insertions(+), 26 deletions(-) diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 0f539cc6e9..b9499a0378 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -523,7 +523,7 @@ PynativeExecutor::~PynativeExecutor() { ClearRes(); } -py::tuple RunOp(const py::args &args) { +py::object RunOp(const py::args &args) { auto executor = PynativeExecutor::GetInstance(); MS_EXCEPTION_IF_NULL(executor); OpExecInfoPtr op_exec_info = executor->GenerateOpExecInfo(args); @@ -555,10 +555,10 @@ py::tuple RunOp(const py::args &args) { } } -py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { +py::object PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { MS_EXCEPTION_IF_NULL(op_exec_info); if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) { - return RunOpWithInitBackendPolicy(op_exec_info); + return RunOpWithInitBackendPolicy(op_exec_info)[0]; } // make cnode for building grad graph if grad flag is set. abstract::AbstractBasePtrList args_spec_list; @@ -574,14 +574,10 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { MS_EXCEPTION_IF_NULL(prim); py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract); if (!output["value"].is_none()) { - py::tuple value_ret(1); - value_ret[0] = output["value"]; - return value_ret; + return output["value"]; } if (prim->is_const_prim()) { - py::tuple value_ret(1); - value_ret[0] = ""; - return value_ret; + return py::cast(""); } // add output abstract info into cache if (!is_find && !op_exec_info->is_dynamic_shape) { @@ -593,10 +589,12 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { } // run op with selected backend auto result = RunOpWithInitBackendPolicy(op_exec_info); - py::object out_real = result; - if (result.size() == 1) { - MS_LOG(DEBUG) << "Output size is 1"; + py::object out_real; + if (result.size() == 1 && op_exec_info->abstract != nullptr && + !op_exec_info->abstract->isa()) { out_real = result[0]; + } else { + out_real = result; } // update output abstract for cnode if (cnode != nullptr) { @@ -609,7 +607,7 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { SaveAllResult(op_exec_info, cnode, out_real); // Update the abstract and device address of value node with tensor in grad graph UpdateAbstractAndDeviceAddress(op_exec_info, out_real); - return result; + return out_real; } OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) { @@ -788,7 +786,7 @@ py::object PynativeExecutor::DoAutoCast(const py::object &arg, const TypeId &typ op_exec->is_mixed_precision_cast = true; op_exec->next_op_name = op_name; op_exec->next_input_index = index; - return RunOpInner(op_exec)[0]; + return RunOpInner(op_exec); } py::object PynativeExecutor::DoParamMixPrecisionCast(bool *is_cast, const py::object obj, const std::string &op_name, @@ -1249,11 +1247,9 @@ py::tuple PynativeExecutor::RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_e auto backend_policy = InitEnv(op_exec_info); PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE; // returns a null py::tuple on error - py::tuple err_ret(0); py::object result = RunOpWithBackendPolicy(backend_policy, op_exec_info, &status); if (status != PYNATIVE_SUCCESS) { - MS_LOG(ERROR) << "Failed to run " << op_exec_info->op_name; - return err_ret; + MS_LOG(EXCEPTION) << "Failed to run " << op_exec_info->op_name; } MS_LOG(DEBUG) << "RunOp end"; @@ -1361,16 +1357,18 @@ py::object PynativeExecutor::RunOpInVM(const OpExecInfoPtr &op_exec_info, Pynati auto primitive = op_exec_info->py_primitive; MS_EXCEPTION_IF_NULL(primitive); auto result = primitive->RunPyComputeFunction(op_inputs); + MS_LOG(INFO) << "RunOpInVM end"; if (py::isinstance(result)) { MS_LOG(ERROR) << "VM got the result none, please check whether it is failed to get func"; *status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR; py::tuple err_ret(0); return std::move(err_ret); } - // execute op - py::tuple tuple_result = py::make_tuple(result); *status = PYNATIVE_SUCCESS; - MS_LOG(INFO) << "RunOpInVM end"; + if (py::isinstance(result)) { + return result; + } + py::tuple tuple_result = py::make_tuple(result); return std::move(tuple_result); } diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index b74a851096..12ead9dc35 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -55,7 +55,7 @@ using AbstractListMap = std::unordered_map>; using TensorIdWithTensor = std::unordered_map>; -py::tuple RunOp(const py::args &args); +py::object RunOp(const py::args &args); void ClearPyNativeSession(); @@ -114,7 +114,7 @@ class PynativeExecutor : public std::enable_shared_from_this { void EnterConstruct(const py::object &cell); void LeaveConstruct(const py::object &cell); - py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info); + py::object RunOpInner(const OpExecInfoPtr &op_exec_info); OpExecInfoPtr GenerateOpExecInfo(const py::args &args); void NewGraph(const py::object &cell, const py::args &args); py::object Run(const py::object &cell, const py::tuple &args, const py::object &phase); diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index 0c84074f6f..096b4baf30 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -510,8 +510,4 @@ def constexpr(fn=None, get_instance=True, name=None): def _run_op(obj, op_name, args): """Single op execution function supported by ge in PyNative mode.""" output = real_run_op(obj, op_name, args) - if not output: - raise RuntimeError("Pynative run op %s failed!" % op_name) - if len(output) == 1: - output = output[0] return output