diff --git a/mindspore/compression/common/__init__.py b/mindspore/compression/common/__init__.py index 149b97bbd6..fb83535cf6 100644 --- a/mindspore/compression/common/__init__.py +++ b/mindspore/compression/common/__init__.py @@ -17,3 +17,6 @@ Compression common module. """ from .constant import * + +__all__ = [] +__all__.extend(constant.__all__) diff --git a/mindspore/compression/common/constant.py b/mindspore/compression/common/constant.py index d190bd1f7f..a5a3744778 100644 --- a/mindspore/compression/common/constant.py +++ b/mindspore/compression/common/constant.py @@ -24,7 +24,7 @@ __all__ = ["QuantDtype"] @enum.unique class QuantDtype(enum.Enum): """ - For type switch + An enum for quant datatype, contains `INT2`~`INT8`, `UINT2`~`UINT8`. """ INT2 = "INT2" INT3 = "INT3" @@ -42,20 +42,42 @@ class QuantDtype(enum.Enum): UINT7 = "UINT7" UINT8 = "UINT8" - FLOAT16 = "FLOAT16" - FLOAT32 = "FLOAT32" - def __str__(self): return f"{self.name}" @staticmethod def is_signed(dtype): + """ + Get whether the quant datatype is signed. + + Args: + dtype (QuantDtype): quant datatype. + + Returns: + bool, whether the input quant datatype is signed. + + Examples: + >>> quant_dtype = QuantDtype.INT8 + >>> is_signed = QuantDtype.is_signed(quant_dtype) + """ return dtype in [QuantDtype.INT2, QuantDtype.INT3, QuantDtype.INT4, QuantDtype.INT5, QuantDtype.INT6, QuantDtype.INT7, QuantDtype.INT8] @staticmethod def switch_signed(dtype): - """switch signed""" + """ + Swicth the signed state of the input quant datatype. + + Args: + dtype (QuantDtype): quant datatype. + + Returns: + QuantDtype, quant datatype with opposite signed state as the input. + + Examples: + >>> quant_dtype = QuantDtype.INT8 + >>> quant_dtype = QuantDtype.switch_signed(quant_dtype) + """ type_map = { QuantDtype.INT2: QuantDtype.UINT2, QuantDtype.INT3: QuantDtype.UINT3, @@ -75,11 +97,20 @@ class QuantDtype(enum.Enum): return type_map[dtype] @DynamicClassAttribute - def value(self): + def _value(self): """The value of the Enum member.""" return int(re.search(r"(\d+)", self._value_).group(1)) @DynamicClassAttribute def num_bits(self): - """The num_bits of the Enum member.""" - return self.value + """ + Get the num bits of the QuantDtype member. + + Returns: + int, the num bits of the QuantDtype member + + Examples: + >>> quant_dtype = QuantDtype.INT8 + >>> num_bits = quant_dtype.num_bits + """ + return self._value diff --git a/mindspore/compression/quant/__init__.py b/mindspore/compression/quant/__init__.py index 29d50d9221..233c50f260 100644 --- a/mindspore/compression/quant/__init__.py +++ b/mindspore/compression/quant/__init__.py @@ -19,3 +19,8 @@ Compression quant module. from .quantizer import * from .qat import * from .quant_utils import * + +__all__ = [] +__all__.extend(qat.__all__) +__all__.extend(quantizer.__all__) +__all__.extend(quant_utils.__all__) diff --git a/mindspore/compression/quant/qat.py b/mindspore/compression/quant/qat.py index 26a6cc6d17..921529da9a 100644 --- a/mindspore/compression/quant/qat.py +++ b/mindspore/compression/quant/qat.py @@ -125,34 +125,6 @@ class _AddFakeQuantAfterSubCell(nn.Cell): return output -class ConvertToQuantNetwork: - """ - Convert network to quantization aware network - """ - __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] - - def __init__(self, **kwargs): - self.network = Validator.check_isinstance('network', kwargs["network"], (nn.Cell,)) - self.weight_qdelay = Validator.check_non_negative_int(kwargs["quant_delay"][0], "quant delay") - self.act_qdelay = Validator.check_int(kwargs["quant_delay"][-1], 0, Rel.GE, "quant delay") - self.bn_fold = Validator.check_bool(kwargs["bn_fold"], "bn fold") - self.freeze_bn = Validator.check_non_negative_int(kwargs["freeze_bn"], "freeze bn") - self.weight_dtype = Validator.check_isinstance("weights dtype", kwargs["quant_dtype"][0], QuantDtype) - self.act_dtype = Validator.check_isinstance("activations dtype", kwargs["quant_dtype"][-1], QuantDtype) - self.weight_channel = Validator.check_bool(kwargs["per_channel"][0], "per channel") - self.act_channel = Validator.check_bool(kwargs["per_channel"][-1], "per channel") - self.weight_symmetric = Validator.check_bool(kwargs["symmetric"][0], "symmetric") - self.act_symmetric = Validator.check_bool(kwargs["symmetric"][-1], "symmetric") - self.weight_range = Validator.check_bool(kwargs["narrow_range"][0], "narrow range") - self.act_range = Validator.check_bool(kwargs["narrow_range"][-1], "narrow range") - self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv, - quant.DenseBnAct: self._convert_dense} - self.quant_config = get_quant_config(quant_delay=kwargs["quant_delay"], - quant_dtype=kwargs["quant_dtype"], - per_channel=kwargs["per_channel"], - symmetric=kwargs["symmetric"], - narrow_range=kwargs["narrow_range"]) - class QuantizationAwareTraining(Quantizer): r""" Quantizer for quantization aware training. @@ -175,6 +147,39 @@ class QuantizationAwareTraining(Quantizer): The first element represents weights and the second element represents data flow. Default: (False, False) optimize_option (OptimizeOption, list or tuple): Specifies the quant algorithm and options, currently only support QAT. Default: OptimizeOption.QAT + + Examples: + >>> class Net(nn.Cell): + >>> def __init__(self, num_class=10, channel=1): + >>> super(LeNet5, self).__init__() + >>> self.type = "fusion" + >>> self.num_class = num_class + >>> + >>> # change `nn.Conv2d` to `nn.Conv2dBnAct` + >>> self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu') + >>> self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu') + >>> # change `nn.Dense` to `nn.DenseBnAct` + >>> self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') + >>> self.fc2 = nn.DenseBnAct(120, 84, activation='relu') + >>> self.fc3 = nn.DenseBnAct(84, self.num_class) + >>> + >>> self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + >>> self.flatten = nn.Flatten() + >>> + >>> def construct(self, x): + >>> x = self.conv1(x) + >>> x = self.max_pool2d(x) + >>> x = self.conv2(x) + >>> x = self.max_pool2d(x) + >>> x = self.flatten(x) + >>> x = self.fc1(x) + >>> x = self.fc2(x) + >>> x = self.fc3(x) + >>> return x + >>> + >>> net = Net() + >>> quantizer = QuantizationAwareTraining(bn_fold=False, per_channel=[True, False], symmetric=[True, False]) + >>> net_qat = quantizer.quantize(net) """ __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] @@ -230,6 +235,17 @@ class QuantizationAwareTraining(Quantizer): return name_new def quantize(self, network): + """ + Quant API to convert input network to a quantization aware training network + + Args: + network (Cell): network to be quantized. + + Examples: + >>> net = Net() + >>> quantizer = QuantizationAwareTraining() + >>> net_qat = quantizer.quantize(net) + """ support_device = ["Ascend", "GPU"] if context.get_context('device_target') not in support_device: raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) diff --git a/mindspore/compression/quant/quant_utils.py b/mindspore/compression/quant/quant_utils.py index 8d0894088c..9fe58e2df7 100644 --- a/mindspore/compression/quant/quant_utils.py +++ b/mindspore/compression/quant/quant_utils.py @@ -267,13 +267,13 @@ def without_fold_batchnorm(weight, cell_quant): def load_nonquant_param_into_quant_net(quant_model, params_dict, quant_new_params=None): - """ - load fp32 model parameters to quantization model. + r""" + Load fp32 model parameters into quantization model. Args: quant_model: quantization model. - params_dict: f32 param. - quant_new_params:parameters that exist in quantative network but not in unquantative network. + params_dict: parameter dict that stores fp32 parameters. + quant_new_params: parameters that exist in quantative network but not in unquantative network. Returns: None diff --git a/mindspore/compression/quant/quantizer.py b/mindspore/compression/quant/quantizer.py index a8bd8c72c0..24dd0d8f39 100644 --- a/mindspore/compression/quant/quantizer.py +++ b/mindspore/compression/quant/quantizer.py @@ -19,12 +19,12 @@ from enum import Enum from ..._checkparam import Validator -__all__ = ["OptimizeOption", "Quantizer"] +__all__ = ["OptimizeOption"] class OptimizeOption(Enum): - """ - An enum for the model quantization optimize option. + r""" + An enum for the model quantization optimize option, currently only support `QAT`. """ # using quantization aware training QAT = "QAT"