Browse Source

!3738 fix bug of cast dtype when using mix_presion in pynative mode

Merge pull request !3738 from jinyaohui/master
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
2883f9366d
2 changed files with 18 additions and 7 deletions
  1. +9
    -1
      mindspore/_extends/builtin_operations.py
  2. +9
    -6
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc

+ 9
- 1
mindspore/_extends/builtin_operations.py View File

@@ -17,6 +17,7 @@ import numpy as np
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype


@@ -115,6 +116,7 @@ def bool_or(x, y):
"""Implement `bool_or`."""
return x or y


def vm_compare(*args):
"""Implement `vm_compare` for tensor."""
obj_str = args[-1]
@@ -143,10 +145,12 @@ def list_len(x):
"""Implement `list_len`."""
return len(x)


def Depend(value, expr):
"""Implement `Depend`."""
return value


# only used in PyNative mode
def make_ref(key, value, ref):
return value
@@ -177,8 +181,12 @@ def stop_gradient(x):

hyper_map = C.HyperMap()


def mixed_precision_cast(dst_type, x):
"""Implement `mixed_precision_cast`."""
def cast_inner(data):
return F.cast(data, dst_type)
if isinstance(data, Tensor) and data.dtype in (mstype.float32, mstype.float16):
return F.cast(data, dst_type)
return data

return hyper_map(cast_inner, x)

+ 9
- 6
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -745,13 +745,16 @@ py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info, const py::args &args) {
return err_ret;
}

auto cnode = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, args, result);
if (cnode != nullptr) {
cnode->set_abstract(op_exec_info->abstract);
MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << cnode->DebugString();
if (op_exec_info->op_name != prim::kPrimMixedPrecisionCast->name()) {
auto cnode = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, args, result);
if (cnode != nullptr) {
cnode->set_abstract(op_exec_info->abstract);
MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << cnode->DebugString();
}
PynativeExecutor::GetInstance()->SaveAllResult(op_exec_info, cnode, result);
MS_LOG(DEBUG) << "RunOp end";
}
PynativeExecutor::GetInstance()->SaveAllResult(op_exec_info, cnode, result);
MS_LOG(DEBUG) << "RunOp end";

return result;
}



Loading…
Cancel
Save