|
|
|
@@ -61,6 +61,7 @@ class Model: |
|
|
|
- O0: Do not change. |
|
|
|
- O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale. |
|
|
|
- O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'. |
|
|
|
|
|
|
|
O2 is recommended on GPU, O3 is recommended on Ascend. |
|
|
|
|
|
|
|
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else |
|
|
|
@@ -115,7 +116,7 @@ class Model: |
|
|
|
self._build_predict_network() |
|
|
|
|
|
|
|
def _process_amp_args(self, kwargs): |
|
|
|
if self._amp_level == "O0": |
|
|
|
if self._amp_level in ["O0", "O3"]: |
|
|
|
self._keep_bn_fp32 = False |
|
|
|
if 'keep_batchnorm_fp32' in kwargs: |
|
|
|
self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32'] |
|
|
|
|