|
|
|
@@ -17,6 +17,7 @@ import numpy as np |
|
|
|
from mindspore.ops import functional as F |
|
|
|
from mindspore.ops import composite as C |
|
|
|
from mindspore.common.tensor import Tensor |
|
|
|
import mindspore.common.dtype as mstype |
|
|
|
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype |
|
|
|
|
|
|
|
|
|
|
|
@@ -115,6 +116,7 @@ def bool_or(x, y): |
|
|
|
"""Implement `bool_or`.""" |
|
|
|
return x or y |
|
|
|
|
|
|
|
|
|
|
|
def vm_compare(*args): |
|
|
|
"""Implement `vm_compare` for tensor.""" |
|
|
|
obj_str = args[-1] |
|
|
|
@@ -143,10 +145,12 @@ def list_len(x): |
|
|
|
"""Implement `list_len`.""" |
|
|
|
return len(x) |
|
|
|
|
|
|
|
|
|
|
|
def Depend(value, expr): |
|
|
|
"""Implement `Depend`.""" |
|
|
|
return value |
|
|
|
|
|
|
|
|
|
|
|
# only used in PyNative mode |
|
|
|
def make_ref(key, value, ref): |
|
|
|
return value |
|
|
|
@@ -177,8 +181,12 @@ def stop_gradient(x): |
|
|
|
|
|
|
|
hyper_map = C.HyperMap() |
|
|
|
|
|
|
|
|
|
|
|
def mixed_precision_cast(dst_type, x): |
|
|
|
"""Implement `mixed_precision_cast`.""" |
|
|
|
def cast_inner(data): |
|
|
|
return F.cast(data, dst_type) |
|
|
|
if isinstance(data, Tensor) and data.dtype in (mstype.float32, mstype.float16): |
|
|
|
return F.cast(data, dst_type) |
|
|
|
return data |
|
|
|
|
|
|
|
return hyper_map(cast_inner, x) |