Browse Source

fix bug of mixed precision

tags/v1.1.0
LianLiguang 5 years ago
parent
commit
85bd04ba05
1 changed files with 1 additions and 1 deletions
  1. +1
    -1
      mindspore/nn/cell.py

+ 1
- 1
mindspore/nn/cell.py View File

@@ -270,7 +270,7 @@ class Cell(Cell_):
res.append(self._cast_mixed_precision_inputs(item, dst_type))
elif isinstance(item, float):
res.append(cast(item, dst_type))
elif hasattr(item, "dtype") in {mstype.float16, mstype.float32, mstype.float64}:
elif hasattr(item, "dtype") and item.dtype in {mstype.float16, mstype.float32, mstype.float64}:
res.append(cast(item, dst_type))
else:
res.append(item)


Loading…
Cancel
Save