|
|
@@ -647,7 +647,7 @@ class Cell(Cell_): |
|
|
param.set_cast_dtype(mstype.float32) |
|
|
param.set_cast_dtype(mstype.float32) |
|
|
elif self._mindspore_flags.get('fp16'): |
|
|
elif self._mindspore_flags.get('fp16'): |
|
|
param.set_cast_dtype(mstype.float16) |
|
|
param.set_cast_dtype(mstype.float16) |
|
|
else: |
|
|
|
|
|
|
|
|
elif hasattr(param, "set_cast_dtype"): |
|
|
# retest dtype |
|
|
# retest dtype |
|
|
param.set_cast_dtype() |
|
|
param.set_cast_dtype() |
|
|
return param |
|
|
return param |
|
|
|