|
|
|
@@ -261,6 +261,17 @@ class Cell(Cell_): |
|
|
|
object.__delattr__(self, name) |
|
|
|
self._attr_synced = False |
|
|
|
|
|
|
|
def _cast_mixed_precision_inputs(self, inputs, dst_type): |
|
|
|
res = list() |
|
|
|
for item in inputs: |
|
|
|
if isinstance(item, tuple): |
|
|
|
res.append(self._cast_mixed_precision_inputs(item, dst_type)) |
|
|
|
elif item.dtype in {mstype.float16, mstype.float32}: |
|
|
|
res.append(cast(item, dst_type)) |
|
|
|
else: |
|
|
|
res.append(item) |
|
|
|
return tuple(res) |
|
|
|
|
|
|
|
def cast_inputs(self, inputs, dst_type): |
|
|
|
res = list() |
|
|
|
for item in inputs: |
|
|
|
@@ -299,9 +310,9 @@ class Cell(Cell_): |
|
|
|
cast_inputs = list() |
|
|
|
if hasattr(self, "_mindspore_flags"): |
|
|
|
if self._mindspore_flags.get('fp16'): |
|
|
|
cast_inputs = self.cast_inputs(inputs, mstype.float16) |
|
|
|
cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float16) |
|
|
|
if self._mindspore_flags.get('fp32'): |
|
|
|
cast_inputs = self.cast_inputs(inputs, mstype.float32) |
|
|
|
cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float32) |
|
|
|
if not cast_inputs: |
|
|
|
cast_inputs = inputs |
|
|
|
if self.enable_hook: |
|
|
|
|