From: @xutianchun Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -39,6 +39,7 @@ set(LITE_SRC | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/errorcode.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/dequant.cc | |||
| ) | |||
| if (SUPPORT_GPU) | |||
| @@ -14,9 +14,9 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include <cmath> | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| #include "src/dequant.h" | |||
| namespace mindspore::kernel { | |||
| namespace mindspore::lite { | |||
| float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) { | |||
| MS_ASSERT(input_tensor != nullptr); | |||
| if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) { | |||
| @@ -35,6 +35,8 @@ float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) { | |||
| } | |||
| void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_int_data) { | |||
| MS_ASSERT(input_tensor != nullptr); | |||
| MS_ASSERT(unpack_int_data != nullptr); | |||
| auto quant_params = input_tensor->quantParams(); | |||
| if (quant_params == nullptr) { | |||
| MS_LOG(ERROR) << "low bits quantparams is empty."; | |||
| @@ -47,4 +49,41 @@ void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_i | |||
| UnPackUtil<int16_t, uint16_t>(input_tensor, origin_bit, unpack_int_data); | |||
| } | |||
| } | |||
| } // namespace mindspore::kernel | |||
| std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const std::vector<Tensor *> &in_tensors, | |||
| TypeId data_type) { | |||
| std::map<Tensor *, std::pair<TypeId, void *>> tensor_origin_data; | |||
| if (data_type == TypeId::kNumberTypeFloat32 || data_type == TypeId::kNumberTypeFloat16) { | |||
| for (auto weight_tensor : in_tensors) { | |||
| MS_ASSERT(weight_tensor != nullptr); | |||
| auto *restore_data = weight_tensor->data_c(); | |||
| auto restore_type = weight_tensor->data_type(); | |||
| bool dequant_flag = !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && | |||
| restore_data != nullptr; | |||
| if (dequant_flag) { | |||
| auto *dequant_weight = DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| return tensor_origin_data; | |||
| } | |||
| weight_tensor->set_data(dequant_weight); | |||
| weight_tensor->set_data_type(kNumberTypeFloat32); | |||
| tensor_origin_data[weight_tensor] = {restore_type, restore_data}; | |||
| } | |||
| } | |||
| } | |||
| return tensor_origin_data; | |||
| } | |||
| void DequantUtil::RestoreTensorData(const std::map<Tensor *, std::pair<TypeId, void *>> &tensor_origin_data_map) { | |||
| for (auto &kv : tensor_origin_data_map) { | |||
| auto *tensor = kv.first; | |||
| auto type_id = kv.second.first; | |||
| auto data = kv.second.second; | |||
| tensor->FreeData(); | |||
| tensor->set_data_type(type_id); | |||
| tensor->set_data(data); | |||
| } | |||
| } | |||
| } // namespace mindspore::lite | |||
| @@ -17,6 +17,8 @@ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_ | |||
| #include <map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <queue> | |||
| #include <cmath> | |||
| @@ -24,13 +26,18 @@ | |||
| #include "src/common/utils.h" | |||
| #include "src/tensor.h" | |||
| namespace mindspore::kernel { | |||
| namespace mindspore::lite { | |||
| class DequantUtil { | |||
| public: | |||
| static float *DequantWeight(lite::Tensor *input_tensor); | |||
| static void UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data); | |||
| static std::map<Tensor *, std::pair<TypeId, void *>> DequantTensor(const std::vector<Tensor *> &in_tensors, | |||
| TypeId data_type); | |||
| static void RestoreTensorData(const std::map<Tensor *, std::pair<TypeId, void *>> &tensor_origin_data_map); | |||
| template <typename ST, typename DT = float> | |||
| static DT *DequantData(lite::Tensor *input_tensor) { | |||
| const auto *quant_datas = static_cast<const ST *>(input_tensor->MutableData()); | |||
| @@ -108,7 +115,7 @@ class DequantUtil { | |||
| static void UnPackData(int origin_bit, const T2 &packed_data, std::queue<bool> *unpack_bit_data, void *unpack_int, | |||
| size_t *count, bool is_last) { | |||
| T2 uint_result = 0; | |||
| T1 result = 0; | |||
| T1 result; | |||
| UnPackFromUintToOrigin<T2>(packed_data, unpack_bit_data); | |||
| while (static_cast<int>(unpack_bit_data->size()) >= origin_bit) { | |||
| for (int k = 0; k < origin_bit; k++) { | |||
| @@ -163,6 +170,6 @@ class DequantUtil { | |||
| } | |||
| } | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| } // namespace mindspore::lite | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_ | |||
| @@ -27,7 +27,7 @@ | |||
| #include "src/common/graph_util.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/lite_model.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| #include "src/dequant.h" | |||
| #if SUPPORT_NPU | |||
| #include "src/runtime/agent/npu/npu_manager.h" | |||
| #include "src/runtime/agent/npu/optimizer/npu_pass_manager.h" | |||
| @@ -120,7 +120,7 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde | |||
| MS_LOG(ERROR) << "Malloc data for tensor failed "; | |||
| return RET_ERROR; | |||
| } | |||
| kernel::DequantUtil::UnPackToInt(src_tensor, dst_tensor->MutableData()); | |||
| DequantUtil::UnPackToInt(src_tensor, dst_tensor->MutableData()); | |||
| copyed_tensor_idxes_.emplace_back(tensor_index); | |||
| } else { | |||
| dst_tensor->set_data(const_cast<unsigned char *>(src_tensor->data()->data())); | |||
| @@ -25,7 +25,6 @@ | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -359,22 +358,6 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> & | |||
| MS_ASSERT(opParameter != nullptr); | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->data_c(); | |||
| auto restore_type = weight_tensor->data_type(); | |||
| bool dequant_flag = | |||
| !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; | |||
| if (dequant_flag) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| weight_tensor->set_data_type(kNumberTypeFloat32); | |||
| weight_tensor->set_data(dequant_weight); | |||
| } | |||
| auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | |||
| kernel::LiteKernel *kernel = nullptr; | |||
| if (conv_param->group_ == 1) { | |||
| @@ -385,11 +368,6 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> & | |||
| if (kernel == nullptr) { | |||
| MS_LOG(DEBUG) << "Create conv fp16 kernel failed."; | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| @@ -398,20 +376,9 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> & | |||
| if (ret != RET_OK) { | |||
| MS_LOG(INFO) << "Init fp16 kernel failed, name: " << opParameter->name_ | |||
| << ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Conv2D, CpuConvFp16KernelCreator) | |||
| @@ -22,7 +22,6 @@ | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -138,22 +137,6 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| MS_ASSERT(opParameter != nullptr); | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->data_c(); | |||
| auto restore_type = weight_tensor->data_type(); | |||
| bool dequant_flag = | |||
| !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; | |||
| if (dequant_flag) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| weight_tensor->set_data_type(kNumberTypeFloat32); | |||
| weight_tensor->set_data(dequant_weight); | |||
| } | |||
| auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | |||
| kernel::LiteKernel *kernel; | |||
| if (conv_param->input_channel_ < 32) { | |||
| @@ -164,11 +147,6 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| } | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| @@ -176,19 +154,9 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| return kernel; | |||
| } | |||
| @@ -20,7 +20,6 @@ | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| #include "nnacl/fp16/conv_fp16.h" | |||
| #include "nnacl/fp16/matmul_fp16.h" | |||
| #include "nnacl/fp16/cast_fp16.h" | |||
| @@ -20,7 +20,6 @@ | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -212,30 +211,9 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor | |||
| MS_ASSERT(opParameter != nullptr); | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D); | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->data_c(); | |||
| auto restore_type = weight_tensor->data_type(); | |||
| auto dequant_flag = | |||
| !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; | |||
| if (dequant_flag) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| weight_tensor->set_data_type(kNumberTypeFloat32); | |||
| weight_tensor->set_data(dequant_weight); | |||
| } | |||
| auto kernel = new (std::nothrow) DeconvolutionDepthwiseFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| @@ -243,19 +221,9 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| return kernel; | |||
| } | |||
| @@ -17,7 +17,6 @@ | |||
| #include "src/runtime/kernel/arm/fp16/deconvolution_fp16.h" | |||
| #include "src/runtime/kernel/arm/fp16/deconvolution_winograd_fp16.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -220,22 +219,6 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| MS_ASSERT(opParameter != nullptr); | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->data_c(); | |||
| auto restore_type = weight_tensor->data_type(); | |||
| auto dequant_flag = | |||
| !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; | |||
| if (dequant_flag) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| weight_tensor->set_data_type(kNumberTypeFloat32); | |||
| weight_tensor->set_data(dequant_weight); | |||
| } | |||
| kernel::LiteKernel *kernel; | |||
| auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | |||
| if ((conv_param->stride_h_ != 1 || conv_param->stride_w_ != 1) && | |||
| @@ -247,11 +230,6 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| @@ -259,19 +237,9 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_DeConv2D, CpuDeConvFp16KernelCreator) | |||
| @@ -234,30 +234,9 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T | |||
| OpParameter *opParameter, const lite::InnerContext *ctx, | |||
| const kernel::KernelKey &desc, | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| // data of second tensor of fc may be nullptr | |||
| auto *restore_data = weight_tensor->data_c(); | |||
| auto restore_type = weight_tensor->data_type(); | |||
| bool dequant_flag = | |||
| !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; | |||
| if (dequant_flag) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| weight_tensor->set_data_type(kNumberTypeFloat32); | |||
| weight_tensor->set_data(dequant_weight); | |||
| } | |||
| auto *kernel = new (std::nothrow) FullconnectionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| @@ -265,19 +244,9 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| return kernel; | |||
| } | |||
| @@ -24,7 +24,6 @@ | |||
| #include "nnacl/fp16/matmul_fp16.h" | |||
| #include "nnacl/fp16/cast_fp16.h" | |||
| #include "src/lite_kernel.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| namespace mindspore::kernel { | |||
| class FullconnectionFP16CPUKernel : public LiteKernel { | |||
| @@ -20,7 +20,6 @@ | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| @@ -330,29 +329,9 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | |||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->data_c(); | |||
| auto restore_type = weight_tensor->data_type(); | |||
| bool dequant_flag = | |||
| !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; | |||
| if (dequant_flag) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| weight_tensor->set_data_type(kNumberTypeFloat32); | |||
| weight_tensor->set_data(dequant_weight); | |||
| } | |||
| auto *kernel = new (std::nothrow) MatmulFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| @@ -361,18 +340,8 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| delete kernel; | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| return nullptr; | |||
| } | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| return kernel; | |||
| } | |||
| @@ -22,7 +22,6 @@ | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -356,22 +355,6 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> & | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); | |||
| MS_ASSERT(desc.data_type == kNumberTypeFloat32); | |||
| // if get quantized weight, dequantize it to float32 type data. | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->data_c(); | |||
| auto restore_type = weight_tensor->data_type(); | |||
| bool dequant_flag = | |||
| !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; | |||
| if (dequant_flag) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| free(op_parameter); | |||
| return nullptr; | |||
| } | |||
| weight_tensor->set_data(dequant_weight); | |||
| } | |||
| auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter); | |||
| kernel::LiteKernel *kernel = nullptr; | |||
| if (conv_param->group_ == 1) { | |||
| @@ -382,11 +365,6 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> & | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| free(op_parameter); | |||
| return nullptr; | |||
| } | |||
| @@ -395,20 +373,9 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> & | |||
| if (ret != RET_OK && ret != RET_INFER_INVALID) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_)); | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| return kernel; | |||
| } | |||
| @@ -21,7 +21,6 @@ | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -126,19 +125,6 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *> | |||
| MS_ASSERT(opParameter != nullptr); | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->data_c(); | |||
| auto restore_type = weight_tensor->data_type(); | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| weight_tensor->set_data(dequant_weight); | |||
| } | |||
| auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | |||
| kernel::LiteKernel *kernel = nullptr; | |||
| if (primitive != nullptr && primitive->infer_flag()) { | |||
| @@ -162,11 +148,6 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *> | |||
| } | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| @@ -174,21 +155,10 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *> | |||
| if (ret != RET_OK && ret != RET_INFER_INVALID) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| return kernel; | |||
| } | |||
| @@ -19,7 +19,6 @@ | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| #include "nnacl/fp32/conv_fp32.h" | |||
| #include "nnacl/fp32/matmul_fp32.h" | |||
| @@ -19,7 +19,6 @@ | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -202,29 +201,10 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| MS_ASSERT(opParameter != nullptr); | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D); | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->data_c(); | |||
| auto restore_type = weight_tensor->data_type(); | |||
| bool dequant_flag = | |||
| !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; | |||
| if (dequant_flag) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| weight_tensor->set_data(dequant_weight); | |||
| } | |||
| auto kernel = | |||
| new (std::nothrow) kernel::DeconvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| @@ -232,19 +212,9 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| return kernel; | |||
| } | |||
| @@ -17,7 +17,6 @@ | |||
| #include "src/runtime/kernel/arm/fp32/deconvolution_fp32.h" | |||
| #include "src/runtime/kernel/arm/fp32/deconvolution_winograd_fp32.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -240,20 +239,6 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *> | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| MS_ASSERT(opParameter != nullptr); | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->data_c(); | |||
| auto restore_type = weight_tensor->data_type(); | |||
| bool dequant_flag = | |||
| !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; | |||
| if (dequant_flag) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| weight_tensor->set_data(dequant_weight); | |||
| } | |||
| kernel::LiteKernel *kernel; | |||
| auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | |||
| @@ -266,11 +251,6 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *> | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| @@ -278,21 +258,9 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *> | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| return kernel; | |||
| } | |||
| @@ -228,28 +228,9 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| MS_ASSERT(opParameter != nullptr); | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_FullConnection); | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| // data of second tensor of fc may be nullptr | |||
| auto *restore_data = weight_tensor->data_c(); | |||
| auto restore_type = weight_tensor->data_type(); | |||
| bool dequant_flag = | |||
| !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; | |||
| if (dequant_flag) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| return nullptr; | |||
| } | |||
| weight_tensor->set_data(dequant_weight); | |||
| } | |||
| auto kernel = new (std::nothrow) FullconnectionCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| if (!kernel) { | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| @@ -257,19 +238,9 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| return kernel; | |||
| } | |||
| @@ -22,7 +22,6 @@ | |||
| #include "include/errorcode.h" | |||
| #include "nnacl/fp32/matmul_fp32.h" | |||
| #include "src/lite_kernel.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| using mindspore::lite::InnerContext; | |||
| namespace mindspore::kernel { | |||
| @@ -19,7 +19,6 @@ | |||
| #include "nnacl/fp32/matmul_fp32.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_INPUT_TENSOR_ERROR; | |||
| @@ -417,30 +416,9 @@ kernel::LiteKernel *CpuMatmulFp32KernelCreator(const std::vector<lite::Tensor *> | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| MS_ASSERT(opParameter != nullptr); | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_MatMul); | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->data_c(); | |||
| auto restore_type = weight_tensor->data_type(); | |||
| bool dequant_flag = | |||
| !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; | |||
| if (dequant_flag) { | |||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||
| if (dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| weight_tensor->set_data(dequant_weight); | |||
| } | |||
| auto kernel = new (std::nothrow) MatmulCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| @@ -448,21 +426,9 @@ kernel::LiteKernel *CpuMatmulFp32KernelCreator(const std::vector<lite::Tensor *> | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| if (dequant_flag) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data(restore_data); | |||
| weight_tensor->set_data_type(restore_type); | |||
| } | |||
| return kernel; | |||
| } | |||
| @@ -15,7 +15,7 @@ | |||
| */ | |||
| #include "src/runtime/kernel/opencl/opencl_kernel.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| #include "mindspore/lite/src/dequant.h" | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| @@ -263,10 +263,10 @@ int OpenCLKernel::DequantWeight() { | |||
| if (is_fp16) { | |||
| #ifdef ENABLE_ARM64 | |||
| if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt8) { | |||
| dequant_weight = kernel::DequantUtil::DequantData<int8_t, float16_t>(weight_tensor); | |||
| dequant_weight = lite::DequantUtil::DequantData<int8_t, float16_t>(weight_tensor); | |||
| weight_tensor->set_data_type(kNumberTypeFloat16); | |||
| } else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt16) { | |||
| dequant_weight = kernel::DequantUtil::DequantData<int16_t, float16_t>(weight_tensor); | |||
| dequant_weight = lite::DequantUtil::DequantData<int16_t, float16_t>(weight_tensor); | |||
| weight_tensor->set_data_type(kNumberTypeFloat16); | |||
| } else { | |||
| set_flag = false; | |||
| @@ -276,10 +276,10 @@ int OpenCLKernel::DequantWeight() { | |||
| #endif | |||
| } else { | |||
| if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt8) { | |||
| dequant_weight = kernel::DequantUtil::DequantData<int8_t, float>(weight_tensor); | |||
| dequant_weight = lite::DequantUtil::DequantData<int8_t, float>(weight_tensor); | |||
| weight_tensor->set_data_type(kNumberTypeFloat32); | |||
| } else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt16) { | |||
| dequant_weight = kernel::DequantUtil::DequantData<int16_t, float>(weight_tensor); | |||
| dequant_weight = lite::DequantUtil::DequantData<int16_t, float>(weight_tensor); | |||
| weight_tensor->set_data_type(kNumberTypeFloat32); | |||
| } else { | |||
| set_flag = false; | |||
| @@ -25,7 +25,7 @@ | |||
| #include "src/lite_kernel.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/opencl/opencl_runtime.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| #include "mindspore/lite/src/dequant.h" | |||
| #include "src/runtime/kernel/opencl/utils.h" | |||
| using mindspore::lite::RET_ERROR; | |||
| @@ -27,6 +27,7 @@ | |||
| #include "src/common/utils.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/sub_graph_kernel.h" | |||
| #include "src/dequant.h" | |||
| #if SUPPORT_GPU | |||
| #include "src/runtime/kernel/opencl/opencl_subgraph.h" | |||
| #include "src/runtime/opencl/opencl_runtime.h" | |||
| @@ -213,8 +214,10 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in | |||
| if (mindspore::lite::IsSupportFloat16() && | |||
| ((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) { | |||
| kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type}; | |||
| auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, fp16_cpu_desc.data_type); | |||
| auto *kernel = | |||
| KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc); | |||
| DequantUtil::RestoreTensorData(tensor_origin_data_map); | |||
| if (kernel != nullptr) { | |||
| MS_LOG(DEBUG) << "Get fp16 op success: " << schema::EnumNamePrimitiveType(fp16_cpu_desc.type) << " " | |||
| << node->name_; | |||
| @@ -225,7 +228,9 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in | |||
| MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op."; | |||
| desc.data_type = kNumberTypeFloat32; | |||
| } | |||
| auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, desc.data_type); | |||
| auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); | |||
| DequantUtil::RestoreTensorData(tensor_origin_data_map); | |||
| if (kernel != nullptr) { | |||
| return kernel; | |||
| } | |||
| @@ -126,6 +126,7 @@ set(TEST_LITE_SRC | |||
| ${LITE_DIR}/src/kernel_registry.cc | |||
| ${LITE_DIR}/src/lite_kernel.cc | |||
| ${LITE_DIR}/src/lite_session.cc | |||
| ${LITE_DIR}/src/dequant.cc | |||
| ${LITE_DIR}/src/sub_graph_kernel.cc | |||
| ${LITE_DIR}/src/lite_model.cc | |||
| ${LITE_DIR}/src/scheduler.cc | |||
| @@ -95,6 +95,7 @@ set(LITE_SRC | |||
| ${SRC_DIR}/executor.cc | |||
| ${SRC_DIR}/lite_model.cc | |||
| ${SRC_DIR}/errorcode.cc | |||
| ${SRC_DIR}/dequant.cc | |||
| ) | |||
| if (SUPPORT_TRAIN) | |||
| set(LITE_SRC | |||
| @@ -782,4 +782,27 @@ FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &func_graph) { | |||
| return new_func_graph; | |||
| } | |||
| void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, ParamValueLitePtr *param_value) { | |||
| MS_ASSERT(node != nullptr); | |||
| MS_ASSERT(param_node != nullptr); | |||
| MS_ASSERT(param_value != nullptr); | |||
| auto op_name = node->fullname_with_scope(); | |||
| *param_node = node->cast<ParameterPtr>(); | |||
| if (*param_node == nullptr) { | |||
| MS_LOG(INFO) << op_name << " can not cast to ParameterPtr"; | |||
| return; | |||
| } | |||
| if (!(*param_node)->has_default()) { | |||
| MS_LOG(INFO) << op_name << " not has_default"; | |||
| return; | |||
| } | |||
| *param_value = std::static_pointer_cast<ParamValueLite>((*param_node)->default_param()); | |||
| if (*param_value == nullptr) { | |||
| MS_LOG(INFO) << "default_param can not cast to ParamValueLite"; | |||
| return; | |||
| } | |||
| } | |||
| } // namespace mindspore::lite::quant | |||
| @@ -75,9 +75,10 @@ class QuantStrategy { | |||
| bool CanMulOpQuantized(const CNodePtr &node) const; | |||
| bool CanOpPostQuantized(AnfNodePtr &node) const; | |||
| private: | |||
| size_t mWeightSize; | |||
| size_t mConvWeightQuantChannelThreshold; | |||
| private: | |||
| static const std::vector<schema::PrimitiveType> conv_types; | |||
| static const std::vector<schema::PrimitiveType> mul_types; | |||
| }; | |||
| @@ -356,5 +357,8 @@ STATUS CopyInputDataToTensor(size_t input_index, size_t image_index, | |||
| const std::vector<std::vector<std::string>> &images, mindspore::tensor::MSTensor *tensor); | |||
| FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &); | |||
| void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, ParamValueLitePtr *param_value); | |||
| } // namespace mindspore::lite::quant | |||
| #endif | |||
| @@ -20,7 +20,6 @@ | |||
| #include <vector> | |||
| #include <unordered_map> | |||
| #include "src/common/common.h" | |||
| #include "ir/dtype/type_id.h" | |||
| using std::string; | |||
| using std::vector; | |||
| @@ -73,13 +72,13 @@ WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const std::string &config_f | |||
| this->bit_num_ = static_cast<size_t>(std::stoull(bitNum)); | |||
| auto convQuantWeightChannelThreshold = static_cast<size_t>(std::stoull(convWeightChannelThreshold)); | |||
| quant_strategy_ = std::make_unique<QuantStrategy>(quantSize, convQuantWeightChannelThreshold); | |||
| quant_max = (1 << (unsigned int)(this->bit_num_ - 1)) - 1; | |||
| quant_min = -(1 << (unsigned int)(this->bit_num_ - 1)); | |||
| // parse type_id | |||
| quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1; | |||
| quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1)); | |||
| // parse type_id_ | |||
| if (this->bit_num_ > 0 && this->bit_num_ <= 8) { | |||
| type_id = kNumberTypeInt8; | |||
| type_id_ = kNumberTypeInt8; | |||
| } else if (this->bit_num_ <= 16) { | |||
| type_id = kNumberTypeInt16; | |||
| type_id_ = kNumberTypeInt16; | |||
| } else { | |||
| MS_LOG(ERROR) << "invalid input bits"; | |||
| } | |||
| @@ -90,7 +89,7 @@ WeightQuantizer::~WeightQuantizer() { delete fp32_session_; } | |||
| STATUS WeightQuantizer::SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, | |||
| std::shared_ptr<PrimitiveC> primitive_c) { | |||
| // set dtype | |||
| param_value->set_tensor_type(type_id); | |||
| param_value->set_tensor_type(type_id_); | |||
| auto abstract_base = param_node->abstract(); | |||
| if (abstract_base == nullptr) { | |||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); | |||
| @@ -101,49 +100,158 @@ STATUS WeightQuantizer::SetAbstract(ParamValueLitePtr param_value, ParameterPtr | |||
| return RET_ERROR; | |||
| } | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base); | |||
| abstract_tensor->element()->set_type(TypeIdToType(type_id)); | |||
| abstract_tensor->element()->set_type(TypeIdToType(type_id_)); | |||
| primitive_c->set_quant_type(schema::QuantType_WeightQuant); | |||
| return RET_OK; | |||
| } | |||
| STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) { | |||
| for (auto &cnode : nodes) { | |||
| if (!quant_strategy_->CanConvOpQuantized(cnode)) { | |||
| continue; | |||
| } | |||
| STATUS WeightQuantizer::DoConvQuantize(CNodePtr cnode) { | |||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto input_node = cnode->input(2); | |||
| if (!input_node->isa<Parameter>()) { | |||
| return RET_ERROR; | |||
| } | |||
| auto input_node = cnode->input(2); | |||
| if (!input_node->isa<Parameter>()) { | |||
| return RET_ERROR; | |||
| } | |||
| ParameterPtr param_node; | |||
| ParamValueLitePtr param_value; | |||
| auto param_node = input_node->cast<ParameterPtr>(); | |||
| if (!param_node->has_default()) { | |||
| return RET_ERROR; | |||
| GetLiteParameter(input_node, ¶m_node, ¶m_value); | |||
| if (param_node == nullptr || param_value == nullptr) { | |||
| MS_LOG(ERROR) << "GetLiteParameter error"; | |||
| return RET_ERROR; | |||
| } | |||
| if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) { | |||
| MS_LOG(ERROR) << "model weight data type invalid which is " << param_value->tensor_type(); | |||
| return RET_ERROR; | |||
| } | |||
| auto status = RET_ERROR; | |||
| if (type_id_ == kNumberTypeInt8) { | |||
| status = | |||
| QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true); | |||
| } else if (type_id_ == kNumberTypeInt16) { | |||
| status = | |||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true); | |||
| } | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||
| return status; | |||
| } | |||
| status = SetAbstract(param_value, param_node, primitive_c); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "SetAbstract failed : " << status; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) { | |||
| auto already_quant = false; | |||
| ParamValueLitePtr param_value = nullptr; | |||
| ParameterPtr param_node = nullptr; | |||
| for (size_t i = 1; i < cnode->size(); i++) { | |||
| auto inputNode = cnode->input(i); | |||
| if (inputNode->isa<Parameter>()) { | |||
| param_node = inputNode->cast<ParameterPtr>(); | |||
| if ((param_node != nullptr) && param_node->has_default()) { | |||
| param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param()); | |||
| if ((param_value == nullptr) || (param_value->tensor_size() == 0) || (param_value->tensor_addr() == nullptr)) { | |||
| param_value = nullptr; | |||
| continue; | |||
| } else if (param_value->tensor_type() == mindspore::kNumberTypeInt8 || | |||
| param_value->tensor_type() == mindspore::kNumberTypeInt16) { | |||
| MS_LOG(INFO) << "the node: " << cnode->fullname_with_scope() << " input_i: " << i << "has been " | |||
| << " quantized"; | |||
| already_quant = true; | |||
| break; | |||
| } else if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) { | |||
| param_value = nullptr; | |||
| continue; | |||
| } else { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| ParamValueLitePtr param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param()); | |||
| if (param_value == nullptr) { | |||
| if (already_quant) { | |||
| return RET_OK; | |||
| } | |||
| if (param_value == nullptr) { | |||
| MS_LOG(ERROR) << "No valid input param node !"; | |||
| return RET_ERROR; | |||
| } | |||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto status = RET_ERROR; | |||
| if (type_id_ == kNumberTypeInt8) { | |||
| status = | |||
| QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true); | |||
| } else if (type_id_ == kNumberTypeInt16) { | |||
| status = | |||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true); | |||
| } | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||
| return status; | |||
| } | |||
| status = SetAbstract(param_value, param_node, primitive_c); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "SetAbstract failed : " << status; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| auto op_name = cnode->fullname_with_scope(); | |||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| MS_ASSERT(primitive_c != nullptr); | |||
| if (cnode->inputs().size() < 4) { | |||
| MS_LOG(ERROR) << op_name << " inputs is " << cnode->inputs().size(); | |||
| return RET_ERROR; | |||
| } | |||
| { | |||
| auto weight_i = cnode->input(2); | |||
| ParameterPtr param_node; | |||
| ParamValueLitePtr param_value; | |||
| GetLiteParameter(weight_i, ¶m_node, ¶m_value); | |||
| if (param_node == nullptr || param_value == nullptr) { | |||
| MS_LOG(ERROR) << "GetLiteParameter error"; | |||
| return RET_ERROR; | |||
| } | |||
| if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) { | |||
| MS_LOG(ERROR) << "model weight data type invalid which is " << param_value->tensor_type(); | |||
| return RET_ERROR; | |||
| if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { | |||
| MS_LOG(WARNING) << "param_value tensor type is: " << param_value->tensor_type() << " not quant"; | |||
| return RET_OK; | |||
| } | |||
| if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) { | |||
| MS_LOG(INFO) << op_name << " weight_i cnt: " << param_value->tensor_size() / 4 << " < " | |||
| << quant_strategy_->mWeightSize; | |||
| return RET_OK; | |||
| } | |||
| auto status = RET_ERROR; | |||
| if (type_id == kNumberTypeInt8) { | |||
| if (type_id_ == kNumberTypeInt8) { | |||
| status = | |||
| QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true); | |||
| } else if (type_id == kNumberTypeInt16) { | |||
| QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); | |||
| } else if (type_id_ == kNumberTypeInt16) { | |||
| status = | |||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true); | |||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); | |||
| } | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||
| @@ -155,65 +263,26 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) { | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) { | |||
| for (auto &node : nodes) { | |||
| if (!quant_strategy_->CanMulOpQuantized(node)) { | |||
| continue; | |||
| } | |||
| auto already_quant = false; | |||
| ParamValueLitePtr param_value = nullptr; | |||
| ParameterPtr param_node = nullptr; | |||
| for (size_t i = 1; i < node->size(); i++) { | |||
| auto inputNode = node->input(i); | |||
| if (inputNode->isa<Parameter>()) { | |||
| param_node = inputNode->cast<ParameterPtr>(); | |||
| if ((param_node != nullptr) && param_node->has_default()) { | |||
| param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param()); | |||
| if ((param_value == nullptr) || (param_value->tensor_size() == 0) || | |||
| (param_value->tensor_addr() == nullptr)) { | |||
| param_value = nullptr; | |||
| continue; | |||
| } else if (param_value->tensor_type() == mindspore::kNumberTypeInt8 || | |||
| param_value->tensor_type() == mindspore::kNumberTypeInt16) { | |||
| MS_LOG(INFO) << "the node: " << node->fullname_with_scope() << " input_i: " << i << "has been " | |||
| << " quantized"; | |||
| already_quant = true; | |||
| break; | |||
| } else if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) { | |||
| param_value = nullptr; | |||
| continue; | |||
| } else { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (already_quant) { | |||
| continue; | |||
| } | |||
| if (param_value == nullptr) { | |||
| MS_LOG(ERROR) << "No valid input param node !"; | |||
| { | |||
| auto weight_h = cnode->input(3); | |||
| ParameterPtr param_node; | |||
| ParamValueLitePtr param_value; | |||
| GetLiteParameter(weight_h, ¶m_node, ¶m_value); | |||
| if (param_node == nullptr || param_value == nullptr) { | |||
| MS_LOG(ERROR) << "GetLiteParameter error"; | |||
| return RET_ERROR; | |||
| } | |||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(node->input(0)); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||
| if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { | |||
| MS_LOG(ERROR) << "param_value tensor type is: " << param_value->tensor_type() << " not quant"; | |||
| return RET_ERROR; | |||
| } | |||
| auto status = RET_ERROR; | |||
| if (type_id == kNumberTypeInt8) { | |||
| if (type_id_ == kNumberTypeInt8) { | |||
| status = | |||
| QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true); | |||
| } else if (type_id == kNumberTypeInt16) { | |||
| QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); | |||
| } else if (type_id_ == kNumberTypeInt16) { | |||
| status = | |||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true); | |||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); | |||
| } | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||
| @@ -225,7 +294,78 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) { | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| { | |||
| if (cnode->inputs().size() > 4) { | |||
| auto bias = cnode->input(4); | |||
| ParameterPtr param_node; | |||
| ParamValueLitePtr param_value; | |||
| GetLiteParameter(bias, ¶m_node, ¶m_value); | |||
| if (param_node == nullptr || param_value == nullptr) { | |||
| MS_LOG(ERROR) << "GetLiteParameter error"; | |||
| return RET_ERROR; | |||
| } | |||
| if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { | |||
| MS_LOG(ERROR) << "param_value tensor type is: " << param_value->tensor_type() << " not quant"; | |||
| return RET_ERROR; | |||
| } | |||
| auto status = RET_ERROR; | |||
| if (type_id_ == kNumberTypeInt8) { | |||
| status = | |||
| QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); | |||
| } else if (type_id_ == kNumberTypeInt16) { | |||
| status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, | |||
| false); | |||
| } | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||
| return status; | |||
| } | |||
| status = SetAbstract(param_value, param_node, primitive_c); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "SetAbstract failed : " << status; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) { | |||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| MS_ASSERT(primitive_c != nullptr); | |||
| auto weight_h = cnode->input(1); | |||
| ParameterPtr param_node; | |||
| ParamValueLitePtr param_value; | |||
| GetLiteParameter(weight_h, ¶m_node, ¶m_value); | |||
| if (param_node == nullptr || param_value == nullptr || param_value->tensor_type() != TypeId::kNumberTypeFloat32) { | |||
| MS_LOG(INFO) << "This Gather op " << cnode->fullname_with_scope() << " can not quant weight"; | |||
| return RET_OK; | |||
| } | |||
| if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) { | |||
| MS_LOG(INFO) << cnode->fullname_with_scope() << " param cnt: " << param_value->tensor_size() / 4 << " < " | |||
| << quant_strategy_->mWeightSize; | |||
| return RET_OK; | |||
| } | |||
| auto status = RET_ERROR; | |||
| if (type_id_ == kNumberTypeInt8) { | |||
| status = | |||
| QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); | |||
| } else if (type_id_ == kNumberTypeInt16) { | |||
| status = | |||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); | |||
| } | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||
| return status; | |||
| } | |||
| status = SetAbstract(param_value, param_node, primitive_c); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "SetAbstract failed : " << status; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -315,6 +455,23 @@ STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) { | |||
| } | |||
| auto cnodes = func_graph->GetOrderedCnodes(); | |||
| for (auto &cnode : cnodes) { | |||
| auto op_type = NodePrimitiveType(cnode); | |||
| if (op_type == schema::PrimitiveType_Lstm) { | |||
| status = DoLstmQuntize(cnode); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DoLstmQuntize error"; | |||
| return RET_ERROR; | |||
| } | |||
| } else if (op_type == schema::PrimitiveType_Gather) { | |||
| status = DoGatherQuntize(cnode); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DoGatherQuntize error"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } | |||
| for (auto iter = cnodes.end(); iter != cnodes.begin();) { | |||
| auto cnode = *(--iter); | |||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| @@ -357,18 +514,18 @@ STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) { | |||
| } | |||
| // 1. try quant | |||
| for (int bit_num_t = 2; bit_num_t <= 8; bit_num_t++) { | |||
| type_id = TypeId::kNumberTypeInt8; | |||
| type_id_ = TypeId::kNumberTypeInt8; | |||
| int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1; | |||
| int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1)); | |||
| if (type_id == TypeId::kNumberTypeInt8) { | |||
| if (type_id_ == TypeId::kNumberTypeInt8) { | |||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, | |||
| quant_min_t, bit_num_t, true); | |||
| } else if (type_id == TypeId::kNumberTypeInt16) { | |||
| } else if (type_id_ == TypeId::kNumberTypeInt16) { | |||
| status = QuantFilter<int16_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, | |||
| quant_min_t, bit_num_t, true); | |||
| } else { | |||
| MS_LOG(ERROR) << "unexpected type_id: " << type_id; | |||
| MS_LOG(ERROR) << "unexpected type_id_: " << type_id_; | |||
| return RET_ERROR; | |||
| } | |||
| if (status != RET_OK) { | |||
| @@ -456,13 +613,53 @@ STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) { | |||
| return RET_OK; | |||
| } | |||
| STATUS WeightQuantizer::DoFixedQuant(FuncGraphPtr func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| for (auto &cnode : func_graph->GetOrderedCnodes()) { | |||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto op_name = cnode->fullname_with_scope(); | |||
| auto op_type = (schema::PrimitiveType)primitive_c->Type(); | |||
| if (quant_strategy_->CanConvOpQuantized(cnode)) { | |||
| auto status = DoConvQuantize(cnode); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DoConvQuantize error"; | |||
| return RET_ERROR; | |||
| } | |||
| } else if (quant_strategy_->CanMulOpQuantized(cnode)) { | |||
| auto status = DoMulQuantize(cnode); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DoMulQuantize error"; | |||
| return RET_ERROR; | |||
| } | |||
| } else if (op_type == schema::PrimitiveType_Lstm) { | |||
| auto status = DoLstmQuntize(cnode); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DoLstmQuntize error"; | |||
| return RET_ERROR; | |||
| } | |||
| } else if (op_type == schema::PrimitiveType_Gather) { | |||
| auto status = DoGatherQuntize(cnode); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DoGatherQuntize error"; | |||
| return RET_ERROR; | |||
| } | |||
| } else { | |||
| MS_LOG(DEBUG) << op_name << " of type: " << schema::EnumNamePrimitiveType(op_type) << " no need quant"; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| STATUS ret; | |||
| auto cnodes = func_graph->GetOrderedCnodes(); | |||
| if (!config_file_.empty()) { | |||
| ret = ParseConfigFile(config_file_, &config_param_); | |||
| auto ret = ParseConfigFile(config_file_, &config_param_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ReadConfig error."; | |||
| return RET_ERROR; | |||
| @@ -470,20 +667,14 @@ STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) { | |||
| } | |||
| if (config_param_.mixed) { | |||
| bit_num_ = 8; | |||
| quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1; | |||
| quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1)); | |||
| type_id_ = kNumberTypeInt8; | |||
| MS_LOG(INFO) << "Do mixed bit quantization"; | |||
| return DoMiexedQuant(func_graph); | |||
| } | |||
| ret = DoConvQuantize(cnodes); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "DoConvQuantize failed :" << ret; | |||
| return ret; | |||
| } | |||
| ret = DoMulQuantize(cnodes); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "DoMulQuantize failed :" << ret; | |||
| return ret; | |||
| } | |||
| return ret; | |||
| return DoFixedQuant(func_graph); | |||
| } | |||
| } // namespace mindspore::lite::quant | |||
| @@ -41,19 +41,21 @@ class WeightQuantizer : public Quantizer { | |||
| ~WeightQuantizer(); | |||
| STATUS DoQuantize(FuncGraphPtr func_graph) override; | |||
| STATUS DoConvQuantize(const std::list<CNodePtr> &nodes); | |||
| STATUS DoMulQuantize(const std::list<CNodePtr> &nodes); | |||
| STATUS DoConvQuantize(CNodePtr); | |||
| STATUS DoMulQuantize(CNodePtr); | |||
| STATUS DoLstmQuntize(CNodePtr cnode); | |||
| STATUS DoGatherQuntize(CNodePtr cnode); | |||
| static STATUS WeightQuantInputCheck(const converter::Flags *config); | |||
| static bool IsPosNum(const std::string &str); | |||
| int quant_max; | |||
| int quant_min; | |||
| TypeId type_id{kTypeUnknown}; | |||
| int quant_max_{127}; | |||
| int quant_min_{-128}; | |||
| TypeId type_id_{kNumberTypeInt8}; | |||
| std::map<std::string, int> opname_bit_; | |||
| private: | |||
| std::unique_ptr<QuantStrategy> quant_strategy_; | |||
| size_t bit_num_; | |||
| size_t bit_num_{8}; | |||
| std::string config_file_; | |||
| PostQuantConfig config_param_; | |||
| std::vector<std::vector<std::string>> images_; // multi_input, [[mode_input_0], [model_input_1]...] | |||
| @@ -61,6 +63,7 @@ class WeightQuantizer : public Quantizer { | |||
| STATUS DoMiexedQuant(FuncGraphPtr); | |||
| STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr<PrimitiveC> primitive_c); | |||
| STATUS DoFixedQuant(FuncGraphPtr); | |||
| }; | |||
| } // namespace mindspore::lite::quant | |||
| #endif | |||