|
|
|
@@ -62,6 +62,7 @@ class Model: |
|
|
|
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else |
|
|
|
scale the loss by LossScaleManager. If it is set, overwrite the level setting. It's a eyword argument. |
|
|
|
e.g. Use `loss_scale_manager=None` to set the value. |
|
|
|
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. Default: True. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> class Net(nn.Cell): |
|
|
|
@@ -96,7 +97,10 @@ class Model: |
|
|
|
self._optimizer = optimizer |
|
|
|
self._loss_scale_manager = None |
|
|
|
self._loss_scale_manager_set = False |
|
|
|
self._keep_bn_fp32 = True |
|
|
|
self._check_kwargs(kwargs) |
|
|
|
if 'keep_batchnorm_fp32' in kwargs: |
|
|
|
self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32'] |
|
|
|
if 'loss_scale_manager' in kwargs: |
|
|
|
self._loss_scale_manager = kwargs['loss_scale_manager'] |
|
|
|
self._loss_scale_manager_set = True |
|
|
|
@@ -112,7 +116,7 @@ class Model: |
|
|
|
|
|
|
|
def _check_kwargs(self, kwargs): |
|
|
|
for arg in kwargs: |
|
|
|
if arg not in ['loss_scale_manager']: |
|
|
|
if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']: |
|
|
|
raise ValueError(f"Unsupport arg '{arg}'") |
|
|
|
|
|
|
|
def _build_train_network(self): |
|
|
|
@@ -124,12 +128,14 @@ class Model: |
|
|
|
self._optimizer, |
|
|
|
self._loss_fn, |
|
|
|
level=self._amp_level, |
|
|
|
loss_scale_manager=self._loss_scale_manager) |
|
|
|
loss_scale_manager=self._loss_scale_manager, |
|
|
|
keep_batchnorm_fp32=self._keep_bn_fp32) |
|
|
|
else: |
|
|
|
network = amp.build_train_network(network, |
|
|
|
self._optimizer, |
|
|
|
self._loss_fn, |
|
|
|
level=self._amp_level) |
|
|
|
level=self._amp_level, |
|
|
|
keep_batchnorm_fp32=self._keep_bn_fp32) |
|
|
|
elif self._loss_fn: |
|
|
|
network = nn.WithLossCell(network, self._loss_fn) |
|
|
|
# If need to check if loss_fn is not None, but optimizer is None |
|
|
|
|