|
|
|
@@ -263,11 +263,14 @@ class Cell(Cell_): |
|
|
|
self._attr_synced = False |
|
|
|
|
|
|
|
def _cast_mixed_precision_inputs(self, inputs, dst_type): |
|
|
|
"""Cast input for mixed precision""" |
|
|
|
res = list() |
|
|
|
for item in inputs: |
|
|
|
if isinstance(item, tuple): |
|
|
|
res.append(self._cast_mixed_precision_inputs(item, dst_type)) |
|
|
|
elif item.dtype in {mstype.float16, mstype.float32}: |
|
|
|
elif isinstance(item, float): |
|
|
|
res.append(cast(item, dst_type)) |
|
|
|
elif hasattr(item, "dtype") in {mstype.float16, mstype.float32, mstype.float64}: |
|
|
|
res.append(cast(item, dst_type)) |
|
|
|
else: |
|
|
|
res.append(item) |
|
|
|
|