diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 698105889a..36e9417095 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -99,12 +99,8 @@ class Model: self._loss_scale_manager_set = False self._keep_bn_fp32 = True self._check_kwargs(kwargs) - if 'keep_batchnorm_fp32' in kwargs: - self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32'] - if 'loss_scale_manager' in kwargs: - self._loss_scale_manager = kwargs['loss_scale_manager'] - self._loss_scale_manager_set = True self._amp_level = amp_level + self._process_amp_args(kwargs) self._parallel_mode = _get_parallel_mode() self._device_number = _get_device_num() self._global_rank = _get_global_rank() @@ -114,6 +110,15 @@ class Model: self._build_eval_network(metrics, eval_network, eval_indexes) self._build_predict_network() + def _process_amp_args(self, kwargs): + if self._amp_level == "O0": + self._keep_bn_fp32 = False + if 'keep_batchnorm_fp32' in kwargs: + self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32'] + if 'loss_scale_manager' in kwargs: + self._loss_scale_manager = kwargs['loss_scale_manager'] + self._loss_scale_manager_set = True + def _check_kwargs(self, kwargs): for arg in kwargs: if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']: