Browse Source

quant_mode 2

pull/15960/head
zhang__sss 4 years ago
parent
commit
f354300cda
1 changed files with 3 additions and 1 deletions
  1. +3
    -1
      mindspore/train/serialization.py

+ 3
- 1
mindspore/train/serialization.py View File

@@ -756,9 +756,11 @@ def quant_mode_manage(func):
Inherit the quant_mode in old version.
"""
def warpper(network, *inputs, file_format, **kwargs):
if not kwargs.get('quant_mode', None):
if 'quant_mode' not in kwargs:
return network
quant_mode = kwargs['quant_mode']
if not isinstance(quant_mode, str):
raise TypeError("The type of quant_mode should be str, but got {}.".format(type(quant_mode)))
if quant_mode in ('AUTO', 'MANUAL'):
kwargs['quant_mode'] = 'QUANT'
return func(network, *inputs, file_format=file_format, **kwargs)


Loading…
Cancel
Save