|
|
|
@@ -65,7 +65,11 @@ _config_level = { |
|
|
|
"O2": { |
|
|
|
"keep_batchnorm_fp32": True, |
|
|
|
"cast_model_type": mstype.float16, |
|
|
|
"loss_scale_manager": DynamicLossScaleManager()}} |
|
|
|
"loss_scale_manager": DynamicLossScaleManager()}, |
|
|
|
"O3": { |
|
|
|
"keep_batchnorm_fp32": False, |
|
|
|
"cast_model_type": mstype.float16, |
|
|
|
"loss_scale_manager": None}} |
|
|
|
|
|
|
|
|
|
|
|
def _check_kwargs(key_words): |
|
|
|
@@ -117,11 +121,13 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): |
|
|
|
loss_fn (Union[None, Cell]): Definition of the loss_fn. If None, the `network` should have the loss inside. |
|
|
|
Default: None. |
|
|
|
optimizer (Optimizer): Optimizer to update the Parameter. |
|
|
|
level (str): Supports [O0, O2]. Default: "O0". |
|
|
|
level (str): Supports [O0, O2, O3]. Default: "O0". |
|
|
|
|
|
|
|
- O0: Do not change. |
|
|
|
- O2: Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32, |
|
|
|
using dynamic loss scale. |
|
|
|
- O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'. |
|
|
|
O2 is recommended on GPU, O3 is recommemded on Ascend. |
|
|
|
|
|
|
|
cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32`. |
|
|
|
If set to `mstype.float16`, use `float16` mode to train. If set, overwrite the level setting. |
|
|
|
@@ -131,7 +137,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): |
|
|
|
""" |
|
|
|
validator.check_value_type('network', network, nn.Cell, None) |
|
|
|
validator.check_value_type('optimizer', optimizer, nn.Optimizer, None) |
|
|
|
validator.check('level', level, "", ['O0', 'O2'], Rel.IN, None) |
|
|
|
validator.check('level', level, "", ['O0', 'O2', 'O3'], Rel.IN, None) |
|
|
|
_check_kwargs(kwargs) |
|
|
|
config = dict(_config_level[level], **kwargs) |
|
|
|
config = edict(config) |
|
|
|
|