Browse Source

add float64 of mixed_precision_cast

tags/v1.1.0
chenfei 5 years ago
parent
commit
369ee9ef9f
1 changed files with 3 additions and 1 deletions
  1. +3
    -1
      mindspore/_extends/builtin_operations.py

+ 3
- 1
mindspore/_extends/builtin_operations.py View File

@@ -20,6 +20,7 @@ from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype



def scalar_add(x, y): def scalar_add(x, y):
"""Implement `scalar_add`.""" """Implement `scalar_add`."""
return x + y return x + y
@@ -164,8 +165,9 @@ hyper_map = C.HyperMap()


def mixed_precision_cast(dst_type, x): def mixed_precision_cast(dst_type, x):
"""Implement `mixed_precision_cast`.""" """Implement `mixed_precision_cast`."""

def cast_inner(data): def cast_inner(data):
if isinstance(data, Tensor) and data.dtype in (mstype.float32, mstype.float16):
if isinstance(data, Tensor) and data.dtype in (mstype.float32, mstype.float16, mstype.float64):
return F.cast(data, dst_type) return F.cast(data, dst_type)
return data return data




Loading…
Cancel
Save