From 17082f0487325231ab9dc6c4e2f1bd314e0d607a Mon Sep 17 00:00:00 2001 From: yuchaojie Date: Thu, 22 Oct 2020 20:40:15 +0800 Subject: [PATCH] add value check for QuantizationAwareTraining's param optimize_option --- mindspore/compression/quant/quantizer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mindspore/compression/quant/quantizer.py b/mindspore/compression/quant/quantizer.py index 36a513c7c1..a8bd8c72c0 100644 --- a/mindspore/compression/quant/quantizer.py +++ b/mindspore/compression/quant/quantizer.py @@ -17,6 +17,8 @@ from abc import ABC, abstractmethod from enum import Enum +from ..._checkparam import Validator + __all__ = ["OptimizeOption", "Quantizer"] @@ -39,12 +41,15 @@ class Quantizer(ABC): This class is an abstract class. Args: - optimize_option (OptimizeOption, list or tuple): Specifies the quant algorithm and options. Default: None. + optimize_option (OptimizeOption, list or tuple): Specifies the quant algorithm and options. Default: + OptimizeOption.QAT. """ def __init__(self, - optimize_option=None): + optimize_option=OptimizeOption.QAT): if not isinstance(optimize_option, list) and not isinstance(optimize_option, tuple): optimize_option = [optimize_option] + for option in optimize_option: + option = Validator.check_isinstance("optimize_option", option, OptimizeOption) self.optimize_option = optimize_option @abstractmethod