diff --git a/mindspore/compression/quant/quantizer.py b/mindspore/compression/quant/quantizer.py index 36a513c7c1..a8bd8c72c0 100644 --- a/mindspore/compression/quant/quantizer.py +++ b/mindspore/compression/quant/quantizer.py @@ -17,6 +17,8 @@ from abc import ABC, abstractmethod from enum import Enum +from ..._checkparam import Validator + __all__ = ["OptimizeOption", "Quantizer"] @@ -39,12 +41,15 @@ class Quantizer(ABC): This class is an abstract class. Args: - optimize_option (OptimizeOption, list or tuple): Specifies the quant algorithm and options. Default: None. + optimize_option (OptimizeOption, list or tuple): Specifies the quant algorithm and options. Default: + OptimizeOption.QAT. """ def __init__(self, - optimize_option=None): + optimize_option=OptimizeOption.QAT): if not isinstance(optimize_option, list) and not isinstance(optimize_option, tuple): optimize_option = [optimize_option] + for option in optimize_option: + option = Validator.check_isinstance("optimize_option", option, OptimizeOption) self.optimize_option = optimize_option @abstractmethod