Browse Source

!3606 [bug][auto_mixed_precision]fix amp doc and eval network build

Merge pull request !3606 from vlne-v1/amp_doc
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
f96efbfe19
2 changed files with 2 additions and 1 deletions
  1. +1
    -0
      mindspore/train/amp.py
  2. +1
    -1
      mindspore/train/model.py

+ 1
- 0
mindspore/train/amp.py View File

@@ -133,6 +133,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32`.
If set to `mstype.float16`, use `float16` mode to train. If set, overwrite the level setting.
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting.
Only `cast_model_type` is `float16`, `keep_batchnorm_fp32` will take effect.
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
scale the loss by LossScaleManager. If set, overwrite the level setting.
"""


+ 1
- 1
mindspore/train/model.py View File

@@ -174,7 +174,7 @@ class Model:
else:
if self._loss_fn is None:
raise ValueError("loss_fn can not be None.")
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level == "O2")
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level in ["O0", "O3"])
self._eval_indexes = [0, 1, 2]

if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):


Loading…
Cancel
Save