diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt index 3ba28a075b..46df806374 100644 --- a/mindspore/lite/src/CMakeLists.txt +++ b/mindspore/lite/src/CMakeLists.txt @@ -39,6 +39,7 @@ set(LITE_SRC ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc ${CMAKE_CURRENT_SOURCE_DIR}/errorcode.cc + ${CMAKE_CURRENT_SOURCE_DIR}/dequant.cc ) if (SUPPORT_GPU) diff --git a/mindspore/lite/src/runtime/kernel/arm/base/dequant.cc b/mindspore/lite/src/dequant.cc similarity index 50% rename from mindspore/lite/src/runtime/kernel/arm/base/dequant.cc rename to mindspore/lite/src/dequant.cc index 4fe4d191e4..f13c403b54 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/dequant.cc +++ b/mindspore/lite/src/dequant.cc @@ -14,9 +14,9 @@ * limitations under the License. */ #include -#include "src/runtime/kernel/arm/base/dequant.h" +#include "src/dequant.h" -namespace mindspore::kernel { +namespace mindspore::lite { float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) { MS_ASSERT(input_tensor != nullptr); if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) { @@ -35,6 +35,8 @@ float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) { } void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_int_data) { + MS_ASSERT(input_tensor != nullptr); + MS_ASSERT(unpack_int_data != nullptr); auto quant_params = input_tensor->quantParams(); if (quant_params == nullptr) { MS_LOG(ERROR) << "low bits quantparams is empty."; @@ -47,4 +49,41 @@ void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_i UnPackUtil(input_tensor, origin_bit, unpack_int_data); } } -} // namespace mindspore::kernel + +std::map> DequantUtil::DequantTensor(const std::vector &in_tensors, + TypeId data_type) { + std::map> tensor_origin_data; + if (data_type == TypeId::kNumberTypeFloat32 || data_type == TypeId::kNumberTypeFloat16) { + for (auto weight_tensor : in_tensors) { + MS_ASSERT(weight_tensor != nullptr); + auto *restore_data = weight_tensor->data_c(); + auto restore_type = weight_tensor->data_type(); + bool dequant_flag = !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && + restore_data != nullptr; + if (dequant_flag) { + auto *dequant_weight = DequantUtil::DequantWeight(weight_tensor); + if (dequant_weight == nullptr) { + MS_LOG(ERROR) << "dequant data is nullptr."; + return tensor_origin_data; + } + weight_tensor->set_data(dequant_weight); + weight_tensor->set_data_type(kNumberTypeFloat32); + tensor_origin_data[weight_tensor] = {restore_type, restore_data}; + } + } + } + return tensor_origin_data; +} + +void DequantUtil::RestoreTensorData(const std::map> &tensor_origin_data_map) { + for (auto &kv : tensor_origin_data_map) { + auto *tensor = kv.first; + auto type_id = kv.second.first; + auto data = kv.second.second; + tensor->FreeData(); + tensor->set_data_type(type_id); + tensor->set_data(data); + } +} + +} // namespace mindspore::lite diff --git a/mindspore/lite/src/runtime/kernel/arm/base/dequant.h b/mindspore/lite/src/dequant.h similarity index 94% rename from mindspore/lite/src/runtime/kernel/arm/base/dequant.h rename to mindspore/lite/src/dequant.h index 641f5b66bd..b052515103 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/dequant.h +++ b/mindspore/lite/src/dequant.h @@ -17,6 +17,8 @@ #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_ +#include +#include #include #include #include @@ -24,13 +26,18 @@ #include "src/common/utils.h" #include "src/tensor.h" -namespace mindspore::kernel { +namespace mindspore::lite { class DequantUtil { public: static float *DequantWeight(lite::Tensor *input_tensor); static void UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data); + static std::map> DequantTensor(const std::vector &in_tensors, + TypeId data_type); + + static void RestoreTensorData(const std::map> &tensor_origin_data_map); + template static DT *DequantData(lite::Tensor *input_tensor) { const auto *quant_datas = static_cast(input_tensor->MutableData()); @@ -108,7 +115,7 @@ class DequantUtil { static void UnPackData(int origin_bit, const T2 &packed_data, std::queue *unpack_bit_data, void *unpack_int, size_t *count, bool is_last) { T2 uint_result = 0; - T1 result = 0; + T1 result; UnPackFromUintToOrigin(packed_data, unpack_bit_data); while (static_cast(unpack_bit_data->size()) >= origin_bit) { for (int k = 0; k < origin_bit; k++) { @@ -163,6 +170,6 @@ class DequantUtil { } } }; -} // namespace mindspore::kernel +} // namespace mindspore::lite #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_ diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 2ef9e880d5..438f79e276 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -27,7 +27,7 @@ #include "src/common/graph_util.h" #include "src/kernel_registry.h" #include "src/lite_model.h" -#include "src/runtime/kernel/arm/base/dequant.h" +#include "src/dequant.h" #if SUPPORT_NPU #include "src/runtime/agent/npu/npu_manager.h" #include "src/runtime/agent/npu/optimizer/npu_pass_manager.h" @@ -120,7 +120,7 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde MS_LOG(ERROR) << "Malloc data for tensor failed "; return RET_ERROR; } - kernel::DequantUtil::UnPackToInt(src_tensor, dst_tensor->MutableData()); + DequantUtil::UnPackToInt(src_tensor, dst_tensor->MutableData()); copyed_tensor_idxes_.emplace_back(tensor_index); } else { dst_tensor->set_data(const_cast(src_tensor->data()->data())); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.cc index ad25700d9a..fde2c532ae 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.cc @@ -25,7 +25,6 @@ #include "src/kernel_registry.h" #include "include/errorcode.h" #include "src/runtime/runtime_api.h" -#include "src/runtime/kernel/arm/base/dequant.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; @@ -359,22 +358,6 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector & MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); - auto *weight_tensor = inputs.at(kWeightIndex); - auto *restore_data = weight_tensor->data_c(); - auto restore_type = weight_tensor->data_type(); - bool dequant_flag = - !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; - if (dequant_flag) { - auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); - if (dequant_weight == nullptr) { - MS_LOG(ERROR) << "dequant data is nullptr."; - free(opParameter); - return nullptr; - } - weight_tensor->set_data_type(kNumberTypeFloat32); - weight_tensor->set_data(dequant_weight); - } - auto conv_param = reinterpret_cast(opParameter); kernel::LiteKernel *kernel = nullptr; if (conv_param->group_ == 1) { @@ -385,11 +368,6 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector & if (kernel == nullptr) { MS_LOG(DEBUG) << "Create conv fp16 kernel failed."; - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } free(opParameter); return nullptr; } @@ -398,20 +376,9 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector & if (ret != RET_OK) { MS_LOG(INFO) << "Init fp16 kernel failed, name: " << opParameter->name_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } delete kernel; return nullptr; } - - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } return kernel; } REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Conv2D, CpuConvFp16KernelCreator) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc index 623ebfffe5..1605dc5c27 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc @@ -22,7 +22,6 @@ #include "src/kernel_registry.h" #include "include/errorcode.h" #include "src/runtime/runtime_api.h" -#include "src/runtime/kernel/arm/base/dequant.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; @@ -138,22 +137,6 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); - auto *weight_tensor = inputs.at(kWeightIndex); - auto *restore_data = weight_tensor->data_c(); - auto restore_type = weight_tensor->data_type(); - bool dequant_flag = - !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; - if (dequant_flag) { - auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); - if (dequant_weight == nullptr) { - MS_LOG(ERROR) << "dequant data is nullptr."; - free(opParameter); - return nullptr; - } - weight_tensor->set_data_type(kNumberTypeFloat32); - weight_tensor->set_data(dequant_weight); - } - auto conv_param = reinterpret_cast(opParameter); kernel::LiteKernel *kernel; if (conv_param->input_channel_ < 32) { @@ -164,11 +147,6 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector } if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } free(opParameter); return nullptr; } @@ -176,19 +154,9 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector if (ret != RET_OK) { MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } delete kernel; return nullptr; } - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } return kernel; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc index 6e28e19a30..b8f289c63f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc @@ -20,7 +20,6 @@ #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "src/runtime/runtime_api.h" -#include "src/runtime/kernel/arm/base/dequant.h" #include "nnacl/fp16/conv_fp16.h" #include "nnacl/fp16/matmul_fp16.h" #include "nnacl/fp16/cast_fp16.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc index 3aa167aa76..b9f9fb1012 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc @@ -20,7 +20,6 @@ #include "src/kernel_registry.h" #include "include/errorcode.h" #include "src/runtime/runtime_api.h" -#include "src/runtime/kernel/arm/base/dequant.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; @@ -212,30 +211,9 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vectordata_c(); - auto restore_type = weight_tensor->data_type(); - auto dequant_flag = - !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; - if (dequant_flag) { - auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); - if (dequant_weight == nullptr) { - MS_LOG(ERROR) << "dequant data is nullptr."; - free(opParameter); - return nullptr; - } - weight_tensor->set_data_type(kNumberTypeFloat32); - weight_tensor->set_data(dequant_weight); - } - auto kernel = new (std::nothrow) DeconvolutionDepthwiseFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } free(opParameter); return nullptr; } @@ -243,19 +221,9 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vectorname_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } delete kernel; return nullptr; } - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } return kernel; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc index 8a199e0880..520d1885d9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc @@ -17,7 +17,6 @@ #include "src/runtime/kernel/arm/fp16/deconvolution_fp16.h" #include "src/runtime/kernel/arm/fp16/deconvolution_winograd_fp16.h" #include "src/runtime/runtime_api.h" -#include "src/runtime/kernel/arm/base/dequant.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; @@ -220,22 +219,6 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); - auto *weight_tensor = inputs.at(kWeightIndex); - auto *restore_data = weight_tensor->data_c(); - auto restore_type = weight_tensor->data_type(); - auto dequant_flag = - !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; - if (dequant_flag) { - auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); - if (dequant_weight == nullptr) { - MS_LOG(ERROR) << "dequant data is nullptr."; - free(opParameter); - return nullptr; - } - weight_tensor->set_data_type(kNumberTypeFloat32); - weight_tensor->set_data(dequant_weight); - } - kernel::LiteKernel *kernel; auto conv_param = reinterpret_cast(opParameter); if ((conv_param->stride_h_ != 1 || conv_param->stride_w_ != 1) && @@ -247,11 +230,6 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } free(opParameter); return nullptr; } @@ -259,19 +237,9 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector if (ret != RET_OK) { MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } delete kernel; return nullptr; } - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } return kernel; } REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_DeConv2D, CpuDeConvFp16KernelCreator) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc index d34c0dd098..7df1608a53 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc @@ -234,30 +234,9 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vectordata_c(); - auto restore_type = weight_tensor->data_type(); - bool dequant_flag = - !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; - if (dequant_flag) { - auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); - if (dequant_weight == nullptr) { - MS_LOG(ERROR) << "dequant data is nullptr."; - free(opParameter); - return nullptr; - } - weight_tensor->set_data_type(kNumberTypeFloat32); - weight_tensor->set_data(dequant_weight); - } auto *kernel = new (std::nothrow) FullconnectionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } free(opParameter); return nullptr; } @@ -265,19 +244,9 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vectorname_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } delete kernel; return nullptr; } - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } return kernel; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.h index cce92802ca..146bd34604 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.h @@ -24,7 +24,6 @@ #include "nnacl/fp16/matmul_fp16.h" #include "nnacl/fp16/cast_fp16.h" #include "src/lite_kernel.h" -#include "src/runtime/kernel/arm/base/dequant.h" namespace mindspore::kernel { class FullconnectionFP16CPUKernel : public LiteKernel { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc index 9f964307fa..f3f2a8f6b2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc @@ -20,7 +20,6 @@ #include "src/runtime/runtime_api.h" #include "include/errorcode.h" #include "src/kernel_registry.h" -#include "src/runtime/kernel/arm/base/dequant.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; @@ -330,29 +329,9 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector const std::vector &outputs, OpParameter *opParameter, const lite::InnerContext *ctx, const kernel::KernelKey &desc, const mindspore::lite::PrimitiveC *primitive) { - auto *weight_tensor = inputs.at(kWeightIndex); - auto *restore_data = weight_tensor->data_c(); - auto restore_type = weight_tensor->data_type(); - bool dequant_flag = - !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; - if (dequant_flag) { - auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); - if (dequant_weight == nullptr) { - MS_LOG(ERROR) << "dequant data is nullptr."; - free(opParameter); - return nullptr; - } - weight_tensor->set_data_type(kNumberTypeFloat32); - weight_tensor->set_data(dequant_weight); - } auto *kernel = new (std::nothrow) MatmulFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } free(opParameter); return nullptr; } @@ -361,18 +340,8 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); delete kernel; - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } return nullptr; } - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } return kernel; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.cc index 948c45c0e0..7b5b7f3f97 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.cc @@ -22,7 +22,6 @@ #include "src/kernel_registry.h" #include "include/errorcode.h" #include "src/runtime/runtime_api.h" -#include "src/runtime/kernel/arm/base/dequant.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; @@ -356,22 +355,6 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector & MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); MS_ASSERT(desc.data_type == kNumberTypeFloat32); - // if get quantized weight, dequantize it to float32 type data. - auto *weight_tensor = inputs.at(kWeightIndex); - auto *restore_data = weight_tensor->data_c(); - auto restore_type = weight_tensor->data_type(); - bool dequant_flag = - !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; - if (dequant_flag) { - auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); - if (dequant_weight == nullptr) { - MS_LOG(ERROR) << "dequant data is nullptr."; - free(op_parameter); - return nullptr; - } - weight_tensor->set_data(dequant_weight); - } - auto conv_param = reinterpret_cast(op_parameter); kernel::LiteKernel *kernel = nullptr; if (conv_param->group_ == 1) { @@ -382,11 +365,6 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector & if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } free(op_parameter); return nullptr; } @@ -395,20 +373,9 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector & if (ret != RET_OK && ret != RET_INFER_INVALID) { MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(op_parameter->type_)); - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } delete kernel; return nullptr; } - - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } return kernel; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.cc index 305f2716c0..e7f597ac13 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.cc @@ -21,7 +21,6 @@ #include "src/kernel_registry.h" #include "include/errorcode.h" #include "src/runtime/runtime_api.h" -#include "src/runtime/kernel/arm/base/dequant.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; @@ -126,19 +125,6 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); - auto *weight_tensor = inputs.at(kWeightIndex); - auto *restore_data = weight_tensor->data_c(); - auto restore_type = weight_tensor->data_type(); - if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { - auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); - if (dequant_weight == nullptr) { - MS_LOG(ERROR) << "dequant data is nullptr."; - free(opParameter); - return nullptr; - } - weight_tensor->set_data(dequant_weight); - } - auto conv_param = reinterpret_cast(opParameter); kernel::LiteKernel *kernel = nullptr; if (primitive != nullptr && primitive->infer_flag()) { @@ -162,11 +148,6 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector } if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; - if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } free(opParameter); return nullptr; } @@ -174,21 +155,10 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector if (ret != RET_OK && ret != RET_INFER_INVALID) { MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); - if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } delete kernel; return nullptr; } - if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } - return kernel; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.cc index 9c8416038a..bf01d96bab 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.cc @@ -19,7 +19,6 @@ #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "src/runtime/runtime_api.h" -#include "src/runtime/kernel/arm/base/dequant.h" #include "nnacl/fp32/conv_fp32.h" #include "nnacl/fp32/matmul_fp32.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise_fp32.cc index 486b2b0fcf..9840a17f19 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise_fp32.cc @@ -19,7 +19,6 @@ #include "src/kernel_registry.h" #include "include/errorcode.h" #include "src/runtime/runtime_api.h" -#include "src/runtime/kernel/arm/base/dequant.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; @@ -202,29 +201,10 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vectordata_c(); - auto restore_type = weight_tensor->data_type(); - bool dequant_flag = - !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; - if (dequant_flag) { - auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); - if (dequant_weight == nullptr) { - MS_LOG(ERROR) << "dequant data is nullptr."; - free(opParameter); - return nullptr; - } - weight_tensor->set_data(dequant_weight); - } auto kernel = new (std::nothrow) kernel::DeconvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } free(opParameter); return nullptr; } @@ -232,19 +212,9 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vectorname_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } delete kernel; return nullptr; } - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } return kernel; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_fp32.cc index d4f2b0a394..cb92adbd72 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_fp32.cc @@ -17,7 +17,6 @@ #include "src/runtime/kernel/arm/fp32/deconvolution_fp32.h" #include "src/runtime/kernel/arm/fp32/deconvolution_winograd_fp32.h" #include "src/runtime/runtime_api.h" -#include "src/runtime/kernel/arm/base/dequant.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; @@ -240,20 +239,6 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector const mindspore::lite::PrimitiveC *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); - auto *weight_tensor = inputs.at(kWeightIndex); - auto *restore_data = weight_tensor->data_c(); - auto restore_type = weight_tensor->data_type(); - bool dequant_flag = - !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; - if (dequant_flag) { - auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); - if (dequant_weight == nullptr) { - MS_LOG(ERROR) << "dequant data is nullptr."; - free(opParameter); - return nullptr; - } - weight_tensor->set_data(dequant_weight); - } kernel::LiteKernel *kernel; auto conv_param = reinterpret_cast(opParameter); @@ -266,11 +251,6 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } free(opParameter); return nullptr; } @@ -278,21 +258,9 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector if (ret != RET_OK) { MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } delete kernel; return nullptr; } - - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } - return kernel; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection_fp32.cc index 9f47dfee49..e460d81271 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection_fp32.cc @@ -228,28 +228,9 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vectordata_c(); - auto restore_type = weight_tensor->data_type(); - bool dequant_flag = - !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; - if (dequant_flag) { - auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); - if (dequant_weight == nullptr) { - MS_LOG(ERROR) << "dequant data is nullptr."; - return nullptr; - } - weight_tensor->set_data(dequant_weight); - } auto kernel = new (std::nothrow) FullconnectionCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (!kernel) { MS_LOG(ERROR) << "kernel is nullptr."; - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } free(opParameter); return nullptr; } @@ -257,19 +238,9 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vectorname_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } delete kernel; return nullptr; } - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } return kernel; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection_fp32.h index 9847cd5fe7..c4ef67b33f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection_fp32.h @@ -22,7 +22,6 @@ #include "include/errorcode.h" #include "nnacl/fp32/matmul_fp32.h" #include "src/lite_kernel.h" -#include "src/runtime/kernel/arm/base/dequant.h" using mindspore::lite::InnerContext; namespace mindspore::kernel { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32.cc index 86dd24277d..fe735320ea 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32.cc @@ -19,7 +19,6 @@ #include "nnacl/fp32/matmul_fp32.h" #include "src/runtime/runtime_api.h" #include "src/kernel_registry.h" -#include "src/runtime/kernel/arm/base/dequant.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_INPUT_TENSOR_ERROR; @@ -417,30 +416,9 @@ kernel::LiteKernel *CpuMatmulFp32KernelCreator(const std::vector const mindspore::lite::PrimitiveC *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_MatMul); - - auto *weight_tensor = inputs.at(kWeightIndex); - auto *restore_data = weight_tensor->data_c(); - auto restore_type = weight_tensor->data_type(); - bool dequant_flag = - !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; - if (dequant_flag) { - auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); - if (dequant_weight == nullptr) { - MS_LOG(ERROR) << "dequant data is nullptr."; - free(opParameter); - return nullptr; - } - weight_tensor->set_data(dequant_weight); - } - auto kernel = new (std::nothrow) MatmulCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } free(opParameter); return nullptr; } @@ -448,21 +426,9 @@ kernel::LiteKernel *CpuMatmulFp32KernelCreator(const std::vector if (ret != RET_OK) { MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } delete kernel; return nullptr; } - - if (dequant_flag) { - weight_tensor->FreeData(); - weight_tensor->set_data(restore_data); - weight_tensor->set_data_type(restore_type); - } - return kernel; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc index d7ef223208..a1d8665b60 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc @@ -15,7 +15,7 @@ */ #include "src/runtime/kernel/opencl/opencl_kernel.h" -#include "src/runtime/kernel/arm/base/dequant.h" +#include "mindspore/lite/src/dequant.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; @@ -263,10 +263,10 @@ int OpenCLKernel::DequantWeight() { if (is_fp16) { #ifdef ENABLE_ARM64 if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt8) { - dequant_weight = kernel::DequantUtil::DequantData(weight_tensor); + dequant_weight = lite::DequantUtil::DequantData(weight_tensor); weight_tensor->set_data_type(kNumberTypeFloat16); } else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt16) { - dequant_weight = kernel::DequantUtil::DequantData(weight_tensor); + dequant_weight = lite::DequantUtil::DequantData(weight_tensor); weight_tensor->set_data_type(kNumberTypeFloat16); } else { set_flag = false; @@ -276,10 +276,10 @@ int OpenCLKernel::DequantWeight() { #endif } else { if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt8) { - dequant_weight = kernel::DequantUtil::DequantData(weight_tensor); + dequant_weight = lite::DequantUtil::DequantData(weight_tensor); weight_tensor->set_data_type(kNumberTypeFloat32); } else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt16) { - dequant_weight = kernel::DequantUtil::DequantData(weight_tensor); + dequant_weight = lite::DequantUtil::DequantData(weight_tensor); weight_tensor->set_data_type(kNumberTypeFloat32); } else { set_flag = false; diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h index 69ff966f6e..5e68c6c7d8 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h @@ -25,7 +25,7 @@ #include "src/lite_kernel.h" #include "include/errorcode.h" #include "src/runtime/opencl/opencl_runtime.h" -#include "src/runtime/kernel/arm/base/dequant.h" +#include "mindspore/lite/src/dequant.h" #include "src/runtime/kernel/opencl/utils.h" using mindspore::lite::RET_ERROR; diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index 745dd8c69e..6523110f60 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -27,6 +27,7 @@ #include "src/common/utils.h" #include "src/kernel_registry.h" #include "src/sub_graph_kernel.h" +#include "src/dequant.h" #if SUPPORT_GPU #include "src/runtime/kernel/opencl/opencl_subgraph.h" #include "src/runtime/opencl/opencl_runtime.h" @@ -213,8 +214,10 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector &in if (mindspore::lite::IsSupportFloat16() && ((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) { kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type}; + auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, fp16_cpu_desc.data_type); auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc); + DequantUtil::RestoreTensorData(tensor_origin_data_map); if (kernel != nullptr) { MS_LOG(DEBUG) << "Get fp16 op success: " << schema::EnumNamePrimitiveType(fp16_cpu_desc.type) << " " << node->name_; @@ -225,7 +228,9 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector &in MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op."; desc.data_type = kNumberTypeFloat32; } + auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, desc.data_type); auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); + DequantUtil::RestoreTensorData(tensor_origin_data_map); if (kernel != nullptr) { return kernel; } diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 9ec364865d..1b592aaae9 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -126,6 +126,7 @@ set(TEST_LITE_SRC ${LITE_DIR}/src/kernel_registry.cc ${LITE_DIR}/src/lite_kernel.cc ${LITE_DIR}/src/lite_session.cc + ${LITE_DIR}/src/dequant.cc ${LITE_DIR}/src/sub_graph_kernel.cc ${LITE_DIR}/src/lite_model.cc ${LITE_DIR}/src/scheduler.cc diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index fa3fac0a1e..aa3d61c558 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -95,6 +95,7 @@ set(LITE_SRC ${SRC_DIR}/executor.cc ${SRC_DIR}/lite_model.cc ${SRC_DIR}/errorcode.cc + ${SRC_DIR}/dequant.cc ) if (SUPPORT_TRAIN) set(LITE_SRC diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index 1ec399c060..d854bcdaae 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -782,4 +782,27 @@ FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &func_graph) { return new_func_graph; } +void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, ParamValueLitePtr *param_value) { + MS_ASSERT(node != nullptr); + MS_ASSERT(param_node != nullptr); + MS_ASSERT(param_value != nullptr); + + auto op_name = node->fullname_with_scope(); + + *param_node = node->cast(); + if (*param_node == nullptr) { + MS_LOG(INFO) << op_name << " can not cast to ParameterPtr"; + return; + } + if (!(*param_node)->has_default()) { + MS_LOG(INFO) << op_name << " not has_default"; + return; + } + + *param_value = std::static_pointer_cast((*param_node)->default_param()); + if (*param_value == nullptr) { + MS_LOG(INFO) << "default_param can not cast to ParamValueLite"; + return; + } +} } // namespace mindspore::lite::quant diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index 7fae9738ef..3f1a42218c 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -75,9 +75,10 @@ class QuantStrategy { bool CanMulOpQuantized(const CNodePtr &node) const; bool CanOpPostQuantized(AnfNodePtr &node) const; - private: size_t mWeightSize; size_t mConvWeightQuantChannelThreshold; + + private: static const std::vector conv_types; static const std::vector mul_types; }; @@ -356,5 +357,8 @@ STATUS CopyInputDataToTensor(size_t input_index, size_t image_index, const std::vector> &images, mindspore::tensor::MSTensor *tensor); FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &); + +void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, ParamValueLitePtr *param_value); + } // namespace mindspore::lite::quant #endif diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc index 04038d6ce7..19165c4057 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc @@ -20,7 +20,6 @@ #include #include #include "src/common/common.h" -#include "ir/dtype/type_id.h" using std::string; using std::vector; @@ -73,13 +72,13 @@ WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const std::string &config_f this->bit_num_ = static_cast(std::stoull(bitNum)); auto convQuantWeightChannelThreshold = static_cast(std::stoull(convWeightChannelThreshold)); quant_strategy_ = std::make_unique(quantSize, convQuantWeightChannelThreshold); - quant_max = (1 << (unsigned int)(this->bit_num_ - 1)) - 1; - quant_min = -(1 << (unsigned int)(this->bit_num_ - 1)); - // parse type_id + quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1; + quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1)); + // parse type_id_ if (this->bit_num_ > 0 && this->bit_num_ <= 8) { - type_id = kNumberTypeInt8; + type_id_ = kNumberTypeInt8; } else if (this->bit_num_ <= 16) { - type_id = kNumberTypeInt16; + type_id_ = kNumberTypeInt16; } else { MS_LOG(ERROR) << "invalid input bits"; } @@ -90,7 +89,7 @@ WeightQuantizer::~WeightQuantizer() { delete fp32_session_; } STATUS WeightQuantizer::SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr primitive_c) { // set dtype - param_value->set_tensor_type(type_id); + param_value->set_tensor_type(type_id_); auto abstract_base = param_node->abstract(); if (abstract_base == nullptr) { MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); @@ -101,49 +100,158 @@ STATUS WeightQuantizer::SetAbstract(ParamValueLitePtr param_value, ParameterPtr return RET_ERROR; } auto abstract_tensor = utils::cast(abstract_base); - abstract_tensor->element()->set_type(TypeIdToType(type_id)); + abstract_tensor->element()->set_type(TypeIdToType(type_id_)); primitive_c->set_quant_type(schema::QuantType_WeightQuant); return RET_OK; } -STATUS WeightQuantizer::DoConvQuantize(const std::list &nodes) { - for (auto &cnode : nodes) { - if (!quant_strategy_->CanConvOpQuantized(cnode)) { - continue; - } +STATUS WeightQuantizer::DoConvQuantize(CNodePtr cnode) { + auto primitive_c = GetValueNode>(cnode->input(0)); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is nullptr"; + return RET_ERROR; + } - auto primitive_c = GetValueNode>(cnode->input(0)); - if (primitive_c == nullptr) { - MS_LOG(ERROR) << "primitive_c is nullptr"; - return RET_ERROR; - } + auto input_node = cnode->input(2); + if (!input_node->isa()) { + return RET_ERROR; + } - auto input_node = cnode->input(2); - if (!input_node->isa()) { - return RET_ERROR; - } + ParameterPtr param_node; + ParamValueLitePtr param_value; - auto param_node = input_node->cast(); - if (!param_node->has_default()) { - return RET_ERROR; + GetLiteParameter(input_node, ¶m_node, ¶m_value); + if (param_node == nullptr || param_value == nullptr) { + MS_LOG(ERROR) << "GetLiteParameter error"; + return RET_ERROR; + } + + if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) { + MS_LOG(ERROR) << "model weight data type invalid which is " << param_value->tensor_type(); + return RET_ERROR; + } + auto status = RET_ERROR; + if (type_id_ == kNumberTypeInt8) { + status = + QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true); + } else if (type_id_ == kNumberTypeInt16) { + status = + QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true); + } + if (status != RET_OK) { + MS_LOG(ERROR) << "QuantFilter failed : " << status; + return status; + } + status = SetAbstract(param_value, param_node, primitive_c); + if (status != RET_OK) { + MS_LOG(ERROR) << "SetAbstract failed : " << status; + return RET_ERROR; + } + return RET_OK; +} + +STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) { + auto already_quant = false; + ParamValueLitePtr param_value = nullptr; + ParameterPtr param_node = nullptr; + for (size_t i = 1; i < cnode->size(); i++) { + auto inputNode = cnode->input(i); + if (inputNode->isa()) { + param_node = inputNode->cast(); + if ((param_node != nullptr) && param_node->has_default()) { + param_value = std::static_pointer_cast(param_node->default_param()); + if ((param_value == nullptr) || (param_value->tensor_size() == 0) || (param_value->tensor_addr() == nullptr)) { + param_value = nullptr; + continue; + } else if (param_value->tensor_type() == mindspore::kNumberTypeInt8 || + param_value->tensor_type() == mindspore::kNumberTypeInt16) { + MS_LOG(INFO) << "the node: " << cnode->fullname_with_scope() << " input_i: " << i << "has been " + << " quantized"; + already_quant = true; + break; + } else if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) { + param_value = nullptr; + continue; + } else { + break; + } + } } + } - ParamValueLitePtr param_value = std::static_pointer_cast(param_node->default_param()); - if (param_value == nullptr) { + if (already_quant) { + return RET_OK; + } + + if (param_value == nullptr) { + MS_LOG(ERROR) << "No valid input param node !"; + return RET_ERROR; + } + + auto primitive_c = GetValueNode>(cnode->input(0)); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is nullptr"; + return RET_ERROR; + } + + auto status = RET_ERROR; + if (type_id_ == kNumberTypeInt8) { + status = + QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true); + } else if (type_id_ == kNumberTypeInt16) { + status = + QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true); + } + if (status != RET_OK) { + MS_LOG(ERROR) << "QuantFilter failed : " << status; + return status; + } + status = SetAbstract(param_value, param_node, primitive_c); + if (status != RET_OK) { + MS_LOG(ERROR) << "SetAbstract failed : " << status; + return RET_ERROR; + } + + return RET_OK; +} + +STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) { + MS_ASSERT(cnode != nullptr); + auto op_name = cnode->fullname_with_scope(); + + auto primitive_c = GetValueNode>(cnode->input(0)); + MS_ASSERT(primitive_c != nullptr); + + if (cnode->inputs().size() < 4) { + MS_LOG(ERROR) << op_name << " inputs is " << cnode->inputs().size(); + return RET_ERROR; + } + { + auto weight_i = cnode->input(2); + ParameterPtr param_node; + ParamValueLitePtr param_value; + GetLiteParameter(weight_i, ¶m_node, ¶m_value); + if (param_node == nullptr || param_value == nullptr) { + MS_LOG(ERROR) << "GetLiteParameter error"; return RET_ERROR; } - if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) { - MS_LOG(ERROR) << "model weight data type invalid which is " << param_value->tensor_type(); - return RET_ERROR; + if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { + MS_LOG(WARNING) << "param_value tensor type is: " << param_value->tensor_type() << " not quant"; + return RET_OK; + } + if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) { + MS_LOG(INFO) << op_name << " weight_i cnt: " << param_value->tensor_size() / 4 << " < " + << quant_strategy_->mWeightSize; + return RET_OK; } auto status = RET_ERROR; - if (type_id == kNumberTypeInt8) { + if (type_id_ == kNumberTypeInt8) { status = - QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true); - } else if (type_id == kNumberTypeInt16) { + QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); + } else if (type_id_ == kNumberTypeInt16) { status = - QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true); + QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); } if (status != RET_OK) { MS_LOG(ERROR) << "QuantFilter failed : " << status; @@ -155,65 +263,26 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list &nodes) { return RET_ERROR; } } - return RET_OK; -} - -STATUS WeightQuantizer::DoMulQuantize(const std::list &nodes) { - for (auto &node : nodes) { - if (!quant_strategy_->CanMulOpQuantized(node)) { - continue; - } - auto already_quant = false; - ParamValueLitePtr param_value = nullptr; - ParameterPtr param_node = nullptr; - for (size_t i = 1; i < node->size(); i++) { - auto inputNode = node->input(i); - if (inputNode->isa()) { - param_node = inputNode->cast(); - if ((param_node != nullptr) && param_node->has_default()) { - param_value = std::static_pointer_cast(param_node->default_param()); - if ((param_value == nullptr) || (param_value->tensor_size() == 0) || - (param_value->tensor_addr() == nullptr)) { - param_value = nullptr; - continue; - } else if (param_value->tensor_type() == mindspore::kNumberTypeInt8 || - param_value->tensor_type() == mindspore::kNumberTypeInt16) { - MS_LOG(INFO) << "the node: " << node->fullname_with_scope() << " input_i: " << i << "has been " - << " quantized"; - already_quant = true; - break; - } else if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) { - param_value = nullptr; - continue; - } else { - break; - } - } - } - } - - if (already_quant) { - continue; - } - - if (param_value == nullptr) { - MS_LOG(ERROR) << "No valid input param node !"; + { + auto weight_h = cnode->input(3); + ParameterPtr param_node; + ParamValueLitePtr param_value; + GetLiteParameter(weight_h, ¶m_node, ¶m_value); + if (param_node == nullptr || param_value == nullptr) { + MS_LOG(ERROR) << "GetLiteParameter error"; return RET_ERROR; } - - auto primitive_c = GetValueNode>(node->input(0)); - if (primitive_c == nullptr) { - MS_LOG(ERROR) << "primitive_c is nullptr"; + if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { + MS_LOG(ERROR) << "param_value tensor type is: " << param_value->tensor_type() << " not quant"; return RET_ERROR; } - auto status = RET_ERROR; - if (type_id == kNumberTypeInt8) { + if (type_id_ == kNumberTypeInt8) { status = - QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true); - } else if (type_id == kNumberTypeInt16) { + QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); + } else if (type_id_ == kNumberTypeInt16) { status = - QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true); + QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); } if (status != RET_OK) { MS_LOG(ERROR) << "QuantFilter failed : " << status; @@ -225,7 +294,78 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list &nodes) { return RET_ERROR; } } + { + if (cnode->inputs().size() > 4) { + auto bias = cnode->input(4); + ParameterPtr param_node; + ParamValueLitePtr param_value; + GetLiteParameter(bias, ¶m_node, ¶m_value); + if (param_node == nullptr || param_value == nullptr) { + MS_LOG(ERROR) << "GetLiteParameter error"; + return RET_ERROR; + } + if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { + MS_LOG(ERROR) << "param_value tensor type is: " << param_value->tensor_type() << " not quant"; + return RET_ERROR; + } + auto status = RET_ERROR; + if (type_id_ == kNumberTypeInt8) { + status = + QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); + } else if (type_id_ == kNumberTypeInt16) { + status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, + false); + } + if (status != RET_OK) { + MS_LOG(ERROR) << "QuantFilter failed : " << status; + return status; + } + status = SetAbstract(param_value, param_node, primitive_c); + if (status != RET_OK) { + MS_LOG(ERROR) << "SetAbstract failed : " << status; + return RET_ERROR; + } + } + } + return RET_OK; +} + +STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) { + auto primitive_c = GetValueNode>(cnode->input(0)); + MS_ASSERT(primitive_c != nullptr); + + auto weight_h = cnode->input(1); + ParameterPtr param_node; + ParamValueLitePtr param_value; + GetLiteParameter(weight_h, ¶m_node, ¶m_value); + if (param_node == nullptr || param_value == nullptr || param_value->tensor_type() != TypeId::kNumberTypeFloat32) { + MS_LOG(INFO) << "This Gather op " << cnode->fullname_with_scope() << " can not quant weight"; + return RET_OK; + } + + if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) { + MS_LOG(INFO) << cnode->fullname_with_scope() << " param cnt: " << param_value->tensor_size() / 4 << " < " + << quant_strategy_->mWeightSize; + return RET_OK; + } + auto status = RET_ERROR; + if (type_id_ == kNumberTypeInt8) { + status = + QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); + } else if (type_id_ == kNumberTypeInt16) { + status = + QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); + } + if (status != RET_OK) { + MS_LOG(ERROR) << "QuantFilter failed : " << status; + return status; + } + status = SetAbstract(param_value, param_node, primitive_c); + if (status != RET_OK) { + MS_LOG(ERROR) << "SetAbstract failed : " << status; + return RET_ERROR; + } return RET_OK; } @@ -315,6 +455,23 @@ STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) { } auto cnodes = func_graph->GetOrderedCnodes(); + for (auto &cnode : cnodes) { + auto op_type = NodePrimitiveType(cnode); + if (op_type == schema::PrimitiveType_Lstm) { + status = DoLstmQuntize(cnode); + if (status != RET_OK) { + MS_LOG(ERROR) << "DoLstmQuntize error"; + return RET_ERROR; + } + } else if (op_type == schema::PrimitiveType_Gather) { + status = DoGatherQuntize(cnode); + if (status != RET_OK) { + MS_LOG(ERROR) << "DoGatherQuntize error"; + return RET_ERROR; + } + } + } + for (auto iter = cnodes.end(); iter != cnodes.begin();) { auto cnode = *(--iter); auto primitive_c = GetValueNode>(cnode->input(0)); @@ -357,18 +514,18 @@ STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) { } // 1. try quant for (int bit_num_t = 2; bit_num_t <= 8; bit_num_t++) { - type_id = TypeId::kNumberTypeInt8; + type_id_ = TypeId::kNumberTypeInt8; int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1; int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1)); - if (type_id == TypeId::kNumberTypeInt8) { + if (type_id_ == TypeId::kNumberTypeInt8) { status = QuantFilter(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t, bit_num_t, true); - } else if (type_id == TypeId::kNumberTypeInt16) { + } else if (type_id_ == TypeId::kNumberTypeInt16) { status = QuantFilter(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t, bit_num_t, true); } else { - MS_LOG(ERROR) << "unexpected type_id: " << type_id; + MS_LOG(ERROR) << "unexpected type_id_: " << type_id_; return RET_ERROR; } if (status != RET_OK) { @@ -456,13 +613,53 @@ STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) { return RET_OK; } +STATUS WeightQuantizer::DoFixedQuant(FuncGraphPtr func_graph) { + MS_ASSERT(func_graph != nullptr); + for (auto &cnode : func_graph->GetOrderedCnodes()) { + auto primitive_c = GetValueNode>(cnode->input(0)); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is nullptr"; + return RET_ERROR; + } + auto op_name = cnode->fullname_with_scope(); + auto op_type = (schema::PrimitiveType)primitive_c->Type(); + + if (quant_strategy_->CanConvOpQuantized(cnode)) { + auto status = DoConvQuantize(cnode); + if (status != RET_OK) { + MS_LOG(ERROR) << "DoConvQuantize error"; + return RET_ERROR; + } + } else if (quant_strategy_->CanMulOpQuantized(cnode)) { + auto status = DoMulQuantize(cnode); + if (status != RET_OK) { + MS_LOG(ERROR) << "DoMulQuantize error"; + return RET_ERROR; + } + } else if (op_type == schema::PrimitiveType_Lstm) { + auto status = DoLstmQuntize(cnode); + if (status != RET_OK) { + MS_LOG(ERROR) << "DoLstmQuntize error"; + return RET_ERROR; + } + } else if (op_type == schema::PrimitiveType_Gather) { + auto status = DoGatherQuntize(cnode); + if (status != RET_OK) { + MS_LOG(ERROR) << "DoGatherQuntize error"; + return RET_ERROR; + } + } else { + MS_LOG(DEBUG) << op_name << " of type: " << schema::EnumNamePrimitiveType(op_type) << " no need quant"; + } + } + return RET_OK; +} + STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) { MS_ASSERT(func_graph != nullptr); - STATUS ret; - auto cnodes = func_graph->GetOrderedCnodes(); if (!config_file_.empty()) { - ret = ParseConfigFile(config_file_, &config_param_); + auto ret = ParseConfigFile(config_file_, &config_param_); if (ret != RET_OK) { MS_LOG(ERROR) << "ReadConfig error."; return RET_ERROR; @@ -470,20 +667,14 @@ STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) { } if (config_param_.mixed) { + bit_num_ = 8; + quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1; + quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1)); + type_id_ = kNumberTypeInt8; MS_LOG(INFO) << "Do mixed bit quantization"; return DoMiexedQuant(func_graph); } - ret = DoConvQuantize(cnodes); - if (ret != RET_OK) { - MS_LOG(ERROR) << "DoConvQuantize failed :" << ret; - return ret; - } - ret = DoMulQuantize(cnodes); - if (ret != RET_OK) { - MS_LOG(ERROR) << "DoMulQuantize failed :" << ret; - return ret; - } - return ret; + return DoFixedQuant(func_graph); } } // namespace mindspore::lite::quant diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.h b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h index 7e7494d2b2..6382829b5f 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h @@ -41,19 +41,21 @@ class WeightQuantizer : public Quantizer { ~WeightQuantizer(); STATUS DoQuantize(FuncGraphPtr func_graph) override; - STATUS DoConvQuantize(const std::list &nodes); - STATUS DoMulQuantize(const std::list &nodes); + STATUS DoConvQuantize(CNodePtr); + STATUS DoMulQuantize(CNodePtr); + STATUS DoLstmQuntize(CNodePtr cnode); + STATUS DoGatherQuntize(CNodePtr cnode); static STATUS WeightQuantInputCheck(const converter::Flags *config); static bool IsPosNum(const std::string &str); - int quant_max; - int quant_min; - TypeId type_id{kTypeUnknown}; + int quant_max_{127}; + int quant_min_{-128}; + TypeId type_id_{kNumberTypeInt8}; std::map opname_bit_; private: std::unique_ptr quant_strategy_; - size_t bit_num_; + size_t bit_num_{8}; std::string config_file_; PostQuantConfig config_param_; std::vector> images_; // multi_input, [[mode_input_0], [model_input_1]...] @@ -61,6 +63,7 @@ class WeightQuantizer : public Quantizer { STATUS DoMiexedQuant(FuncGraphPtr); STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr primitive_c); + STATUS DoFixedQuant(FuncGraphPtr); }; } // namespace mindspore::lite::quant #endif