|
|
|
@@ -20,6 +20,7 @@ from mindspore.common.tensor import Tensor |
|
|
|
import mindspore.common.dtype as mstype |
|
|
|
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype |
|
|
|
|
|
|
|
|
|
|
|
def scalar_add(x, y): |
|
|
|
"""Implement `scalar_add`.""" |
|
|
|
return x + y |
|
|
|
@@ -164,8 +165,9 @@ hyper_map = C.HyperMap() |
|
|
|
|
|
|
|
def mixed_precision_cast(dst_type, x): |
|
|
|
"""Implement `mixed_precision_cast`.""" |
|
|
|
|
|
|
|
def cast_inner(data): |
|
|
|
if isinstance(data, Tensor) and data.dtype in (mstype.float32, mstype.float16): |
|
|
|
if isinstance(data, Tensor) and data.dtype in (mstype.float32, mstype.float16, mstype.float64): |
|
|
|
return F.cast(data, dst_type) |
|
|
|
return data |
|
|
|
|
|
|
|
|