| @@ -99,12 +99,8 @@ class Model: | |||||
| self._loss_scale_manager_set = False | self._loss_scale_manager_set = False | ||||
| self._keep_bn_fp32 = True | self._keep_bn_fp32 = True | ||||
| self._check_kwargs(kwargs) | self._check_kwargs(kwargs) | ||||
| if 'keep_batchnorm_fp32' in kwargs: | |||||
| self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32'] | |||||
| if 'loss_scale_manager' in kwargs: | |||||
| self._loss_scale_manager = kwargs['loss_scale_manager'] | |||||
| self._loss_scale_manager_set = True | |||||
| self._amp_level = amp_level | self._amp_level = amp_level | ||||
| self._process_amp_args(kwargs) | |||||
| self._parallel_mode = _get_parallel_mode() | self._parallel_mode = _get_parallel_mode() | ||||
| self._device_number = _get_device_num() | self._device_number = _get_device_num() | ||||
| self._global_rank = _get_global_rank() | self._global_rank = _get_global_rank() | ||||
| @@ -114,6 +110,15 @@ class Model: | |||||
| self._build_eval_network(metrics, eval_network, eval_indexes) | self._build_eval_network(metrics, eval_network, eval_indexes) | ||||
| self._build_predict_network() | self._build_predict_network() | ||||
| def _process_amp_args(self, kwargs): | |||||
| if self._amp_level == "O0": | |||||
| self._keep_bn_fp32 = False | |||||
| if 'keep_batchnorm_fp32' in kwargs: | |||||
| self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32'] | |||||
| if 'loss_scale_manager' in kwargs: | |||||
| self._loss_scale_manager = kwargs['loss_scale_manager'] | |||||
| self._loss_scale_manager_set = True | |||||
| def _check_kwargs(self, kwargs): | def _check_kwargs(self, kwargs): | ||||
| for arg in kwargs: | for arg in kwargs: | ||||
| if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']: | if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']: | ||||