Browse Source

add float64 cast for pynative auto mixed precision

tags/v1.1.0
LianLiguang 5 years ago
parent
commit
c6c0a45c8a
1 changed files with 4 additions and 1 deletions
  1. +4
    -1
      mindspore/nn/cell.py

+ 4
- 1
mindspore/nn/cell.py View File

@@ -263,11 +263,14 @@ class Cell(Cell_):
self._attr_synced = False

def _cast_mixed_precision_inputs(self, inputs, dst_type):
"""Cast input for mixed precision"""
res = list()
for item in inputs:
if isinstance(item, tuple):
res.append(self._cast_mixed_precision_inputs(item, dst_type))
elif item.dtype in {mstype.float16, mstype.float32}:
elif isinstance(item, float):
res.append(cast(item, dst_type))
elif hasattr(item, "dtype") in {mstype.float16, mstype.float32, mstype.float64}:
res.append(cast(item, dst_type))
else:
res.append(item)


Loading…
Cancel
Save