# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Quantization function.""" from mindspore.common import dtype as mstype from mindspore.common.parameter import Parameter from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.ops import composite as C from mindspore import nn class QuantizeWeightCell(nn.Cell): """ The ternary fake quant op for weight. Args: num_bits (int): The bit number of quantization, supporting 2 to 8 bits. Default: 2. compute_type (:class:`mindspore.dtype`): Compute type in QuantizeWeightCell. Default: mstype.float32. clip_value (float): Clips weight to be in [-clip_value, clip_value]. per_channel (bool): Quantization granularity based on layer or on channel. Default: False. Inputs: - **weight** (Parameter) - Parameter of shape :math:`(N, C_{in}, H_{in}, W_{in})`. Outputs: Parameter of shape :math:`(N, C_{out}, H_{out}, W_{out})`. """ def __init__(self, num_bits=8, compute_type=mstype.float32, clip_value=1.0, per_channel=False): super(QuantizeWeightCell, self).__init__() self.num_bits = num_bits self.compute_type = compute_type self.clip_value = clip_value self.per_channel = per_channel self.clamp = C.clip_by_value self.abs = P.Abs() self.sum = P.ReduceSum() self.nelement = F.size self.div = P.Div() self.cast = P.Cast() self.max = P.ReduceMax() self.min = P.ReduceMin() self.round = P.Round() def construct(self, weight): """quantize weight cell""" tensor = self.clamp(weight, -self.clip_value, self.clip_value) if self.num_bits == 2: if self.per_channel: n = self.nelement(tensor[0]) m = self.div(self.sum(self.abs(tensor), 1), n) thres = 0.7 * m pos = self.cast(tensor[:] > thres[0], self.compute_type) neg = self.cast(tensor[:] < -thres[0], self.compute_type) mask = self.cast(self.abs(tensor)[:] > thres[0], self.compute_type) alpha = self.reshape(self.sum(self.abs(mask * tensor), 1) / self.sum(mask, 1), (-1, 1)) output = alpha * pos - alpha * neg else: n = self.nelement(tensor) m = self.div(self.sum(self.abs(tensor)), n) thres = 0.7 * m pos = self.cast(tensor > thres, self.compute_type) neg = self.cast(tensor < -thres, self.compute_type) mask = self.cast(self.abs(tensor) > thres, self.compute_type) alpha = self.sum(self.abs(mask * self.cast(tensor, self.compute_type))) / self.sum(mask) output = alpha * pos - alpha * neg else: tensor_max = self.cast(self.max(tensor), self.compute_type) tensor_min = self.cast(self.min(tensor), self.compute_type) s = (tensor_max - tensor_min) / (2 ** self.cast(self.num_bits, self.compute_type) - 1) output = self.round(self.div(tensor - tensor_min, s)) * s + tensor_min return output class QuantizeWeight: """ Quantize weight into specified bit. Args: num_bits (int): The bit number of quantization, supporting 2 to 8 bits. Default: 2. compute_type (:class:`mindspore.dtype`): Compute type in QuantizeWeightCell. Default: mstype.float32. clip_value (float): Clips weight to be in [-clip_value, clip_value]. per_channel (bool): Quantization granularity based on layer or on channel. Default: False. Inputs: - **weight** (Parameter) - Parameter of shape :math:`(N, C_{in}, H_{in}, W_{in})`. Outputs: Parameter of shape :math:`(N, C_{out}, H_{out}, W_{out})`. """ def __init__(self, num_bits=2, compute_type=mstype.float32, clip_value=1.0, per_channel=False): self.num_bits = num_bits self.compute_type = compute_type self.clip_value = clip_value self.per_channel = per_channel self.clamp = C.clip_by_value self.abs = P.Abs() self.sum = P.ReduceSum() self.nelement = F.size self.div = P.Div() self.cast = P.Cast() self.max = P.ReduceMax() self.min = P.ReduceMin() self.floor = P.Floor() def construct(self, weight): """quantize weight""" tensor = self.clamp(weight, -self.clip_value, self.clip_value) if self.num_bits == 2: if self.per_channel: n = self.nelement(tensor[0]) m = self.div(self.sum(self.abs(tensor), 1), n) thres = 0.7 * m pos = self.cast(tensor[:] > thres[0], self.compute_type) neg = self.cast(tensor[:] < -thres[0], self.compute_type) mask = self.cast(self.abs(tensor)[:] > thres[0], self.compute_type) alpha = self.reshape(self.sum(self.abs(mask * tensor), 1) / self.sum(mask, 1), (-1, 1)) output = alpha * pos - alpha * neg else: n = self.nelement(tensor) m = self.div(self.sum(self.abs(tensor)), n) thres = 0.7 * m pos = self.cast(tensor > thres, self.compute_type) neg = self.cast(tensor < -thres, self.compute_type) mask = self.cast(self.abs(tensor) > thres, self.compute_type) alpha = self.sum(self.abs(mask * tensor)) / self.sum(mask) output = alpha * pos - alpha * neg else: tensor_max = self.max(tensor) tensor_min = self.min(tensor) s = (tensor_max - tensor_min) / (2 ** self.num_bits - 1) output = self.floor(self.div((tensor - tensor_min), s) + 0.5) * s + tensor_min return output def convert_network(network, embedding_bits=2, weight_bits=2, clip_value=1.0): quantize_embedding = QuantizeWeight(num_bits=embedding_bits, clip_value=clip_value) quantize_weight = QuantizeWeight(num_bits=weight_bits, clip_value=clip_value) for name, param in network.parameters_and_names(): if 'bert_embedding_lookup' in name and 'min' not in name and 'max' not in name: quantized_param = quantize_embedding.construct(param) param.set_data(quantized_param) elif 'weight' in name and 'dense_1' not in name: quantized_param = quantize_weight.construct(param) param.set_data(quantized_param) def save_params(network): return {name: Parameter(param, 'saved_params') for name, param in network.parameters_and_names()} def restore_params(network, params_dict): for name, param in network.parameters_and_names(): param.set_data(params_dict[name])