|
|
|
@@ -24,13 +24,12 @@ from mindspore.ops import functional as F |
|
|
|
from mindspore.common.parameter import Parameter |
|
|
|
from mindspore.common.initializer import initializer |
|
|
|
from mindspore.common.tensor import Tensor |
|
|
|
from mindspore._checkparam import Validator, Rel, twice |
|
|
|
from mindspore._checkparam import Validator, twice |
|
|
|
from mindspore.compression.common import QuantDtype |
|
|
|
import mindspore.context as context |
|
|
|
from .normalization import BatchNorm2d |
|
|
|
from .activation import get_activation, ReLU |
|
|
|
from ..cell import Cell |
|
|
|
from ... import nn |
|
|
|
from ...ops.operations import _quant_ops as Q |
|
|
|
|
|
|
|
__all__ = [ |
|
|
|
@@ -381,10 +380,6 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver): |
|
|
|
num_channels=num_channels) |
|
|
|
Validator.check_value_type("min_init", min_init, [int, float, list], type(self).__name__) |
|
|
|
Validator.check_value_type("max_init", max_init, [int, float, list], type(self).__name__) |
|
|
|
if isinstance(max_init, (int, float)) and isinstance(min_init, (int, float)): |
|
|
|
Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT) |
|
|
|
elif not np.greater(max_init, min_init).all(): |
|
|
|
raise ValueError("`min_init` is not less than `max_init`, please reset the initial value.") |
|
|
|
Validator.check_non_negative_int(quant_delay, 'quant_delay') |
|
|
|
self.min_init = min_init |
|
|
|
self.max_init = max_init |
|
|
|
@@ -405,7 +400,8 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver): |
|
|
|
|
|
|
|
min_array = self._get_init_array(self.min_init) |
|
|
|
max_array = self._get_init_array(self.max_init) |
|
|
|
|
|
|
|
if not np.greater(max_array, min_array).all(): |
|
|
|
raise ValueError("`min_init` is not less than `max_init`, please reset the initial value.") |
|
|
|
if self.mode == "DEFAULT": |
|
|
|
# init tensor min and max for fake quantized operation |
|
|
|
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) |
|
|
|
@@ -441,7 +437,9 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver): |
|
|
|
raise ValueError("The 'LEARNED_SCALE' mode only support symmetric quant, please set symmetric to True.") |
|
|
|
if self.neg_trunc: |
|
|
|
min_array = self._get_init_array(0) |
|
|
|
self.narrow_range = False |
|
|
|
if self.narrow_range: |
|
|
|
raise ValueError("The 'LEARNED_SCALE' mode only support the combination of " |
|
|
|
"neg_trunc=True and narrow_range=False config scenario.") |
|
|
|
elif not self.narrow_range: |
|
|
|
raise ValueError("The 'LEARNED_SCALE' mode only support narrow_range=True config, " |
|
|
|
"except for neg_trunc=True scenario.") |
|
|
|
@@ -483,6 +481,8 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver): |
|
|
|
self.max_init = max_init |
|
|
|
min_array = self._get_init_array(self.min_init) |
|
|
|
max_array = self._get_init_array(self.max_init) |
|
|
|
if not np.greater(max_array, min_array).all(): |
|
|
|
raise ValueError("`min_init` is not less than `max_init`, please reset the initial value.") |
|
|
|
self.minq.set_data(Tensor(min_array)) |
|
|
|
self.maxq.set_data(Tensor(max_array)) |
|
|
|
self.quant_max.set_data(Tensor(np.array([self._quant_max]).astype(np.float32))) |
|
|
|
@@ -494,10 +494,10 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver): |
|
|
|
Convert the initial value to array. |
|
|
|
""" |
|
|
|
if isinstance(init_date, list) and self.per_channel and len(init_date) != self.num_channels: |
|
|
|
raise ValueError("The length of the min_init/max_init list shuold be equal to num_channels for " |
|
|
|
raise ValueError("The length of the min_init/max_init list should be equal to num_channels for " |
|
|
|
"perchannel quant scenario, but get {}".format(len(init_date))) |
|
|
|
if isinstance(init_date, list) and not self.per_channel and len(init_date) != 1: |
|
|
|
raise ValueError("The length of the min_init/max_init list shuold be 1 for perlayer quant " |
|
|
|
raise ValueError("The length of the min_init/max_init list should be 1 for perlayer quant " |
|
|
|
"scenario, but get {}".format(len(init_date))) |
|
|
|
|
|
|
|
if isinstance(init_date, list): |
|
|
|
@@ -1343,8 +1343,6 @@ class ActQuant(_QuantActivation): |
|
|
|
quant_config=quant_config_default, |
|
|
|
quant_dtype=QuantDtype.INT8): |
|
|
|
super(ActQuant, self).__init__() |
|
|
|
act_class = activation.__class__ |
|
|
|
act_list = [nn.ReLU, nn.ReLU6] |
|
|
|
self.act = Validator.check_isinstance("activation", activation, Cell) |
|
|
|
self.fake_before = Validator.check_bool(fake_before, "fake_before") |
|
|
|
if self.fake_before: |
|
|
|
@@ -1353,14 +1351,11 @@ class ActQuant(_QuantActivation): |
|
|
|
ema=ema, |
|
|
|
ema_decay=ema_decay, |
|
|
|
quant_dtype=quant_dtype) |
|
|
|
|
|
|
|
neg_trunc = bool(act_class in act_list) |
|
|
|
self.fake_quant_act = quant_config.activation(min_init=-6, |
|
|
|
max_init=6, |
|
|
|
ema=ema, |
|
|
|
ema_decay=ema_decay, |
|
|
|
quant_dtype=quant_dtype, |
|
|
|
neg_trunc=neg_trunc) |
|
|
|
quant_dtype=quant_dtype) |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
if self.fake_before: |
|
|
|
|