| @@ -20,6 +20,7 @@ from mindspore.common.tensor import Tensor | |||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype | from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype | ||||
| def scalar_add(x, y): | def scalar_add(x, y): | ||||
| """Implement `scalar_add`.""" | """Implement `scalar_add`.""" | ||||
| return x + y | return x + y | ||||
| @@ -164,8 +165,9 @@ hyper_map = C.HyperMap() | |||||
| def mixed_precision_cast(dst_type, x): | def mixed_precision_cast(dst_type, x): | ||||
| """Implement `mixed_precision_cast`.""" | """Implement `mixed_precision_cast`.""" | ||||
| def cast_inner(data): | def cast_inner(data): | ||||
| if isinstance(data, Tensor) and data.dtype in (mstype.float32, mstype.float16): | |||||
| if isinstance(data, Tensor) and data.dtype in (mstype.float32, mstype.float16, mstype.float64): | |||||
| return F.cast(data, dst_type) | return F.cast(data, dst_type) | ||||
| return data | return data | ||||