|
|
|
@@ -99,12 +99,8 @@ class Model: |
|
|
|
self._loss_scale_manager_set = False |
|
|
|
self._keep_bn_fp32 = True |
|
|
|
self._check_kwargs(kwargs) |
|
|
|
if 'keep_batchnorm_fp32' in kwargs: |
|
|
|
self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32'] |
|
|
|
if 'loss_scale_manager' in kwargs: |
|
|
|
self._loss_scale_manager = kwargs['loss_scale_manager'] |
|
|
|
self._loss_scale_manager_set = True |
|
|
|
self._amp_level = amp_level |
|
|
|
self._process_amp_args(kwargs) |
|
|
|
self._parallel_mode = _get_parallel_mode() |
|
|
|
self._device_number = _get_device_num() |
|
|
|
self._global_rank = _get_global_rank() |
|
|
|
@@ -114,6 +110,15 @@ class Model: |
|
|
|
self._build_eval_network(metrics, eval_network, eval_indexes) |
|
|
|
self._build_predict_network() |
|
|
|
|
|
|
|
def _process_amp_args(self, kwargs): |
|
|
|
if self._amp_level == "O0": |
|
|
|
self._keep_bn_fp32 = False |
|
|
|
if 'keep_batchnorm_fp32' in kwargs: |
|
|
|
self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32'] |
|
|
|
if 'loss_scale_manager' in kwargs: |
|
|
|
self._loss_scale_manager = kwargs['loss_scale_manager'] |
|
|
|
self._loss_scale_manager_set = True |
|
|
|
|
|
|
|
def _check_kwargs(self, kwargs): |
|
|
|
for arg in kwargs: |
|
|
|
if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']: |
|
|
|
|