Browse Source

!7365 [MSLITE]Support bit16 weight quant

Merge pull request !7365 from ghzl/support-int16-weight-quant
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
2d2e0c90b5
19 changed files with 203 additions and 127 deletions
  1. +0
    -54
      mindspore/lite/src/lite_kernel.cc
  2. +1
    -4
      mindspore/lite/src/lite_kernel.h
  3. +35
    -0
      mindspore/lite/src/runtime/kernel/arm/base/dequant.cc
  4. +80
    -0
      mindspore/lite/src/runtime/kernel/arm/base/dequant.h
  5. +2
    -1
      mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc
  6. +6
    -3
      mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc
  7. +4
    -5
      mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc
  8. +4
    -5
      mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc
  9. +4
    -5
      mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc
  10. +4
    -5
      mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc
  11. +2
    -4
      mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc
  12. +2
    -4
      mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc
  13. +6
    -5
      mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc
  14. +6
    -5
      mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc
  15. +6
    -5
      mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc
  16. +6
    -5
      mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.cc
  17. +5
    -5
      mindspore/lite/tools/converter/quantizer/quantize_util.h
  18. +27
    -10
      mindspore/lite/tools/converter/quantizer/weight_quantizer.cc
  19. +3
    -2
      mindspore/lite/tools/converter/quantizer/weight_quantizer.h

+ 0
- 54
mindspore/lite/src/lite_kernel.cc View File

@@ -180,58 +180,4 @@ void LiteKernelUtil::InitTensorRefCount(std::vector<kernel::LiteKernel *> &kerne
}

int LiteKernelUtil::SetInput(LiteKernel &kernelMod, std::vector<lite::Tensor *> inputs) { return -1; }

float *LiteKernelUtil::DequantWeight(lite::Tensor *input_tensor) {
MS_ASSERT(input_tensor != nullptr);
if (input_tensor->data_type() != kNumberTypeInt8) {
MS_LOG(ERROR) << "conv weight input type error" << input_tensor->data_type();
return nullptr;
}
if (input_tensor->GetQuantParams().empty()) {
MS_LOG(ERROR) << "no quant param";
return nullptr;
}
const auto *quant_datas = static_cast<const int8_t *>(input_tensor->MutableData());
auto *dequant_datas = static_cast<float *>(malloc(input_tensor->ElementsNum() * sizeof(float)));
if (dequant_datas == nullptr) {
MS_LOG(ERROR) << "malloc faile";
return nullptr;
}

if (input_tensor->GetQuantParams().size() != kPerTensor) {
size_t channels = static_cast<size_t>(input_tensor->Batch());
if (input_tensor->GetQuantParams().size() != channels) {
MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->GetQuantParams().size() << channels;
free(dequant_datas);
return nullptr;
}
size_t per_channel_size = input_tensor->ElementsNum() / channels;
auto quant_param = input_tensor->GetQuantParams();
for (size_t i = 0; i < channels; i++) {
auto param = quant_param.at(i);
auto scale = param.scale;
auto zero_point = param.zeroPoint;
auto var_corr = param.var_corr;
auto mean_corr = param.mean_corr;
if (var_corr < 0 || var_corr > 10) {
MS_LOG(WARNING) << "unexpeted var_corr: " << var_corr;
var_corr = 1;
}
for (size_t j = 0; j < per_channel_size; j++) {
auto dequant_data = (quant_datas[per_channel_size * i + j] - zero_point) * scale;
dequant_datas[per_channel_size * i + j] = static_cast<float>(dequant_data * var_corr + mean_corr);
}
}
} else {
auto quant_param = input_tensor->GetQuantParams();
auto param = quant_param.front();
auto scale = param.scale;
auto zero_point = param.zeroPoint;
for (int64_t j = 0; j < input_tensor->ElementsNum(); j++) {
dequant_datas[j] = static_cast<float>((quant_datas[j] - zero_point) * scale);
}
}

return dequant_datas;
}
} // namespace mindspore::kernel

+ 1
- 4
mindspore/lite/src/lite_kernel.h View File

@@ -16,8 +16,8 @@

#ifndef MINDSPORE_LITE_SRC_LITE_KERNEL_H_
#define MINDSPORE_LITE_SRC_LITE_KERNEL_H_
#include <vector>
#include <string>
#include <vector>
#include <memory>
#include "src/ops/primitive_c.h"
#include "src/common/utils.h"
@@ -31,7 +31,6 @@

static constexpr int kPerTensor = 1;

// using mindspore::kernel::AddressPtr;
namespace mindspore::kernel {
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
@@ -212,8 +211,6 @@ class LiteKernelUtil {
static void InitTensorRefCount(std::vector<kernel::LiteKernel *> &kernels);

static int SetInput(LiteKernel &kernelMod, std::vector<lite::Tensor *> inputs);

static float *DequantWeight(lite::Tensor *input_tensor);
};
} // namespace mindspore::kernel



+ 35
- 0
mindspore/lite/src/runtime/kernel/arm/base/dequant.cc View File

@@ -0,0 +1,35 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/base/dequant.h"

namespace mindspore::kernel {
float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) {
MS_ASSERT(input_tensor != nullptr);
if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) {
MS_LOG(ERROR) << "Conv weight input type error." << input_tensor->data_type();
return nullptr;
}
if (input_tensor->GetQuantParams().empty()) {
MS_LOG(ERROR) << "No quant param.";
return nullptr;
}
if (input_tensor->data_type() == kNumberTypeInt16) {
return DequantData<int16_t>(input_tensor);
} else {
return DequantData<int8_t>(input_tensor);
}
}
} // namespace mindspore::kernel

+ 80
- 0
mindspore/lite/src/runtime/kernel/arm/base/dequant.h View File

@@ -0,0 +1,80 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_

#include <vector>
#include "src/lite_kernel.h"
#include "src/common/utils.h"
#include "src/tensor.h"

namespace mindspore::kernel {
class DequantUtil {
public:
static float *DequantWeight(lite::Tensor *input_tensor);

template <typename T>
static float *DequantData(lite::Tensor *input_tensor) {
const auto *quant_datas = static_cast<const T *>(input_tensor->MutableData());
if (quant_datas == nullptr) {
MS_LOG(ERROR) << "Get quant tensor failed.";
return nullptr;
}
auto *dequant_datas = static_cast<float *>(malloc(input_tensor->ElementsNum() * sizeof(float)));
if (dequant_datas == nullptr) {
MS_LOG(ERROR) << "Malloc failed.";
return nullptr;
}
if (input_tensor->GetQuantParams().size() != kPerTensor) {
size_t channels = static_cast<size_t>(input_tensor->Batch());
if (input_tensor->GetQuantParams().size() != channels) {
MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->GetQuantParams().size() << channels;
free(dequant_datas);
return nullptr;
}
size_t per_channel_size = input_tensor->ElementsNum() / channels;
auto quant_param = input_tensor->GetQuantParams();
for (size_t i = 0; i < channels; i++) {
auto param = quant_param.at(i);
auto scale = param.scale;
auto zero_point = param.zeroPoint;
auto var_corr = param.var_corr;
auto mean_corr = param.mean_corr;
if (var_corr < 0 || var_corr > 10) {
MS_LOG(WARNING) << "unexpeted var_corr: " << var_corr;
var_corr = 1;
}
for (size_t j = 0; j < per_channel_size; j++) {
auto dequant_data = (quant_datas[per_channel_size * i + j] - zero_point) * scale;
dequant_datas[per_channel_size * i + j] = static_cast<float>(dequant_data * var_corr + mean_corr);
}
}
} else {
auto quant_param = input_tensor->GetQuantParams();
auto param = quant_param.front();
auto scale = param.scale;
auto zero_point = param.zeroPoint;
for (int64_t j = 0; j < input_tensor->ElementsNum(); j++) {
dequant_datas[j] = static_cast<float>((quant_datas[j] - zero_point) * scale);
}
}
return dequant_datas;
}
};
} // namespace mindspore::kernel

#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_

+ 2
- 1
mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc View File

@@ -20,6 +20,7 @@
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"
#include "src/runtime/kernel/arm/base/dequant.h"

using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
@@ -64,7 +65,7 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T
// data of second tensor of fc may be nullptr
auto *restore_data = weight_tensor->data_c();
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
return nullptr;


+ 6
- 3
mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc View File

@@ -19,6 +19,7 @@
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"
#include "src/runtime/kernel/arm/base/dequant.h"

using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
@@ -35,9 +36,11 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector<lite::Tensor *> &in

auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
auto is_const_quant_weight = (restore_data != nullptr) && (weight_tensor->data_type() == kNumberTypeInt8);
auto is_const_quant_weight =
(restore_data != nullptr) &&
((weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16));
if (is_const_quant_weight) {
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
@@ -49,7 +52,7 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector<lite::Tensor *> &in
auto input_tensor = inputs.at(kInputIndex);
auto data_type = input_tensor->data_type();
kernel::LiteKernel *kernel = nullptr;
if (data_type == kNumberTypeInt8 || data_type == kNumberTypeUInt8) {
if (data_type == kNumberTypeInt8) {
kernel = new (std::nothrow) MatmulInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
} else {
kernel = new (std::nothrow) MatmulCPUKernel(opParameter, inputs, outputs, ctx, primitive);


+ 4
- 5
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc View File

@@ -22,6 +22,7 @@
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"

using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
@@ -145,9 +146,10 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>

auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
auto dequant_flag = (weight_tensor->data_type() == kNumberTypeInt8) ? true : false;
auto dequant_flag =
(weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) ? true : false;
if (dequant_flag) {
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
@@ -169,7 +171,6 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
MS_LOG(ERROR) << "kernel is nullptr.";
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
free(opParameter);
@@ -182,14 +183,12 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
}
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return kernel;


+ 4
- 5
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc View File

@@ -27,6 +27,7 @@
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
#include "nnacl/fp16/winograd_utils_fp16.h"
#include "src/runtime/kernel/arm/base/dequant.h"

using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
@@ -183,9 +184,10 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &

auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
auto dequant_flag = (weight_tensor->data_type() == kNumberTypeInt8) ? true : false;
auto dequant_flag =
(weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) ? true : false;
if (dequant_flag) {
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
@@ -224,7 +226,6 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
MS_LOG(DEBUG) << "Create conv fp16 kernel failed.";
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
free(opParameter);
@@ -237,14 +238,12 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
}
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return kernel;


+ 4
- 5
mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc View File

@@ -20,6 +20,7 @@
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"

using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
@@ -208,9 +209,10 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor

auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
auto dequant_flag = (weight_tensor->data_type() == kNumberTypeInt8) ? true : false;
auto dequant_flag =
(weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) ? true : false;
if (dequant_flag) {
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
@@ -225,7 +227,6 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor
MS_LOG(ERROR) << "kernel is nullptr.";
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
free(opParameter);
@@ -238,14 +239,12 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
}
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return kernel;


+ 4
- 5
mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc View File

@@ -16,6 +16,7 @@

#include "src/runtime/kernel/arm/fp16/deconvolution_fp16.h"
#include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"

using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
@@ -215,9 +216,10 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>

auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
auto dequant_flag = (weight_tensor->data_type() == kNumberTypeInt8) ? true : false;
auto dequant_flag =
(weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) ? true : false;
if (dequant_flag) {
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
@@ -232,7 +234,6 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
MS_LOG(ERROR) << "kernel is nullptr.";
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
free(opParameter);
@@ -245,14 +246,12 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
}
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return kernel;


+ 2
- 4
mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc View File

@@ -20,6 +20,7 @@
#include "src/runtime/runtime_api.h"
#include "include/errorcode.h"
#include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/base/dequant.h"

using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
@@ -242,7 +243,7 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T
// data of second tensor of fc may be nullptr
auto *restore_data = weight_tensor->data_c();
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
@@ -256,7 +257,6 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T
MS_LOG(ERROR) << "kernel is nullptr.";
if (!weight_tensor->GetQuantParams().empty()) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
free(opParameter);
@@ -269,14 +269,12 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T
delete kernel;
if (!weight_tensor->GetQuantParams().empty()) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
}
if (!weight_tensor->GetQuantParams().empty()) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return kernel;


+ 2
- 4
mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc View File

@@ -20,6 +20,7 @@
#include "src/runtime/runtime_api.h"
#include "include/errorcode.h"
#include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/base/dequant.h"

using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
@@ -256,7 +257,7 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *>
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
return nullptr;
@@ -269,7 +270,6 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *>
MS_LOG(ERROR) << "kernel is nullptr.";
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
free(opParameter);
@@ -282,14 +282,12 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *>
delete kernel;
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
}
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return kernel;


+ 6
- 5
mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc View File

@@ -23,6 +23,7 @@
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"

using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
@@ -186,8 +187,8 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &

auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
if (weight_tensor->data_type() == kNumberTypeInt8) {
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(op_parameter);
@@ -207,7 +208,7 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
}
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
if (weight_tensor->data_type() == kNumberTypeInt8) {
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}
@@ -219,14 +220,14 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
if (weight_tensor->data_type() == kNumberTypeInt8) {
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}
return nullptr;
}

if (weight_tensor->data_type() == kNumberTypeInt8) {
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}


+ 6
- 5
mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc View File

@@ -20,6 +20,7 @@
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"

using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
@@ -133,8 +134,8 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>

auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
if (weight_tensor->data_type() == kNumberTypeInt8) {
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
@@ -152,7 +153,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
}
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
if (weight_tensor->data_type() == kNumberTypeInt8) {
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}
@@ -164,14 +165,14 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (weight_tensor->data_type() == kNumberTypeInt8) {
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}
return nullptr;
}

if (weight_tensor->data_type() == kNumberTypeInt8) {
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}


+ 6
- 5
mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc View File

@@ -17,6 +17,7 @@
#include "src/runtime/kernel/arm/fp32/deconvolution.h"
#include "src/runtime/kernel/arm/fp32/deconvolution_winograd.h"
#include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"

using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
@@ -238,8 +239,8 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *>
MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D);
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
if (weight_tensor->data_type() == kNumberTypeInt8) {
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
@@ -260,7 +261,7 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *>

if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
if (weight_tensor->data_type() == kNumberTypeInt8) {
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}
@@ -272,14 +273,14 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *>
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (weight_tensor->data_type() == kNumberTypeInt8) {
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}
return nullptr;
}

if (weight_tensor->data_type() == kNumberTypeInt8) {
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}


+ 6
- 5
mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.cc View File

@@ -19,6 +19,7 @@
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"

using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
@@ -201,8 +202,8 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor
MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D);
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
if (weight_tensor->data_type() == kNumberTypeInt8) {
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
@@ -214,7 +215,7 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor
new (std::nothrow) kernel::DeconvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
if (weight_tensor->data_type() == kNumberTypeInt8) {
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}
@@ -226,13 +227,13 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (weight_tensor->data_type() == kNumberTypeInt8) {
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}
return nullptr;
}
if (weight_tensor->data_type() == kNumberTypeInt8) {
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}


+ 5
- 5
mindspore/lite/tools/converter/quantizer/quantize_util.h View File

@@ -91,12 +91,12 @@ T QuantizeData(const float originData, const schema::QuantParamT *quantParam) {
const auto numBit = quantParam->numBits;
const auto narrowRange = quantParam->narrowRange;
double maxLimitTemp = static_cast<float>((1 << (unsigned int)numBit) - 1);
const double maxLimit = static_cast<float>(maxLimitTemp - zeroPoint + std::numeric_limits<int8_t>::min()) * scale;
const double maxLimit = static_cast<float>(maxLimitTemp - zeroPoint + std::numeric_limits<T>::min()) * scale;
double minLimit;
if (narrowRange) {
minLimit = static_cast<float>(std::numeric_limits<int8_t>::min() + 1 - zeroPoint) * scale;
minLimit = static_cast<float>(std::numeric_limits<T>::min() + 1 - zeroPoint) * scale;
} else {
minLimit = static_cast<float>(std::numeric_limits<int8_t>::min() - zeroPoint) * scale;
minLimit = static_cast<float>(std::numeric_limits<T>::min() - zeroPoint) * scale;
}

return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] {
@@ -244,7 +244,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
}
quant_params.emplace_back(quant_param);
}
auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t));
auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(T));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret;
return RET_ERROR;
@@ -273,7 +273,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min);
quant_datas[i] = quant_data;
}
auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t));
auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(T));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret;
return RET_ERROR;


+ 27
- 10
mindspore/lite/tools/converter/quantizer/weight_quantizer.cc View File

@@ -48,8 +48,8 @@ STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) {
MS_LOG(ERROR) << "quantSize must be valid pos num.";
return RET_ERROR;
}
if (!WeightQuantizer::IsPosNum(config->bitNum) || config->bitNum != "8") {
MS_LOG(ERROR) << "bitNum must be valid pos num, current only support 8 bit weight quant.";
if (!WeightQuantizer::IsPosNum(config->bitNum) || (config->bitNum != "8" && config->bitNum != "16")) {
MS_LOG(ERROR) << "bitNum must be valid pos num, current only support 8 or 16 bit weight quant.";
return RET_ERROR;
}
return RET_OK;
@@ -61,6 +61,13 @@ WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const string &weightSize,
this->bitNum = static_cast<size_t>(std::stoull(bitNum));
auto convQuantWeightChannelThreshold = static_cast<size_t>(std::stoull(convWeightChannelThreshold));
mStrategy.reset(new QuantStrategy(quantSize, convQuantWeightChannelThreshold));
quant_max = (1 << (unsigned int)(this->bitNum - 1)) - 1;
quant_min = -(1 << (unsigned int)(this->bitNum - 1));
if (this->bitNum == 8) {
type_id = kNumberTypeInt8;
} else if (this->bitNum == 16) {
type_id = kNumberTypeInt16;
}
}

STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
@@ -96,14 +103,19 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {

std::vector<schema::QuantParamT> quant_params;
primitive_c->AddInputQuantParam(quant_params);
auto status =
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true);
auto status = RET_ERROR;
if (type_id == kNumberTypeInt8) {
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true);
} else if (type_id == kNumberTypeInt16) {
status =
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
// set dtype
param_value->set_tensor_type(kNumberTypeInt8);
param_value->set_tensor_type(type_id);
auto abstractBase = param_node->abstract();
if (abstractBase == nullptr) {
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
@@ -114,7 +126,7 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
return RET_ERROR;
}
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
abstractTensor->element()->set_type(TypeIdToType(kNumberTypeInt8));
abstractTensor->element()->set_type(TypeIdToType(type_id));
primitive_c->SetQuantType(schema::QuantType_WeightQuant);
}

@@ -159,13 +171,18 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) {

std::vector<schema::QuantParamT> quant_params;
primitive_c->AddInputQuantParam(quant_params);
auto status =
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true);
auto status = RET_ERROR;
if (type_id == kNumberTypeInt8) {
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true);
} else if (type_id == kNumberTypeInt16) {
status =
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
param_value->set_tensor_type(kNumberTypeInt8);
param_value->set_tensor_type(type_id);
// set dtype
auto abstractBase = param_node->abstract();
if (abstractBase == nullptr) {
@@ -177,7 +194,7 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) {
return RET_ERROR;
}
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
abstractTensor->element()->set_type(TypeIdToType(kNumberTypeInt8));
abstractTensor->element()->set_type(TypeIdToType(type_id));
primitive_c->SetQuantType(schema::QuantType_WeightQuant);
}



+ 3
- 2
mindspore/lite/tools/converter/quantizer/weight_quantizer.h View File

@@ -43,8 +43,9 @@ class WeightQuantizer : public Quantizer {
STATUS DoMulQuantize(const std::list<CNodePtr> &nodes);
static STATUS WeightQuantInputCheck(const converter::Flags *config);
static bool IsPosNum(const std::string &str);
int quant_max{INT8_MAX};
int quant_min{INT8_MIN};
int quant_max;
int quant_min;
TypeId type_id{kTypeUnknown};

private:
std::unique_ptr<QuantStrategy> mStrategy;


Loading…
Cancel
Save