From 39c251c015c07d2e0a0e3b97fb87cfd19c65ca68 Mon Sep 17 00:00:00 2001 From: jinyaohui Date: Fri, 31 Jul 2020 14:59:49 +0800 Subject: [PATCH] fix bug of cast dtype when using mix_presion in pynative mode --- mindspore/_extends/builtin_operations.py | 10 +++++++++- .../ccsrc/pipeline/pynative/pynative_execute.cc | 15 +++++++++------ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/mindspore/_extends/builtin_operations.py b/mindspore/_extends/builtin_operations.py index fc0498f342..1eade2d86d 100644 --- a/mindspore/_extends/builtin_operations.py +++ b/mindspore/_extends/builtin_operations.py @@ -17,6 +17,7 @@ import numpy as np from mindspore.ops import functional as F from mindspore.ops import composite as C from mindspore.common.tensor import Tensor +import mindspore.common.dtype as mstype from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype @@ -115,6 +116,7 @@ def bool_or(x, y): """Implement `bool_or`.""" return x or y + def vm_compare(*args): """Implement `vm_compare` for tensor.""" obj_str = args[-1] @@ -143,10 +145,12 @@ def list_len(x): """Implement `list_len`.""" return len(x) + def Depend(value, expr): """Implement `Depend`.""" return value + # only used in PyNative mode def make_ref(key, value, ref): return value @@ -177,8 +181,12 @@ def stop_gradient(x): hyper_map = C.HyperMap() + def mixed_precision_cast(dst_type, x): """Implement `mixed_precision_cast`.""" def cast_inner(data): - return F.cast(data, dst_type) + if isinstance(data, Tensor) and data.dtype in (mstype.float32, mstype.float16): + return F.cast(data, dst_type) + return data + return hyper_map(cast_inner, x) diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 34e637cef2..97e9e47a16 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -745,13 +745,16 @@ py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info, const py::args &args) { return err_ret; } - auto cnode = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, args, result); - if (cnode != nullptr) { - cnode->set_abstract(op_exec_info->abstract); - MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << cnode->DebugString(); + if (op_exec_info->op_name != prim::kPrimMixedPrecisionCast->name()) { + auto cnode = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, args, result); + if (cnode != nullptr) { + cnode->set_abstract(op_exec_info->abstract); + MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << cnode->DebugString(); + } + PynativeExecutor::GetInstance()->SaveAllResult(op_exec_info, cnode, result); + MS_LOG(DEBUG) << "RunOp end"; } - PynativeExecutor::GetInstance()->SaveAllResult(op_exec_info, cnode, result); - MS_LOG(DEBUG) << "RunOp end"; + return result; }