Browse Source

!2326 fix the problem of BatchNorm config failure at Amp O3 level and some unexpected indent

Merge pull request !2326 from liangzelang/master
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
e4298d7d47
2 changed files with 4 additions and 2 deletions
  1. +2
    -1
      mindspore/train/amp.py
  2. +2
    -1
      mindspore/train/model.py

+ 2
- 1
mindspore/train/amp.py View File

@@ -127,7 +127,8 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
- O2: Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32,
using dynamic loss scale.
- O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
O2 is recommended on GPU, O3 is recommemded on Ascend.

O2 is recommended on GPU, O3 is recommended on Ascend.

cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32`.
If set to `mstype.float16`, use `float16` mode to train. If set, overwrite the level setting.


+ 2
- 1
mindspore/train/model.py View File

@@ -61,6 +61,7 @@ class Model:
- O0: Do not change.
- O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale.
- O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.

O2 is recommended on GPU, O3 is recommended on Ascend.

loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
@@ -115,7 +116,7 @@ class Model:
self._build_predict_network()

def _process_amp_args(self, kwargs):
if self._amp_level == "O0":
if self._amp_level in ["O0", "O3"]:
self._keep_bn_fp32 = False
if 'keep_batchnorm_fp32' in kwargs:
self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32']


Loading…
Cancel
Save