Browse Source

fix bug of quant function

pull/14340/head
chengxianbin 4 years ago
parent
commit
b64d8c54b3
1 changed files with 4 additions and 2 deletions
  1. +4
    -2
      mindspore/compression/quant/qat.py

+ 4
- 2
mindspore/compression/quant/qat.py View File

@@ -56,7 +56,7 @@ def create_quant_config(quant_observer=(nn.FakeQuantWithMinMaxObserver, nn.FakeQ
Default: (QuantDtype.INT8, QuantDtype.INT8) Default: (QuantDtype.INT8, QuantDtype.INT8)
per_channel (Union[bool, list, tuple]): Quantization granularity based on layer or on channel. If `True` per_channel (Union[bool, list, tuple]): Quantization granularity based on layer or on channel. If `True`
then base on per channel otherwise base on per layer. The first element represents weights then base on per channel otherwise base on per layer. The first element represents weights
and second element represents data flow. Default: (False, False)
and second element represents data flow, and second element must be `False` now. Default: (False, False)
symmetric (Union[bool, list, tuple]): Whether the quantization algorithm is symmetric or not. If `True` then symmetric (Union[bool, list, tuple]): Whether the quantization algorithm is symmetric or not. If `True` then
base on symmetric otherwise base on asymmetric. The first element represents weights and second base on symmetric otherwise base on asymmetric. The first element represents weights and second
element represents data flow. Default: (False, False) element represents data flow. Default: (False, False)
@@ -66,6 +66,8 @@ def create_quant_config(quant_observer=(nn.FakeQuantWithMinMaxObserver, nn.FakeQ
Returns: Returns:
QuantConfig, Contains the observer type of weight and activation. QuantConfig, Contains the observer type of weight and activation.
""" """
if per_channel[-1]:
raise ValueError("Arg 'per_channel' second element must be 'False'.")
weight_observer = quant_observer[0].partial_init(quant_delay=quant_delay[0], quant_dtype=quant_dtype[0], weight_observer = quant_observer[0].partial_init(quant_delay=quant_delay[0], quant_dtype=quant_dtype[0],
per_channel=per_channel[0], symmetric=symmetric[0], per_channel=per_channel[0], symmetric=symmetric[0],
narrow_range=narrow_range[0]) narrow_range=narrow_range[0])
@@ -130,7 +132,7 @@ class QuantizationAwareTraining(Quantizer):
Default: (QuantDtype.INT8, QuantDtype.INT8) Default: (QuantDtype.INT8, QuantDtype.INT8)
per_channel (Union[bool, list, tuple]): Quantization granularity based on layer or on channel. If `True` per_channel (Union[bool, list, tuple]): Quantization granularity based on layer or on channel. If `True`
then base on per channel otherwise base on per layer. The first element represents weights then base on per channel otherwise base on per layer. The first element represents weights
and second element represents data flow. Default: (False, False)
and second element represents data flow, and second element must be `False` now. Default: (False, False)
symmetric (Union[bool, list, tuple]): Whether the quantization algorithm is symmetric or not. If `True` then symmetric (Union[bool, list, tuple]): Whether the quantization algorithm is symmetric or not. If `True` then
base on symmetric otherwise base on asymmetric. The first element represents weights and second base on symmetric otherwise base on asymmetric. The first element represents weights and second
element represents data flow. Default: (False, False) element represents data flow. Default: (False, False)


Loading…
Cancel
Save