Merge pull request !7365 from ghzl/support-int16-weight-quanttags/v1.1.0
| @@ -180,58 +180,4 @@ void LiteKernelUtil::InitTensorRefCount(std::vector<kernel::LiteKernel *> &kerne | |||
| } | |||
| int LiteKernelUtil::SetInput(LiteKernel &kernelMod, std::vector<lite::Tensor *> inputs) { return -1; } | |||
| float *LiteKernelUtil::DequantWeight(lite::Tensor *input_tensor) { | |||
| MS_ASSERT(input_tensor != nullptr); | |||
| if (input_tensor->data_type() != kNumberTypeInt8) { | |||
| MS_LOG(ERROR) << "conv weight input type error" << input_tensor->data_type(); | |||
| return nullptr; | |||
| } | |||
| if (input_tensor->GetQuantParams().empty()) { | |||
| MS_LOG(ERROR) << "no quant param"; | |||
| return nullptr; | |||
| } | |||
| const auto *quant_datas = static_cast<const int8_t *>(input_tensor->MutableData()); | |||
| auto *dequant_datas = static_cast<float *>(malloc(input_tensor->ElementsNum() * sizeof(float))); | |||
| if (dequant_datas == nullptr) { | |||
| MS_LOG(ERROR) << "malloc faile"; | |||
| return nullptr; | |||
| } | |||
| if (input_tensor->GetQuantParams().size() != kPerTensor) { | |||
| size_t channels = static_cast<size_t>(input_tensor->Batch()); | |||
| if (input_tensor->GetQuantParams().size() != channels) { | |||
| MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->GetQuantParams().size() << channels; | |||
| free(dequant_datas); | |||
| return nullptr; | |||
| } | |||
| size_t per_channel_size = input_tensor->ElementsNum() / channels; | |||
| auto quant_param = input_tensor->GetQuantParams(); | |||
| for (size_t i = 0; i < channels; i++) { | |||
| auto param = quant_param.at(i); | |||
| auto scale = param.scale; | |||
| auto zero_point = param.zeroPoint; | |||
| auto var_corr = param.var_corr; | |||
| auto mean_corr = param.mean_corr; | |||
| if (var_corr < 0 || var_corr > 10) { | |||
| MS_LOG(WARNING) << "unexpeted var_corr: " << var_corr; | |||
| var_corr = 1; | |||
| } | |||
| for (size_t j = 0; j < per_channel_size; j++) { | |||
| auto dequant_data = (quant_datas[per_channel_size * i + j] - zero_point) * scale; | |||
| dequant_datas[per_channel_size * i + j] = static_cast<float>(dequant_data * var_corr + mean_corr); | |||
| } | |||
| } | |||
| } else { | |||
| auto quant_param = input_tensor->GetQuantParams(); | |||
| auto param = quant_param.front(); | |||
| auto scale = param.scale; | |||
| auto zero_point = param.zeroPoint; | |||
| for (int64_t j = 0; j < input_tensor->ElementsNum(); j++) { | |||
| dequant_datas[j] = static_cast<float>((quant_datas[j] - zero_point) * scale); | |||
| } | |||
| } | |||
| return dequant_datas; | |||
| } | |||
| } // namespace mindspore::kernel | |||
| @@ -16,8 +16,8 @@ | |||
| #ifndef MINDSPORE_LITE_SRC_LITE_KERNEL_H_ | |||
| #define MINDSPORE_LITE_SRC_LITE_KERNEL_H_ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "src/ops/primitive_c.h" | |||
| #include "src/common/utils.h" | |||
| @@ -31,7 +31,6 @@ | |||
| static constexpr int kPerTensor = 1; | |||
| // using mindspore::kernel::AddressPtr; | |||
| namespace mindspore::kernel { | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| @@ -212,8 +211,6 @@ class LiteKernelUtil { | |||
| static void InitTensorRefCount(std::vector<kernel::LiteKernel *> &kernels); | |||
| static int SetInput(LiteKernel &kernelMod, std::vector<lite::Tensor *> inputs); | |||
| static float *DequantWeight(lite::Tensor *input_tensor); | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,35 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| namespace mindspore::kernel { | |||
| float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) { | |||
| MS_ASSERT(input_tensor != nullptr); | |||
| if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) { | |||
| MS_LOG(ERROR) << "Conv weight input type error." << input_tensor->data_type(); | |||
| return nullptr; | |||
| } | |||
| if (input_tensor->GetQuantParams().empty()) { | |||
| MS_LOG(ERROR) << "No quant param."; | |||
| return nullptr; | |||
| } | |||
| if (input_tensor->data_type() == kNumberTypeInt16) { | |||
| return DequantData<int16_t>(input_tensor); | |||
| } else { | |||
| return DequantData<int8_t>(input_tensor); | |||
| } | |||
| } | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,80 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/tensor.h" | |||
| namespace mindspore::kernel { | |||
| class DequantUtil { | |||
| public: | |||
| static float *DequantWeight(lite::Tensor *input_tensor); | |||
| template <typename T> | |||
| static float *DequantData(lite::Tensor *input_tensor) { | |||
| const auto *quant_datas = static_cast<const T *>(input_tensor->MutableData()); | |||
| if (quant_datas == nullptr) { | |||
| MS_LOG(ERROR) << "Get quant tensor failed."; | |||
| return nullptr; | |||
| } | |||
| auto *dequant_datas = static_cast<float *>(malloc(input_tensor->ElementsNum() * sizeof(float))); | |||
| if (dequant_datas == nullptr) { | |||
| MS_LOG(ERROR) << "Malloc failed."; | |||
| return nullptr; | |||
| } | |||
| if (input_tensor->GetQuantParams().size() != kPerTensor) { | |||
| size_t channels = static_cast<size_t>(input_tensor->Batch()); | |||
| if (input_tensor->GetQuantParams().size() != channels) { | |||
| MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->GetQuantParams().size() << channels; | |||
| free(dequant_datas); | |||
| return nullptr; | |||
| } | |||
| size_t per_channel_size = input_tensor->ElementsNum() / channels; | |||
| auto quant_param = input_tensor->GetQuantParams(); | |||
| for (size_t i = 0; i < channels; i++) { | |||
| auto param = quant_param.at(i); | |||
| auto scale = param.scale; | |||
| auto zero_point = param.zeroPoint; | |||
| auto var_corr = param.var_corr; | |||
| auto mean_corr = param.mean_corr; | |||
| if (var_corr < 0 || var_corr > 10) { | |||
| MS_LOG(WARNING) << "unexpeted var_corr: " << var_corr; | |||
| var_corr = 1; | |||
| } | |||
| for (size_t j = 0; j < per_channel_size; j++) { | |||
| auto dequant_data = (quant_datas[per_channel_size * i + j] - zero_point) * scale; | |||
| dequant_datas[per_channel_size * i + j] = static_cast<float>(dequant_data * var_corr + mean_corr); | |||
| } | |||
| } | |||
| } else { | |||
| auto quant_param = input_tensor->GetQuantParams(); | |||
| auto param = quant_param.front(); | |||
| auto scale = param.scale; | |||
| auto zero_point = param.zeroPoint; | |||
| for (int64_t j = 0; j < input_tensor->ElementsNum(); j++) { | |||
| dequant_datas[j] = static_cast<float>((quant_datas[j] - zero_point) * scale); | |||
| } | |||
| } | |||
| return dequant_datas; | |||
| } | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_ | |||
| @@ -20,6 +20,7 @@ | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "include/context.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| @@ -64,7 +65,7 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T | |||
| // data of second tensor of fc may be nullptr | |||
| auto *restore_data = weight_tensor->data_c(); | |||
| if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { | |||
| auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| return nullptr; | |||
| @@ -19,6 +19,7 @@ | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "include/context.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| @@ -35,9 +36,11 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector<lite::Tensor *> &in | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->data_c(); | |||
| auto is_const_quant_weight = (restore_data != nullptr) && (weight_tensor->data_type() == kNumberTypeInt8); | |||
| auto is_const_quant_weight = | |||
| (restore_data != nullptr) && | |||
| ((weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16)); | |||
| if (is_const_quant_weight) { | |||
| auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| free(opParameter); | |||
| @@ -49,7 +52,7 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector<lite::Tensor *> &in | |||
| auto input_tensor = inputs.at(kInputIndex); | |||
| auto data_type = input_tensor->data_type(); | |||
| kernel::LiteKernel *kernel = nullptr; | |||
| if (data_type == kNumberTypeInt8 || data_type == kNumberTypeUInt8) { | |||
| if (data_type == kNumberTypeInt8) { | |||
| kernel = new (std::nothrow) MatmulInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| } else { | |||
| kernel = new (std::nothrow) MatmulCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| @@ -22,6 +22,7 @@ | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -145,9 +146,10 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->MutableData(); | |||
| auto dequant_flag = (weight_tensor->data_type() == kNumberTypeInt8) ? true : false; | |||
| auto dequant_flag = | |||
| (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) ? true : false; | |||
| if (dequant_flag) { | |||
| auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| free(opParameter); | |||
| @@ -169,7 +171,6 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| free(opParameter); | |||
| @@ -182,14 +183,12 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| } | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return kernel; | |||
| @@ -27,6 +27,7 @@ | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "nnacl/fp16/winograd_utils_fp16.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -183,9 +184,10 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> & | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->MutableData(); | |||
| auto dequant_flag = (weight_tensor->data_type() == kNumberTypeInt8) ? true : false; | |||
| auto dequant_flag = | |||
| (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) ? true : false; | |||
| if (dequant_flag) { | |||
| auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| free(opParameter); | |||
| @@ -224,7 +226,6 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> & | |||
| MS_LOG(DEBUG) << "Create conv fp16 kernel failed."; | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| free(opParameter); | |||
| @@ -237,14 +238,12 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> & | |||
| << ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| } | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return kernel; | |||
| @@ -20,6 +20,7 @@ | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -208,9 +209,10 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->MutableData(); | |||
| auto dequant_flag = (weight_tensor->data_type() == kNumberTypeInt8) ? true : false; | |||
| auto dequant_flag = | |||
| (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) ? true : false; | |||
| if (dequant_flag) { | |||
| auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| free(opParameter); | |||
| @@ -225,7 +227,6 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| free(opParameter); | |||
| @@ -238,14 +239,12 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| } | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return kernel; | |||
| @@ -16,6 +16,7 @@ | |||
| #include "src/runtime/kernel/arm/fp16/deconvolution_fp16.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -215,9 +216,10 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->MutableData(); | |||
| auto dequant_flag = (weight_tensor->data_type() == kNumberTypeInt8) ? true : false; | |||
| auto dequant_flag = | |||
| (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) ? true : false; | |||
| if (dequant_flag) { | |||
| auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| free(opParameter); | |||
| @@ -232,7 +234,6 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| free(opParameter); | |||
| @@ -245,14 +246,12 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| } | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return kernel; | |||
| @@ -20,6 +20,7 @@ | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| @@ -242,7 +243,7 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T | |||
| // data of second tensor of fc may be nullptr | |||
| auto *restore_data = weight_tensor->data_c(); | |||
| if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { | |||
| auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| free(opParameter); | |||
| @@ -256,7 +257,6 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (!weight_tensor->GetQuantParams().empty()) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| free(opParameter); | |||
| @@ -269,14 +269,12 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T | |||
| delete kernel; | |||
| if (!weight_tensor->GetQuantParams().empty()) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| } | |||
| if (!weight_tensor->GetQuantParams().empty()) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return kernel; | |||
| @@ -20,6 +20,7 @@ | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| @@ -256,7 +257,7 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->data_c(); | |||
| if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { | |||
| auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| return nullptr; | |||
| @@ -269,7 +270,6 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| free(opParameter); | |||
| @@ -282,14 +282,12 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| delete kernel; | |||
| if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| } | |||
| if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return kernel; | |||
| @@ -23,6 +23,7 @@ | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -186,8 +187,8 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> & | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->MutableData(); | |||
| if (weight_tensor->data_type() == kNumberTypeInt8) { | |||
| auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| free(op_parameter); | |||
| @@ -207,7 +208,7 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> & | |||
| } | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (weight_tensor->data_type() == kNumberTypeInt8) { | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| @@ -219,14 +220,14 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> & | |||
| delete kernel; | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_)); | |||
| if (weight_tensor->data_type() == kNumberTypeInt8) { | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| } | |||
| if (weight_tensor->data_type() == kNumberTypeInt8) { | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| @@ -20,6 +20,7 @@ | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -133,8 +134,8 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *> | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->MutableData(); | |||
| if (weight_tensor->data_type() == kNumberTypeInt8) { | |||
| auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| free(opParameter); | |||
| @@ -152,7 +153,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *> | |||
| } | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (weight_tensor->data_type() == kNumberTypeInt8) { | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| @@ -164,14 +165,14 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *> | |||
| delete kernel; | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| if (weight_tensor->data_type() == kNumberTypeInt8) { | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| } | |||
| if (weight_tensor->data_type() == kNumberTypeInt8) { | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| @@ -17,6 +17,7 @@ | |||
| #include "src/runtime/kernel/arm/fp32/deconvolution.h" | |||
| #include "src/runtime/kernel/arm/fp32/deconvolution_winograd.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -238,8 +239,8 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *> | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->MutableData(); | |||
| if (weight_tensor->data_type() == kNumberTypeInt8) { | |||
| auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| free(opParameter); | |||
| @@ -260,7 +261,7 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *> | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (weight_tensor->data_type() == kNumberTypeInt8) { | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| @@ -272,14 +273,14 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *> | |||
| delete kernel; | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| if (weight_tensor->data_type() == kNumberTypeInt8) { | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| } | |||
| if (weight_tensor->data_type() == kNumberTypeInt8) { | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| @@ -19,6 +19,7 @@ | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -201,8 +202,8 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D); | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->MutableData(); | |||
| if (weight_tensor->data_type() == kNumberTypeInt8) { | |||
| auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| free(opParameter); | |||
| @@ -214,7 +215,7 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor | |||
| new (std::nothrow) kernel::DeconvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (weight_tensor->data_type() == kNumberTypeInt8) { | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| @@ -226,13 +227,13 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor | |||
| delete kernel; | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| if (weight_tensor->data_type() == kNumberTypeInt8) { | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| } | |||
| if (weight_tensor->data_type() == kNumberTypeInt8) { | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| @@ -91,12 +91,12 @@ T QuantizeData(const float originData, const schema::QuantParamT *quantParam) { | |||
| const auto numBit = quantParam->numBits; | |||
| const auto narrowRange = quantParam->narrowRange; | |||
| double maxLimitTemp = static_cast<float>((1 << (unsigned int)numBit) - 1); | |||
| const double maxLimit = static_cast<float>(maxLimitTemp - zeroPoint + std::numeric_limits<int8_t>::min()) * scale; | |||
| const double maxLimit = static_cast<float>(maxLimitTemp - zeroPoint + std::numeric_limits<T>::min()) * scale; | |||
| double minLimit; | |||
| if (narrowRange) { | |||
| minLimit = static_cast<float>(std::numeric_limits<int8_t>::min() + 1 - zeroPoint) * scale; | |||
| minLimit = static_cast<float>(std::numeric_limits<T>::min() + 1 - zeroPoint) * scale; | |||
| } else { | |||
| minLimit = static_cast<float>(std::numeric_limits<int8_t>::min() - zeroPoint) * scale; | |||
| minLimit = static_cast<float>(std::numeric_limits<T>::min() - zeroPoint) * scale; | |||
| } | |||
| return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] { | |||
| @@ -244,7 +244,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti | |||
| } | |||
| quant_params.emplace_back(quant_param); | |||
| } | |||
| auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t)); | |||
| auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(T)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||
| return RET_ERROR; | |||
| @@ -273,7 +273,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti | |||
| auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min); | |||
| quant_datas[i] = quant_data; | |||
| } | |||
| auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t)); | |||
| auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(T)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||
| return RET_ERROR; | |||
| @@ -48,8 +48,8 @@ STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) { | |||
| MS_LOG(ERROR) << "quantSize must be valid pos num."; | |||
| return RET_ERROR; | |||
| } | |||
| if (!WeightQuantizer::IsPosNum(config->bitNum) || config->bitNum != "8") { | |||
| MS_LOG(ERROR) << "bitNum must be valid pos num, current only support 8 bit weight quant."; | |||
| if (!WeightQuantizer::IsPosNum(config->bitNum) || (config->bitNum != "8" && config->bitNum != "16")) { | |||
| MS_LOG(ERROR) << "bitNum must be valid pos num, current only support 8 or 16 bit weight quant."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| @@ -61,6 +61,13 @@ WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const string &weightSize, | |||
| this->bitNum = static_cast<size_t>(std::stoull(bitNum)); | |||
| auto convQuantWeightChannelThreshold = static_cast<size_t>(std::stoull(convWeightChannelThreshold)); | |||
| mStrategy.reset(new QuantStrategy(quantSize, convQuantWeightChannelThreshold)); | |||
| quant_max = (1 << (unsigned int)(this->bitNum - 1)) - 1; | |||
| quant_min = -(1 << (unsigned int)(this->bitNum - 1)); | |||
| if (this->bitNum == 8) { | |||
| type_id = kNumberTypeInt8; | |||
| } else if (this->bitNum == 16) { | |||
| type_id = kNumberTypeInt16; | |||
| } | |||
| } | |||
| STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) { | |||
| @@ -96,14 +103,19 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) { | |||
| std::vector<schema::QuantParamT> quant_params; | |||
| primitive_c->AddInputQuantParam(quant_params); | |||
| auto status = | |||
| QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); | |||
| auto status = RET_ERROR; | |||
| if (type_id == kNumberTypeInt8) { | |||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); | |||
| } else if (type_id == kNumberTypeInt16) { | |||
| status = | |||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); | |||
| } | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||
| return status; | |||
| } | |||
| // set dtype | |||
| param_value->set_tensor_type(kNumberTypeInt8); | |||
| param_value->set_tensor_type(type_id); | |||
| auto abstractBase = param_node->abstract(); | |||
| if (abstractBase == nullptr) { | |||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); | |||
| @@ -114,7 +126,7 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) { | |||
| return RET_ERROR; | |||
| } | |||
| auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase); | |||
| abstractTensor->element()->set_type(TypeIdToType(kNumberTypeInt8)); | |||
| abstractTensor->element()->set_type(TypeIdToType(type_id)); | |||
| primitive_c->SetQuantType(schema::QuantType_WeightQuant); | |||
| } | |||
| @@ -159,13 +171,18 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) { | |||
| std::vector<schema::QuantParamT> quant_params; | |||
| primitive_c->AddInputQuantParam(quant_params); | |||
| auto status = | |||
| QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); | |||
| auto status = RET_ERROR; | |||
| if (type_id == kNumberTypeInt8) { | |||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); | |||
| } else if (type_id == kNumberTypeInt16) { | |||
| status = | |||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); | |||
| } | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||
| return status; | |||
| } | |||
| param_value->set_tensor_type(kNumberTypeInt8); | |||
| param_value->set_tensor_type(type_id); | |||
| // set dtype | |||
| auto abstractBase = param_node->abstract(); | |||
| if (abstractBase == nullptr) { | |||
| @@ -177,7 +194,7 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) { | |||
| return RET_ERROR; | |||
| } | |||
| auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase); | |||
| abstractTensor->element()->set_type(TypeIdToType(kNumberTypeInt8)); | |||
| abstractTensor->element()->set_type(TypeIdToType(type_id)); | |||
| primitive_c->SetQuantType(schema::QuantType_WeightQuant); | |||
| } | |||
| @@ -43,8 +43,9 @@ class WeightQuantizer : public Quantizer { | |||
| STATUS DoMulQuantize(const std::list<CNodePtr> &nodes); | |||
| static STATUS WeightQuantInputCheck(const converter::Flags *config); | |||
| static bool IsPosNum(const std::string &str); | |||
| int quant_max{INT8_MAX}; | |||
| int quant_min{INT8_MIN}; | |||
| int quant_max; | |||
| int quant_min; | |||
| TypeId type_id{kTypeUnknown}; | |||
| private: | |||
| std::unique_ptr<QuantStrategy> mStrategy; | |||