|
|
@@ -125,34 +125,6 @@ class _AddFakeQuantAfterSubCell(nn.Cell): |
|
|
return output |
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConvertToQuantNetwork: |
|
|
|
|
|
""" |
|
|
|
|
|
Convert network to quantization aware network |
|
|
|
|
|
""" |
|
|
|
|
|
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
|
|
|
self.network = Validator.check_isinstance('network', kwargs["network"], (nn.Cell,)) |
|
|
|
|
|
self.weight_qdelay = Validator.check_non_negative_int(kwargs["quant_delay"][0], "quant delay") |
|
|
|
|
|
self.act_qdelay = Validator.check_int(kwargs["quant_delay"][-1], 0, Rel.GE, "quant delay") |
|
|
|
|
|
self.bn_fold = Validator.check_bool(kwargs["bn_fold"], "bn fold") |
|
|
|
|
|
self.freeze_bn = Validator.check_non_negative_int(kwargs["freeze_bn"], "freeze bn") |
|
|
|
|
|
self.weight_dtype = Validator.check_isinstance("weights dtype", kwargs["quant_dtype"][0], QuantDtype) |
|
|
|
|
|
self.act_dtype = Validator.check_isinstance("activations dtype", kwargs["quant_dtype"][-1], QuantDtype) |
|
|
|
|
|
self.weight_channel = Validator.check_bool(kwargs["per_channel"][0], "per channel") |
|
|
|
|
|
self.act_channel = Validator.check_bool(kwargs["per_channel"][-1], "per channel") |
|
|
|
|
|
self.weight_symmetric = Validator.check_bool(kwargs["symmetric"][0], "symmetric") |
|
|
|
|
|
self.act_symmetric = Validator.check_bool(kwargs["symmetric"][-1], "symmetric") |
|
|
|
|
|
self.weight_range = Validator.check_bool(kwargs["narrow_range"][0], "narrow range") |
|
|
|
|
|
self.act_range = Validator.check_bool(kwargs["narrow_range"][-1], "narrow range") |
|
|
|
|
|
self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv, |
|
|
|
|
|
quant.DenseBnAct: self._convert_dense} |
|
|
|
|
|
self.quant_config = get_quant_config(quant_delay=kwargs["quant_delay"], |
|
|
|
|
|
quant_dtype=kwargs["quant_dtype"], |
|
|
|
|
|
per_channel=kwargs["per_channel"], |
|
|
|
|
|
symmetric=kwargs["symmetric"], |
|
|
|
|
|
narrow_range=kwargs["narrow_range"]) |
|
|
|
|
|
|
|
|
|
|
|
class QuantizationAwareTraining(Quantizer): |
|
|
class QuantizationAwareTraining(Quantizer): |
|
|
r""" |
|
|
r""" |
|
|
Quantizer for quantization aware training. |
|
|
Quantizer for quantization aware training. |
|
|
@@ -175,6 +147,39 @@ class QuantizationAwareTraining(Quantizer): |
|
|
The first element represents weights and the second element represents data flow. Default: (False, False) |
|
|
The first element represents weights and the second element represents data flow. Default: (False, False) |
|
|
optimize_option (OptimizeOption, list or tuple): Specifies the quant algorithm and options, currently only |
|
|
optimize_option (OptimizeOption, list or tuple): Specifies the quant algorithm and options, currently only |
|
|
support QAT. Default: OptimizeOption.QAT |
|
|
support QAT. Default: OptimizeOption.QAT |
|
|
|
|
|
|
|
|
|
|
|
Examples: |
|
|
|
|
|
>>> class Net(nn.Cell): |
|
|
|
|
|
>>> def __init__(self, num_class=10, channel=1): |
|
|
|
|
|
>>> super(LeNet5, self).__init__() |
|
|
|
|
|
>>> self.type = "fusion" |
|
|
|
|
|
>>> self.num_class = num_class |
|
|
|
|
|
>>> |
|
|
|
|
|
>>> # change `nn.Conv2d` to `nn.Conv2dBnAct` |
|
|
|
|
|
>>> self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu') |
|
|
|
|
|
>>> self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu') |
|
|
|
|
|
>>> # change `nn.Dense` to `nn.DenseBnAct` |
|
|
|
|
|
>>> self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') |
|
|
|
|
|
>>> self.fc2 = nn.DenseBnAct(120, 84, activation='relu') |
|
|
|
|
|
>>> self.fc3 = nn.DenseBnAct(84, self.num_class) |
|
|
|
|
|
>>> |
|
|
|
|
|
>>> self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) |
|
|
|
|
|
>>> self.flatten = nn.Flatten() |
|
|
|
|
|
>>> |
|
|
|
|
|
>>> def construct(self, x): |
|
|
|
|
|
>>> x = self.conv1(x) |
|
|
|
|
|
>>> x = self.max_pool2d(x) |
|
|
|
|
|
>>> x = self.conv2(x) |
|
|
|
|
|
>>> x = self.max_pool2d(x) |
|
|
|
|
|
>>> x = self.flatten(x) |
|
|
|
|
|
>>> x = self.fc1(x) |
|
|
|
|
|
>>> x = self.fc2(x) |
|
|
|
|
|
>>> x = self.fc3(x) |
|
|
|
|
|
>>> return x |
|
|
|
|
|
>>> |
|
|
|
|
|
>>> net = Net() |
|
|
|
|
|
>>> quantizer = QuantizationAwareTraining(bn_fold=False, per_channel=[True, False], symmetric=[True, False]) |
|
|
|
|
|
>>> net_qat = quantizer.quantize(net) |
|
|
""" |
|
|
""" |
|
|
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] |
|
|
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] |
|
|
|
|
|
|
|
|
@@ -230,6 +235,17 @@ class QuantizationAwareTraining(Quantizer): |
|
|
return name_new |
|
|
return name_new |
|
|
|
|
|
|
|
|
def quantize(self, network): |
|
|
def quantize(self, network): |
|
|
|
|
|
""" |
|
|
|
|
|
Quant API to convert input network to a quantization aware training network |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
|
network (Cell): network to be quantized. |
|
|
|
|
|
|
|
|
|
|
|
Examples: |
|
|
|
|
|
>>> net = Net() |
|
|
|
|
|
>>> quantizer = QuantizationAwareTraining() |
|
|
|
|
|
>>> net_qat = quantizer.quantize(net) |
|
|
|
|
|
""" |
|
|
support_device = ["Ascend", "GPU"] |
|
|
support_device = ["Ascend", "GPU"] |
|
|
if context.get_context('device_target') not in support_device: |
|
|
if context.get_context('device_target') not in support_device: |
|
|
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) |
|
|
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) |
|
|
|