Browse Source

add QuantDtype and Observer

tags/v1.1.0
yuchaojie 5 years ago
parent
commit
025ea2f392
5 changed files with 345 additions and 2 deletions
  1. +1
    -0
      cmake/package.cmake
  2. +17
    -0
      mindspore/compression/__init__.py
  3. +19
    -0
      mindspore/compression/common/__init__.py
  4. +85
    -0
      mindspore/compression/common/constant.py
  5. +223
    -2
      mindspore/nn/layer/quant.py

+ 1
- 0
cmake/package.cmake View File

@@ -248,6 +248,7 @@ install(
${CMAKE_SOURCE_DIR}/mindspore/ops ${CMAKE_SOURCE_DIR}/mindspore/ops
${CMAKE_SOURCE_DIR}/mindspore/communication ${CMAKE_SOURCE_DIR}/mindspore/communication
${CMAKE_SOURCE_DIR}/mindspore/profiler ${CMAKE_SOURCE_DIR}/mindspore/profiler
${CMAKE_SOURCE_DIR}/mindspore/compression
DESTINATION ${INSTALL_PY_DIR} DESTINATION ${INSTALL_PY_DIR}
COMPONENT mindspore COMPONENT mindspore
) )


+ 17
- 0
mindspore/compression/__init__.py View File

@@ -0,0 +1,17 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
MindSpore compression module.
"""

+ 19
- 0
mindspore/compression/common/__init__.py View File

@@ -0,0 +1,19 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Compression common module.
"""

from .constant import *

+ 85
- 0
mindspore/compression/common/constant.py View File

@@ -0,0 +1,85 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Constant module for compression"""
import enum
import re
from types import DynamicClassAttribute


__all__ = ["QuantDtype"]


@enum.unique
class QuantDtype(enum.Enum):
"""
For type switch
"""
INT2 = "INT2"
INT3 = "INT3"
INT4 = "INT4"
INT5 = "INT5"
INT6 = "INT6"
INT7 = "INT7"
INT8 = "INT8"

UINT2 = "UINT2"
UINT3 = "UINT3"
UINT4 = "UINT4"
UINT5 = "UINT5"
UINT6 = "UINT6"
UINT7 = "UINT7"
UINT8 = "UINT8"

FLOAT16 = "FLOAT16"
FLOAT32 = "FLOAT32"

def __str__(self):
return f"{self.name}"

@staticmethod
def is_signed(dtype):
return dtype in [QuantDtype.INT2, QuantDtype.INT3, QuantDtype.INT4, QuantDtype.INT5,
QuantDtype.INT6, QuantDtype.INT7, QuantDtype.INT8]

@staticmethod
def switch_signed(dtype):
"""switch signed"""
type_map = {
QuantDtype.INT2: QuantDtype.UINT2,
QuantDtype.INT3: QuantDtype.UINT3,
QuantDtype.INT4: QuantDtype.UINT4,
QuantDtype.INT5: QuantDtype.UINT5,
QuantDtype.INT6: QuantDtype.UINT6,
QuantDtype.INT7: QuantDtype.UINT7,
QuantDtype.INT8: QuantDtype.UINT8,
QuantDtype.UINT2: QuantDtype.INT2,
QuantDtype.UINT3: QuantDtype.INT3,
QuantDtype.UINT4: QuantDtype.INT4,
QuantDtype.UINT5: QuantDtype.INT5,
QuantDtype.UINT6: QuantDtype.INT6,
QuantDtype.UINT7: QuantDtype.INT7,
QuantDtype.UINT8: QuantDtype.INT8
}
return type_map[dtype]

@DynamicClassAttribute
def value(self):
"""The value of the Enum member."""
return int(re.search(r"(\d+)", self._value_).group(1))

@DynamicClassAttribute
def num_bits(self):
"""The num_bits of the Enum member."""
return self.value

+ 223
- 2
mindspore/nn/layer/quant.py View File

@@ -24,6 +24,7 @@ from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator, Rel, twice from mindspore._checkparam import Validator, Rel, twice
from mindspore.compression.common import QuantDtype
import mindspore.context as context import mindspore.context as context
from .normalization import BatchNorm2d, BatchNorm1d from .normalization import BatchNorm2d, BatchNorm1d
from .activation import get_activation, ReLU, LeakyReLU from .activation import get_activation, ReLU, LeakyReLU
@@ -277,13 +278,233 @@ class BatchNormFoldCell(Cell):
return batch_mean, batch_std, running_mean, running_std return batch_mean, batch_std, running_mean, running_std




def _partial_init(cls_or_self, **kwargs):
"""
Wrapper that allows creation of class factories.

This can be useful when there is a need to create classes with the same
constructor arguments, but different instances.

Example::
>>> Foo.partial_init = classmethod(_partial_init)
>>> foo_builder = Foo.partial_init(a=3, b=4).partial_init(answer=42)
>>> foo_instance1 = foo_builder()
>>> foo_instance2 = foo_builder()
>>> id(foo_instance1) == id(foo_instance2)
False
"""

class _PartialWrapper:
r"""
class of wrapper that allows creation of class factories.
"""

def __init__(self, p):
self.p = p

def __call__(self, *args, **keywords):
return self.p(*args, **keywords)

def __repr__(self):
return self.p.__repr__()

partial_init = _partial_init

r = _PartialWrapper(partial(cls_or_self, **kwargs))
return r


class Observer(Cell):
"""
Base class of Observer. Observer is used to calculate the statistics of specific layer.

Notes:
This class is an abstract class.

Args:
quant_dtype (QuantDtype): The type of FakeQuant data.
"""

def __init__(self, quant_dtype):
super(Observer, self).__init__()
self.quant_dtype = quant_dtype

def extend_repr(self):
s = f"dtype={self.dtype}"
return s

def construct(self):
pass

partial_init = classmethod(_partial_init)


class UniformQuantObserver(Observer):
"""
The base class of Uniform Quantization Observer.

Args:
quant_dtype (QuantDtype): The type of FakeQuant data. Default: QuantDtype.INT8.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
num_channels (int): declarate the min and max channel size, Default: 1.

Returns:
Tensor.
"""

min_max_map = {
QuantDtype.INT2: (-2, 1),
QuantDtype.INT3: (-4, 3),
QuantDtype.INT4: (-8, 7),
QuantDtype.INT5: (-16, 15),
QuantDtype.INT6: (-32, 31),
QuantDtype.INT7: (-64, 63),
QuantDtype.INT8: (-128, 127),

QuantDtype.UINT2: (0, 3),
QuantDtype.UINT3: (0, 7),
QuantDtype.UINT4: (0, 15),
QuantDtype.UINT5: (0, 31),
QuantDtype.UINT6: (0, 63),
QuantDtype.UINT7: (0, 127),
QuantDtype.UINT8: (0, 255)
}

def __init__(self, quant_dtype=QuantDtype.INT8, per_channel=False, symmetric=False, narrow_range=False,
num_channels=1):
super(UniformQuantObserver, self).__init__(quant_dtype)
self.per_channel = per_channel
self.symmetric = symmetric
self.narrow_range = narrow_range
self.num_channels = num_channels


class FakeQuantWithMinMaxObserver(UniformQuantObserver):
r"""
Quantization aware op. This OP provides the fake quantization observer function on data with min and max.

Args:
min_init (int, float): The initialized min value. Default: -6.
max_init (int, float): The initialized max value. Default: 6.
ema (bool): The exponential Moving Average algorithm updates min and max. Default: False.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
channel_axis (int): Quantization by channel axis. Default: 1.
num_channels (int): declarate the min and max channel size, Default: 1.
quant_dtype (QuantDtype): The datatype of quantization, supporting 4 and 8bits. Default: QuantDtype.INT8.
symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.

Inputs:
- **x** (Tensor) - The input of FakeQuantWithMinMaxObserver.

Outputs:
Tensor, with the same type and shape as the `x`.

Examples:
>>> fake_quant = FakeQuantWithMinMaxObserver()
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result = fake_quant(input_x)
"""

def __init__(self,
min_init=-6,
max_init=6,
ema=False,
ema_decay=0.999,
per_channel=False,
channel_axis=1,
num_channels=1,
quant_dtype=QuantDtype.INT8,
symmetric=False,
narrow_range=False,
quant_delay=0):
"""Initialize FakeQuantWithMinMax layer"""
super(FakeQuantWithMinMaxObserver, self).__init__(quant_dtype=quant_dtype, per_channel=per_channel,
symmetric=symmetric, narrow_range=narrow_range,
num_channels=num_channels)
Validator.check_type("min_init", min_init, [int, float])
Validator.check_type("max_init", max_init, [int, float])
Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
Validator.check_integer('quant_delay', quant_delay, 0, Rel.GE)
self.min_init = min_init
self.max_init = max_init
self.quant_dtype = quant_dtype
self.ema = ema
self.ema_decay = ema_decay
self.per_channel = per_channel
self.num_channels = num_channels
self.channel_axis = channel_axis
self.quant_delay = quant_delay
self.symmetric = symmetric
self.narrow_range = narrow_range
self.is_ascend = context.get_context('device_target') == "Ascend"

# init tensor min and max for fake quant op
if self.per_channel:
min_array = np.array([self.min_init] * self.num_channels).astype(np.float32)
max_array = np.array([self.max_init] * self.num_channels).astype(np.float32)
else:
min_array = np.array([self.min_init]).astype(np.float32)
max_array = np.array([self.max_init]).astype(np.float32)
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)

# init fake quant relative op
if self.per_channel:
quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis)
ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis)
else:
quant_fun = Q.FakeQuantPerLayer
ema_fun = Q.MinMaxUpdatePerLayer

self.ema_update = ema_fun(ema=self.ema, ema_decay=self.ema_decay)
if self.is_ascend:
self.fake_quant_train = quant_fun(num_bits=self.quant_dtype.num_bits,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
quant_delay=self.quant_delay)
self.fake_quant_infer = self.fake_quant_train
else:
quant_fun = partial(quant_fun,
ema=self.ema,
ema_decay=ema_decay,
num_bits=self.quant_dtype.num_bits,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
quant_delay=self.quant_delay)
self.fake_quant_train = quant_fun(training=True)
self.fake_quant_infer = quant_fun(training=False)

def extend_repr(self):
s = 'quant_dtype={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \
'quant_delay={}, min_init={}, max_init={}'.format(self.quant_dtype, self.symmetric, self.narrow_range,
self.ema, self.ema_decay, self.per_channel,
self.channel_axis, self.num_channels, self.quant_delay,
self.min_init, self.max_init)
return s

def construct(self, x):
if self.training:
min_up, max_up = self.ema_update(x, self.minq, self.maxq)
P.Assign()(self.minq, min_up)
P.Assign()(self.maxq, max_up)
out = self.fake_quant_train(x, self.minq, self.maxq)
else:
out = self.fake_quant_infer(x, self.minq, self.maxq)
return out


class FakeQuantWithMinMax(Cell): class FakeQuantWithMinMax(Cell):
r""" r"""
Quantization aware op. This OP provides the fake quantization observer function on data with min and max. Quantization aware op. This OP provides the fake quantization observer function on data with min and max.


Args: Args:
min_init (int, float): The dimension of channel or 1(layer). Default: -6.
max_init (int, float): The dimension of channel or 1(layer). Default: 6.
min_init (int, float): The initialized min value. Default: -6.
max_init (int, float): The initialized max value. Default: 6.
ema (bool): The exponential Moving Average algorithm updates min and max. Default: False. ema (bool): The exponential Moving Average algorithm updates min and max. Default: False.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False. per_channel (bool): Quantization granularity based on layer or on channel. Default: False.


Loading…
Cancel
Save