Browse Source

!7901 fix bug of pynative's mixprecision

Merge pull request !7901 from lianliguang/master
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
b27d16c0a5
1 changed files with 13 additions and 2 deletions
  1. +13
    -2
      mindspore/nn/cell.py

+ 13
- 2
mindspore/nn/cell.py View File

@@ -261,6 +261,17 @@ class Cell(Cell_):
object.__delattr__(self, name)
self._attr_synced = False

def _cast_mixed_precision_inputs(self, inputs, dst_type):
res = list()
for item in inputs:
if isinstance(item, tuple):
res.append(self._cast_mixed_precision_inputs(item, dst_type))
elif item.dtype in {mstype.float16, mstype.float32}:
res.append(cast(item, dst_type))
else:
res.append(item)
return tuple(res)

def cast_inputs(self, inputs, dst_type):
res = list()
for item in inputs:
@@ -299,9 +310,9 @@ class Cell(Cell_):
cast_inputs = list()
if hasattr(self, "_mindspore_flags"):
if self._mindspore_flags.get('fp16'):
cast_inputs = self.cast_inputs(inputs, mstype.float16)
cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float16)
if self._mindspore_flags.get('fp32'):
cast_inputs = self.cast_inputs(inputs, mstype.float32)
cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float32)
if not cast_inputs:
cast_inputs = inputs
if self.enable_hook:


Loading…
Cancel
Save