Browse Source

amp add `best_choice` level

tags/v1.0.0
Wei Luning 5 years ago
parent
commit
5e1cba77f0
1 changed files with 15 additions and 2 deletions
  1. +15
    -2
      mindspore/train/amp.py

+ 15
- 2
mindspore/train/amp.py View File

@@ -121,12 +121,15 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
loss_fn (Union[None, Cell]): Definition of the loss_fn. If None, the `network` should have the loss inside.
Default: None.
optimizer (Optimizer): Optimizer to update the Parameter.
level (str): Supports [O0, O2, O3]. Default: "O0".
level (str): Supports ["O0", "O2", "O3", "auto"]. Default: "O0".

- O0: Do not change.
- O2: Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32,
using dynamic loss scale.
- O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
- auto: Set to level to recommended level in different devices. Set level to O2 on GPU, Set
level to O3 Ascend. The recommended level is choose by the export experience, cannot
always generalize. User should specify the level for special network.

O2 is recommended on GPU, O3 is recommended on Ascend.

@@ -139,7 +142,17 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
"""
validator.check_value_type('network', network, nn.Cell, None)
validator.check_value_type('optimizer', optimizer, nn.Optimizer, None)
validator.check('level', level, "", ['O0', 'O2', 'O3'], Rel.IN, None)
validator.check('level', level, "", ['O0', 'O2', 'O3', "auto"], Rel.IN, None)

if level == "auto":
device_target = context.get_context('device_target')
if device_target == "GPU":
level = "O2"
elif device_target == "Ascend":
level = "O3"
else:
raise ValueError("Level `auto` only support when `device_target` is GPU or Ascend.")

_check_kwargs(kwargs)
config = dict(_config_level[level], **kwargs)
config = edict(config)


Loading…
Cancel
Save