Browse Source

!2209 add function for quant_utils that convert float to int

Merge pull request !2209 from chenzhongming/master
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
b391eb2b29
1 changed files with 33 additions and 1 deletions
  1. +33
    -1
      mindspore/train/quant/quant_utils.py

+ 33
- 1
mindspore/train/quant/quant_utils.py View File

@@ -19,6 +19,7 @@ import numpy as np

def cal_quantization_params(input_min,
input_max,
data_type,
num_bits=8,
symmetric=False,
narrow_range=False):
@@ -28,6 +29,7 @@ def cal_quantization_params(input_min,
Args:
input_min (int, list): The dimension of channel or 1.
input_max (int, list): The dimension of channel or 1.
data_type (numpy type) : Can ben numpy int8, numpy uint8.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
@@ -52,7 +54,7 @@ def cal_quantization_params(input_min,
# scale = 1.0, zp = 0.0
return np.ones(input_min.shape), np.zeros(input_min.shape)

if symmetric:
if data_type == np.int8:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1)
else:
@@ -84,3 +86,33 @@ def cal_quantization_params(input_min,
zp = np.floor(zp_double + 0.5)

return scale, zp


def weight2int(data,
scale,
zero_point):
r"""
calculate int8/uint8 weight from fp32. the formula is defined as:

.. math::

int8/uint8 = round(float/scale) + offset

Args:
data (int, list): The dimension of channel or 1. Should be NCHW.
scale (int, list): The dimension of channel or 1.
zero_point (int, list): The dimension of channel or 1.

Outputs:
weight (int, list): The dimension of channel or 1.

Examples:
>>> weight = weight2int([1, 2, 1], 1, 0)
"""
if scale.shape != zero_point.shape:
raise ValueError("scale and zero_point should have the same shape.")
if scale.shape[0] > 0:
scale = scale.reshape(1, -1, 1, 1)
zero_point = zero_point.reshape(1, -1, 1, 1)

return np.round((data/scale) + zero_point)

Loading…
Cancel
Save