From: @xutianchun Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -39,6 +39,7 @@ set(LITE_SRC | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc | ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc | ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/errorcode.cc | ${CMAKE_CURRENT_SOURCE_DIR}/errorcode.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/dequant.cc | |||||
| ) | ) | ||||
| if (SUPPORT_GPU) | if (SUPPORT_GPU) | ||||
| @@ -14,9 +14,9 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <cmath> | #include <cmath> | ||||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||||
| #include "src/dequant.h" | |||||
| namespace mindspore::kernel { | |||||
| namespace mindspore::lite { | |||||
| float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) { | float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) { | ||||
| MS_ASSERT(input_tensor != nullptr); | MS_ASSERT(input_tensor != nullptr); | ||||
| if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) { | if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) { | ||||
| @@ -35,6 +35,8 @@ float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) { | |||||
| } | } | ||||
| void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_int_data) { | void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_int_data) { | ||||
| MS_ASSERT(input_tensor != nullptr); | |||||
| MS_ASSERT(unpack_int_data != nullptr); | |||||
| auto quant_params = input_tensor->quantParams(); | auto quant_params = input_tensor->quantParams(); | ||||
| if (quant_params == nullptr) { | if (quant_params == nullptr) { | ||||
| MS_LOG(ERROR) << "low bits quantparams is empty."; | MS_LOG(ERROR) << "low bits quantparams is empty."; | ||||
| @@ -47,4 +49,41 @@ void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_i | |||||
| UnPackUtil<int16_t, uint16_t>(input_tensor, origin_bit, unpack_int_data); | UnPackUtil<int16_t, uint16_t>(input_tensor, origin_bit, unpack_int_data); | ||||
| } | } | ||||
| } | } | ||||
| } // namespace mindspore::kernel | |||||
| std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const std::vector<Tensor *> &in_tensors, | |||||
| TypeId data_type) { | |||||
| std::map<Tensor *, std::pair<TypeId, void *>> tensor_origin_data; | |||||
| if (data_type == TypeId::kNumberTypeFloat32 || data_type == TypeId::kNumberTypeFloat16) { | |||||
| for (auto weight_tensor : in_tensors) { | |||||
| MS_ASSERT(weight_tensor != nullptr); | |||||
| auto *restore_data = weight_tensor->data_c(); | |||||
| auto restore_type = weight_tensor->data_type(); | |||||
| bool dequant_flag = !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && | |||||
| restore_data != nullptr; | |||||
| if (dequant_flag) { | |||||
| auto *dequant_weight = DequantUtil::DequantWeight(weight_tensor); | |||||
| if (dequant_weight == nullptr) { | |||||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||||
| return tensor_origin_data; | |||||
| } | |||||
| weight_tensor->set_data(dequant_weight); | |||||
| weight_tensor->set_data_type(kNumberTypeFloat32); | |||||
| tensor_origin_data[weight_tensor] = {restore_type, restore_data}; | |||||
| } | |||||
| } | |||||
| } | |||||
| return tensor_origin_data; | |||||
| } | |||||
| void DequantUtil::RestoreTensorData(const std::map<Tensor *, std::pair<TypeId, void *>> &tensor_origin_data_map) { | |||||
| for (auto &kv : tensor_origin_data_map) { | |||||
| auto *tensor = kv.first; | |||||
| auto type_id = kv.second.first; | |||||
| auto data = kv.second.second; | |||||
| tensor->FreeData(); | |||||
| tensor->set_data_type(type_id); | |||||
| tensor->set_data(data); | |||||
| } | |||||
| } | |||||
| } // namespace mindspore::lite | |||||
| @@ -17,6 +17,8 @@ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_ | #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_ | ||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_ | #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_ | ||||
| #include <map> | |||||
| #include <utility> | |||||
| #include <vector> | #include <vector> | ||||
| #include <queue> | #include <queue> | ||||
| #include <cmath> | #include <cmath> | ||||
| @@ -24,13 +26,18 @@ | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/tensor.h" | #include "src/tensor.h" | ||||
| namespace mindspore::kernel { | |||||
| namespace mindspore::lite { | |||||
| class DequantUtil { | class DequantUtil { | ||||
| public: | public: | ||||
| static float *DequantWeight(lite::Tensor *input_tensor); | static float *DequantWeight(lite::Tensor *input_tensor); | ||||
| static void UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data); | static void UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data); | ||||
| static std::map<Tensor *, std::pair<TypeId, void *>> DequantTensor(const std::vector<Tensor *> &in_tensors, | |||||
| TypeId data_type); | |||||
| static void RestoreTensorData(const std::map<Tensor *, std::pair<TypeId, void *>> &tensor_origin_data_map); | |||||
| template <typename ST, typename DT = float> | template <typename ST, typename DT = float> | ||||
| static DT *DequantData(lite::Tensor *input_tensor) { | static DT *DequantData(lite::Tensor *input_tensor) { | ||||
| const auto *quant_datas = static_cast<const ST *>(input_tensor->MutableData()); | const auto *quant_datas = static_cast<const ST *>(input_tensor->MutableData()); | ||||
| @@ -108,7 +115,7 @@ class DequantUtil { | |||||
| static void UnPackData(int origin_bit, const T2 &packed_data, std::queue<bool> *unpack_bit_data, void *unpack_int, | static void UnPackData(int origin_bit, const T2 &packed_data, std::queue<bool> *unpack_bit_data, void *unpack_int, | ||||
| size_t *count, bool is_last) { | size_t *count, bool is_last) { | ||||
| T2 uint_result = 0; | T2 uint_result = 0; | ||||
| T1 result = 0; | |||||
| T1 result; | |||||
| UnPackFromUintToOrigin<T2>(packed_data, unpack_bit_data); | UnPackFromUintToOrigin<T2>(packed_data, unpack_bit_data); | ||||
| while (static_cast<int>(unpack_bit_data->size()) >= origin_bit) { | while (static_cast<int>(unpack_bit_data->size()) >= origin_bit) { | ||||
| for (int k = 0; k < origin_bit; k++) { | for (int k = 0; k < origin_bit; k++) { | ||||
| @@ -163,6 +170,6 @@ class DequantUtil { | |||||
| } | } | ||||
| } | } | ||||
| }; | }; | ||||
| } // namespace mindspore::kernel | |||||
| } // namespace mindspore::lite | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_ | #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_ | ||||
| @@ -27,7 +27,7 @@ | |||||
| #include "src/common/graph_util.h" | #include "src/common/graph_util.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/lite_model.h" | #include "src/lite_model.h" | ||||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||||
| #include "src/dequant.h" | |||||
| #if SUPPORT_NPU | #if SUPPORT_NPU | ||||
| #include "src/runtime/agent/npu/npu_manager.h" | #include "src/runtime/agent/npu/npu_manager.h" | ||||
| #include "src/runtime/agent/npu/optimizer/npu_pass_manager.h" | #include "src/runtime/agent/npu/optimizer/npu_pass_manager.h" | ||||
| @@ -120,7 +120,7 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde | |||||
| MS_LOG(ERROR) << "Malloc data for tensor failed "; | MS_LOG(ERROR) << "Malloc data for tensor failed "; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| kernel::DequantUtil::UnPackToInt(src_tensor, dst_tensor->MutableData()); | |||||
| DequantUtil::UnPackToInt(src_tensor, dst_tensor->MutableData()); | |||||
| copyed_tensor_idxes_.emplace_back(tensor_index); | copyed_tensor_idxes_.emplace_back(tensor_index); | ||||
| } else { | } else { | ||||
| dst_tensor->set_data(const_cast<unsigned char *>(src_tensor->data()->data())); | dst_tensor->set_data(const_cast<unsigned char *>(src_tensor->data()->data())); | ||||
| @@ -25,7 +25,6 @@ | |||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/runtime/runtime_api.h" | #include "src/runtime/runtime_api.h" | ||||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | using mindspore::kernel::KERNEL_ARCH::kCPU; | ||||
| using mindspore::lite::KernelRegistrar; | using mindspore::lite::KernelRegistrar; | ||||
| @@ -359,22 +358,6 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> & | |||||
| MS_ASSERT(opParameter != nullptr); | MS_ASSERT(opParameter != nullptr); | ||||
| MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); | MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); | ||||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||||
| auto *restore_data = weight_tensor->data_c(); | |||||
| auto restore_type = weight_tensor->data_type(); | |||||
| bool dequant_flag = | |||||
| !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; | |||||
| if (dequant_flag) { | |||||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||||
| if (dequant_weight == nullptr) { | |||||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||||
| free(opParameter); | |||||
| return nullptr; | |||||
| } | |||||
| weight_tensor->set_data_type(kNumberTypeFloat32); | |||||
| weight_tensor->set_data(dequant_weight); | |||||
| } | |||||
| auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | ||||
| kernel::LiteKernel *kernel = nullptr; | kernel::LiteKernel *kernel = nullptr; | ||||
| if (conv_param->group_ == 1) { | if (conv_param->group_ == 1) { | ||||
| @@ -385,11 +368,6 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> & | |||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| MS_LOG(DEBUG) << "Create conv fp16 kernel failed."; | MS_LOG(DEBUG) << "Create conv fp16 kernel failed."; | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| free(opParameter); | free(opParameter); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -398,20 +376,9 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> & | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(INFO) << "Init fp16 kernel failed, name: " << opParameter->name_ | MS_LOG(INFO) << "Init fp16 kernel failed, name: " << opParameter->name_ | ||||
| << ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | << ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| delete kernel; | delete kernel; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| return kernel; | return kernel; | ||||
| } | } | ||||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Conv2D, CpuConvFp16KernelCreator) | REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Conv2D, CpuConvFp16KernelCreator) | ||||
| @@ -22,7 +22,6 @@ | |||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/runtime/runtime_api.h" | #include "src/runtime/runtime_api.h" | ||||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | using mindspore::kernel::KERNEL_ARCH::kCPU; | ||||
| using mindspore::lite::KernelRegistrar; | using mindspore::lite::KernelRegistrar; | ||||
| @@ -138,22 +137,6 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *> | |||||
| MS_ASSERT(opParameter != nullptr); | MS_ASSERT(opParameter != nullptr); | ||||
| MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); | MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); | ||||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||||
| auto *restore_data = weight_tensor->data_c(); | |||||
| auto restore_type = weight_tensor->data_type(); | |||||
| bool dequant_flag = | |||||
| !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; | |||||
| if (dequant_flag) { | |||||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||||
| if (dequant_weight == nullptr) { | |||||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||||
| free(opParameter); | |||||
| return nullptr; | |||||
| } | |||||
| weight_tensor->set_data_type(kNumberTypeFloat32); | |||||
| weight_tensor->set_data(dequant_weight); | |||||
| } | |||||
| auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | ||||
| kernel::LiteKernel *kernel; | kernel::LiteKernel *kernel; | ||||
| if (conv_param->input_channel_ < 32) { | if (conv_param->input_channel_ < 32) { | ||||
| @@ -164,11 +147,6 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *> | |||||
| } | } | ||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| MS_LOG(ERROR) << "kernel is nullptr."; | MS_LOG(ERROR) << "kernel is nullptr."; | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| free(opParameter); | free(opParameter); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -176,19 +154,9 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *> | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | ||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| delete kernel; | delete kernel; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| return kernel; | return kernel; | ||||
| } | } | ||||
| @@ -20,7 +20,6 @@ | |||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/runtime/runtime_api.h" | #include "src/runtime/runtime_api.h" | ||||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||||
| #include "nnacl/fp16/conv_fp16.h" | #include "nnacl/fp16/conv_fp16.h" | ||||
| #include "nnacl/fp16/matmul_fp16.h" | #include "nnacl/fp16/matmul_fp16.h" | ||||
| #include "nnacl/fp16/cast_fp16.h" | #include "nnacl/fp16/cast_fp16.h" | ||||
| @@ -20,7 +20,6 @@ | |||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/runtime/runtime_api.h" | #include "src/runtime/runtime_api.h" | ||||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | using mindspore::kernel::KERNEL_ARCH::kCPU; | ||||
| using mindspore::lite::KernelRegistrar; | using mindspore::lite::KernelRegistrar; | ||||
| @@ -212,30 +211,9 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor | |||||
| MS_ASSERT(opParameter != nullptr); | MS_ASSERT(opParameter != nullptr); | ||||
| MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D); | MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D); | ||||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||||
| auto *restore_data = weight_tensor->data_c(); | |||||
| auto restore_type = weight_tensor->data_type(); | |||||
| auto dequant_flag = | |||||
| !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; | |||||
| if (dequant_flag) { | |||||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||||
| if (dequant_weight == nullptr) { | |||||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||||
| free(opParameter); | |||||
| return nullptr; | |||||
| } | |||||
| weight_tensor->set_data_type(kNumberTypeFloat32); | |||||
| weight_tensor->set_data(dequant_weight); | |||||
| } | |||||
| auto kernel = new (std::nothrow) DeconvolutionDepthwiseFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | auto kernel = new (std::nothrow) DeconvolutionDepthwiseFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | ||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| MS_LOG(ERROR) << "kernel is nullptr."; | MS_LOG(ERROR) << "kernel is nullptr."; | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| free(opParameter); | free(opParameter); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -243,19 +221,9 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | ||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| delete kernel; | delete kernel; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| return kernel; | return kernel; | ||||
| } | } | ||||
| @@ -17,7 +17,6 @@ | |||||
| #include "src/runtime/kernel/arm/fp16/deconvolution_fp16.h" | #include "src/runtime/kernel/arm/fp16/deconvolution_fp16.h" | ||||
| #include "src/runtime/kernel/arm/fp16/deconvolution_winograd_fp16.h" | #include "src/runtime/kernel/arm/fp16/deconvolution_winograd_fp16.h" | ||||
| #include "src/runtime/runtime_api.h" | #include "src/runtime/runtime_api.h" | ||||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | using mindspore::kernel::KERNEL_ARCH::kCPU; | ||||
| using mindspore::lite::KernelRegistrar; | using mindspore::lite::KernelRegistrar; | ||||
| @@ -220,22 +219,6 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *> | |||||
| MS_ASSERT(opParameter != nullptr); | MS_ASSERT(opParameter != nullptr); | ||||
| MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); | MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); | ||||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||||
| auto *restore_data = weight_tensor->data_c(); | |||||
| auto restore_type = weight_tensor->data_type(); | |||||
| auto dequant_flag = | |||||
| !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; | |||||
| if (dequant_flag) { | |||||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||||
| if (dequant_weight == nullptr) { | |||||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||||
| free(opParameter); | |||||
| return nullptr; | |||||
| } | |||||
| weight_tensor->set_data_type(kNumberTypeFloat32); | |||||
| weight_tensor->set_data(dequant_weight); | |||||
| } | |||||
| kernel::LiteKernel *kernel; | kernel::LiteKernel *kernel; | ||||
| auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | ||||
| if ((conv_param->stride_h_ != 1 || conv_param->stride_w_ != 1) && | if ((conv_param->stride_h_ != 1 || conv_param->stride_w_ != 1) && | ||||
| @@ -247,11 +230,6 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *> | |||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| MS_LOG(ERROR) << "kernel is nullptr."; | MS_LOG(ERROR) << "kernel is nullptr."; | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| free(opParameter); | free(opParameter); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -259,19 +237,9 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *> | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | ||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| delete kernel; | delete kernel; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| return kernel; | return kernel; | ||||
| } | } | ||||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_DeConv2D, CpuDeConvFp16KernelCreator) | REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_DeConv2D, CpuDeConvFp16KernelCreator) | ||||
| @@ -234,30 +234,9 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T | |||||
| OpParameter *opParameter, const lite::InnerContext *ctx, | OpParameter *opParameter, const lite::InnerContext *ctx, | ||||
| const kernel::KernelKey &desc, | const kernel::KernelKey &desc, | ||||
| const mindspore::lite::PrimitiveC *primitive) { | const mindspore::lite::PrimitiveC *primitive) { | ||||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||||
| // data of second tensor of fc may be nullptr | |||||
| auto *restore_data = weight_tensor->data_c(); | |||||
| auto restore_type = weight_tensor->data_type(); | |||||
| bool dequant_flag = | |||||
| !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; | |||||
| if (dequant_flag) { | |||||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||||
| if (dequant_weight == nullptr) { | |||||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||||
| free(opParameter); | |||||
| return nullptr; | |||||
| } | |||||
| weight_tensor->set_data_type(kNumberTypeFloat32); | |||||
| weight_tensor->set_data(dequant_weight); | |||||
| } | |||||
| auto *kernel = new (std::nothrow) FullconnectionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | auto *kernel = new (std::nothrow) FullconnectionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | ||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| MS_LOG(ERROR) << "kernel is nullptr."; | MS_LOG(ERROR) << "kernel is nullptr."; | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| free(opParameter); | free(opParameter); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -265,19 +244,9 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | ||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| delete kernel; | delete kernel; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| return kernel; | return kernel; | ||||
| } | } | ||||
| @@ -24,7 +24,6 @@ | |||||
| #include "nnacl/fp16/matmul_fp16.h" | #include "nnacl/fp16/matmul_fp16.h" | ||||
| #include "nnacl/fp16/cast_fp16.h" | #include "nnacl/fp16/cast_fp16.h" | ||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| class FullconnectionFP16CPUKernel : public LiteKernel { | class FullconnectionFP16CPUKernel : public LiteKernel { | ||||
| @@ -20,7 +20,6 @@ | |||||
| #include "src/runtime/runtime_api.h" | #include "src/runtime/runtime_api.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||||
| using mindspore::lite::KernelRegistrar; | using mindspore::lite::KernelRegistrar; | ||||
| using mindspore::lite::RET_ERROR; | using mindspore::lite::RET_ERROR; | ||||
| @@ -330,29 +329,9 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *> | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | ||||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | const lite::InnerContext *ctx, const kernel::KernelKey &desc, | ||||
| const mindspore::lite::PrimitiveC *primitive) { | const mindspore::lite::PrimitiveC *primitive) { | ||||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||||
| auto *restore_data = weight_tensor->data_c(); | |||||
| auto restore_type = weight_tensor->data_type(); | |||||
| bool dequant_flag = | |||||
| !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; | |||||
| if (dequant_flag) { | |||||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||||
| if (dequant_weight == nullptr) { | |||||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||||
| free(opParameter); | |||||
| return nullptr; | |||||
| } | |||||
| weight_tensor->set_data_type(kNumberTypeFloat32); | |||||
| weight_tensor->set_data(dequant_weight); | |||||
| } | |||||
| auto *kernel = new (std::nothrow) MatmulFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | auto *kernel = new (std::nothrow) MatmulFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | ||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| MS_LOG(ERROR) << "kernel is nullptr."; | MS_LOG(ERROR) << "kernel is nullptr."; | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| free(opParameter); | free(opParameter); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -361,18 +340,8 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *> | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | ||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | ||||
| delete kernel; | delete kernel; | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| return kernel; | return kernel; | ||||
| } | } | ||||
| @@ -22,7 +22,6 @@ | |||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/runtime/runtime_api.h" | #include "src/runtime/runtime_api.h" | ||||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | using mindspore::kernel::KERNEL_ARCH::kCPU; | ||||
| using mindspore::lite::KernelRegistrar; | using mindspore::lite::KernelRegistrar; | ||||
| @@ -356,22 +355,6 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> & | |||||
| MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); | MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); | ||||
| MS_ASSERT(desc.data_type == kNumberTypeFloat32); | MS_ASSERT(desc.data_type == kNumberTypeFloat32); | ||||
| // if get quantized weight, dequantize it to float32 type data. | |||||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||||
| auto *restore_data = weight_tensor->data_c(); | |||||
| auto restore_type = weight_tensor->data_type(); | |||||
| bool dequant_flag = | |||||
| !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; | |||||
| if (dequant_flag) { | |||||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||||
| if (dequant_weight == nullptr) { | |||||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||||
| free(op_parameter); | |||||
| return nullptr; | |||||
| } | |||||
| weight_tensor->set_data(dequant_weight); | |||||
| } | |||||
| auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter); | auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter); | ||||
| kernel::LiteKernel *kernel = nullptr; | kernel::LiteKernel *kernel = nullptr; | ||||
| if (conv_param->group_ == 1) { | if (conv_param->group_ == 1) { | ||||
| @@ -382,11 +365,6 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> & | |||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| MS_LOG(ERROR) << "kernel is nullptr."; | MS_LOG(ERROR) << "kernel is nullptr."; | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| free(op_parameter); | free(op_parameter); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -395,20 +373,9 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> & | |||||
| if (ret != RET_OK && ret != RET_INFER_INVALID) { | if (ret != RET_OK && ret != RET_INFER_INVALID) { | ||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: " | MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: " | ||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_)); | << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_)); | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| delete kernel; | delete kernel; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| return kernel; | return kernel; | ||||
| } | } | ||||
| @@ -21,7 +21,6 @@ | |||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/runtime/runtime_api.h" | #include "src/runtime/runtime_api.h" | ||||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | using mindspore::kernel::KERNEL_ARCH::kCPU; | ||||
| using mindspore::lite::KernelRegistrar; | using mindspore::lite::KernelRegistrar; | ||||
| @@ -126,19 +125,6 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *> | |||||
| MS_ASSERT(opParameter != nullptr); | MS_ASSERT(opParameter != nullptr); | ||||
| MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); | MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); | ||||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||||
| auto *restore_data = weight_tensor->data_c(); | |||||
| auto restore_type = weight_tensor->data_type(); | |||||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||||
| if (dequant_weight == nullptr) { | |||||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||||
| free(opParameter); | |||||
| return nullptr; | |||||
| } | |||||
| weight_tensor->set_data(dequant_weight); | |||||
| } | |||||
| auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | ||||
| kernel::LiteKernel *kernel = nullptr; | kernel::LiteKernel *kernel = nullptr; | ||||
| if (primitive != nullptr && primitive->infer_flag()) { | if (primitive != nullptr && primitive->infer_flag()) { | ||||
| @@ -162,11 +148,6 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *> | |||||
| } | } | ||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| MS_LOG(ERROR) << "kernel is nullptr."; | MS_LOG(ERROR) << "kernel is nullptr."; | ||||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| free(opParameter); | free(opParameter); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -174,21 +155,10 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *> | |||||
| if (ret != RET_OK && ret != RET_INFER_INVALID) { | if (ret != RET_OK && ret != RET_INFER_INVALID) { | ||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | ||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | ||||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| delete kernel; | delete kernel; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| return kernel; | return kernel; | ||||
| } | } | ||||
| @@ -19,7 +19,6 @@ | |||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/runtime/runtime_api.h" | #include "src/runtime/runtime_api.h" | ||||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||||
| #include "nnacl/fp32/conv_fp32.h" | #include "nnacl/fp32/conv_fp32.h" | ||||
| #include "nnacl/fp32/matmul_fp32.h" | #include "nnacl/fp32/matmul_fp32.h" | ||||
| @@ -19,7 +19,6 @@ | |||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/runtime/runtime_api.h" | #include "src/runtime/runtime_api.h" | ||||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | using mindspore::kernel::KERNEL_ARCH::kCPU; | ||||
| using mindspore::lite::KernelRegistrar; | using mindspore::lite::KernelRegistrar; | ||||
| @@ -202,29 +201,10 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor | |||||
| const mindspore::lite::PrimitiveC *primitive) { | const mindspore::lite::PrimitiveC *primitive) { | ||||
| MS_ASSERT(opParameter != nullptr); | MS_ASSERT(opParameter != nullptr); | ||||
| MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D); | MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D); | ||||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||||
| auto *restore_data = weight_tensor->data_c(); | |||||
| auto restore_type = weight_tensor->data_type(); | |||||
| bool dequant_flag = | |||||
| !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; | |||||
| if (dequant_flag) { | |||||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||||
| if (dequant_weight == nullptr) { | |||||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||||
| free(opParameter); | |||||
| return nullptr; | |||||
| } | |||||
| weight_tensor->set_data(dequant_weight); | |||||
| } | |||||
| auto kernel = | auto kernel = | ||||
| new (std::nothrow) kernel::DeconvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx, primitive); | new (std::nothrow) kernel::DeconvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx, primitive); | ||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| MS_LOG(ERROR) << "kernel is nullptr."; | MS_LOG(ERROR) << "kernel is nullptr."; | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| free(opParameter); | free(opParameter); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -232,19 +212,9 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | ||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| delete kernel; | delete kernel; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| return kernel; | return kernel; | ||||
| } | } | ||||
| @@ -17,7 +17,6 @@ | |||||
| #include "src/runtime/kernel/arm/fp32/deconvolution_fp32.h" | #include "src/runtime/kernel/arm/fp32/deconvolution_fp32.h" | ||||
| #include "src/runtime/kernel/arm/fp32/deconvolution_winograd_fp32.h" | #include "src/runtime/kernel/arm/fp32/deconvolution_winograd_fp32.h" | ||||
| #include "src/runtime/runtime_api.h" | #include "src/runtime/runtime_api.h" | ||||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | using mindspore::kernel::KERNEL_ARCH::kCPU; | ||||
| using mindspore::lite::KernelRegistrar; | using mindspore::lite::KernelRegistrar; | ||||
| @@ -240,20 +239,6 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *> | |||||
| const mindspore::lite::PrimitiveC *primitive) { | const mindspore::lite::PrimitiveC *primitive) { | ||||
| MS_ASSERT(opParameter != nullptr); | MS_ASSERT(opParameter != nullptr); | ||||
| MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); | MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); | ||||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||||
| auto *restore_data = weight_tensor->data_c(); | |||||
| auto restore_type = weight_tensor->data_type(); | |||||
| bool dequant_flag = | |||||
| !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; | |||||
| if (dequant_flag) { | |||||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||||
| if (dequant_weight == nullptr) { | |||||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||||
| free(opParameter); | |||||
| return nullptr; | |||||
| } | |||||
| weight_tensor->set_data(dequant_weight); | |||||
| } | |||||
| kernel::LiteKernel *kernel; | kernel::LiteKernel *kernel; | ||||
| auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | ||||
| @@ -266,11 +251,6 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *> | |||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| MS_LOG(ERROR) << "kernel is nullptr."; | MS_LOG(ERROR) << "kernel is nullptr."; | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| free(opParameter); | free(opParameter); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -278,21 +258,9 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *> | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | ||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| delete kernel; | delete kernel; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| return kernel; | return kernel; | ||||
| } | } | ||||
| @@ -228,28 +228,9 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T | |||||
| const mindspore::lite::PrimitiveC *primitive) { | const mindspore::lite::PrimitiveC *primitive) { | ||||
| MS_ASSERT(opParameter != nullptr); | MS_ASSERT(opParameter != nullptr); | ||||
| MS_ASSERT(desc.type == schema::PrimitiveType_FullConnection); | MS_ASSERT(desc.type == schema::PrimitiveType_FullConnection); | ||||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||||
| // data of second tensor of fc may be nullptr | |||||
| auto *restore_data = weight_tensor->data_c(); | |||||
| auto restore_type = weight_tensor->data_type(); | |||||
| bool dequant_flag = | |||||
| !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; | |||||
| if (dequant_flag) { | |||||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||||
| if (dequant_weight == nullptr) { | |||||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||||
| return nullptr; | |||||
| } | |||||
| weight_tensor->set_data(dequant_weight); | |||||
| } | |||||
| auto kernel = new (std::nothrow) FullconnectionCPUKernel(opParameter, inputs, outputs, ctx, primitive); | auto kernel = new (std::nothrow) FullconnectionCPUKernel(opParameter, inputs, outputs, ctx, primitive); | ||||
| if (!kernel) { | if (!kernel) { | ||||
| MS_LOG(ERROR) << "kernel is nullptr."; | MS_LOG(ERROR) << "kernel is nullptr."; | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| free(opParameter); | free(opParameter); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -257,19 +238,9 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | ||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| delete kernel; | delete kernel; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| return kernel; | return kernel; | ||||
| } | } | ||||
| @@ -22,7 +22,6 @@ | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "nnacl/fp32/matmul_fp32.h" | #include "nnacl/fp32/matmul_fp32.h" | ||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||||
| using mindspore::lite::InnerContext; | using mindspore::lite::InnerContext; | ||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| @@ -19,7 +19,6 @@ | |||||
| #include "nnacl/fp32/matmul_fp32.h" | #include "nnacl/fp32/matmul_fp32.h" | ||||
| #include "src/runtime/runtime_api.h" | #include "src/runtime/runtime_api.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||||
| using mindspore::lite::RET_ERROR; | using mindspore::lite::RET_ERROR; | ||||
| using mindspore::lite::RET_INPUT_TENSOR_ERROR; | using mindspore::lite::RET_INPUT_TENSOR_ERROR; | ||||
| @@ -417,30 +416,9 @@ kernel::LiteKernel *CpuMatmulFp32KernelCreator(const std::vector<lite::Tensor *> | |||||
| const mindspore::lite::PrimitiveC *primitive) { | const mindspore::lite::PrimitiveC *primitive) { | ||||
| MS_ASSERT(opParameter != nullptr); | MS_ASSERT(opParameter != nullptr); | ||||
| MS_ASSERT(desc.type == schema::PrimitiveType_MatMul); | MS_ASSERT(desc.type == schema::PrimitiveType_MatMul); | ||||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||||
| auto *restore_data = weight_tensor->data_c(); | |||||
| auto restore_type = weight_tensor->data_type(); | |||||
| bool dequant_flag = | |||||
| !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; | |||||
| if (dequant_flag) { | |||||
| auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor); | |||||
| if (dequant_weight == nullptr) { | |||||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||||
| free(opParameter); | |||||
| return nullptr; | |||||
| } | |||||
| weight_tensor->set_data(dequant_weight); | |||||
| } | |||||
| auto kernel = new (std::nothrow) MatmulCPUKernel(opParameter, inputs, outputs, ctx, primitive); | auto kernel = new (std::nothrow) MatmulCPUKernel(opParameter, inputs, outputs, ctx, primitive); | ||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| MS_LOG(ERROR) << "kernel is nullptr."; | MS_LOG(ERROR) << "kernel is nullptr."; | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| free(opParameter); | free(opParameter); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -448,21 +426,9 @@ kernel::LiteKernel *CpuMatmulFp32KernelCreator(const std::vector<lite::Tensor *> | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | ||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| delete kernel; | delete kernel; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (dequant_flag) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->set_data(restore_data); | |||||
| weight_tensor->set_data_type(restore_type); | |||||
| } | |||||
| return kernel; | return kernel; | ||||
| } | } | ||||
| @@ -15,7 +15,7 @@ | |||||
| */ | */ | ||||
| #include "src/runtime/kernel/opencl/opencl_kernel.h" | #include "src/runtime/kernel/opencl/opencl_kernel.h" | ||||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||||
| #include "mindspore/lite/src/dequant.h" | |||||
| using mindspore::lite::RET_ERROR; | using mindspore::lite::RET_ERROR; | ||||
| using mindspore::lite::RET_OK; | using mindspore::lite::RET_OK; | ||||
| @@ -263,10 +263,10 @@ int OpenCLKernel::DequantWeight() { | |||||
| if (is_fp16) { | if (is_fp16) { | ||||
| #ifdef ENABLE_ARM64 | #ifdef ENABLE_ARM64 | ||||
| if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt8) { | if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt8) { | ||||
| dequant_weight = kernel::DequantUtil::DequantData<int8_t, float16_t>(weight_tensor); | |||||
| dequant_weight = lite::DequantUtil::DequantData<int8_t, float16_t>(weight_tensor); | |||||
| weight_tensor->set_data_type(kNumberTypeFloat16); | weight_tensor->set_data_type(kNumberTypeFloat16); | ||||
| } else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt16) { | } else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt16) { | ||||
| dequant_weight = kernel::DequantUtil::DequantData<int16_t, float16_t>(weight_tensor); | |||||
| dequant_weight = lite::DequantUtil::DequantData<int16_t, float16_t>(weight_tensor); | |||||
| weight_tensor->set_data_type(kNumberTypeFloat16); | weight_tensor->set_data_type(kNumberTypeFloat16); | ||||
| } else { | } else { | ||||
| set_flag = false; | set_flag = false; | ||||
| @@ -276,10 +276,10 @@ int OpenCLKernel::DequantWeight() { | |||||
| #endif | #endif | ||||
| } else { | } else { | ||||
| if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt8) { | if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt8) { | ||||
| dequant_weight = kernel::DequantUtil::DequantData<int8_t, float>(weight_tensor); | |||||
| dequant_weight = lite::DequantUtil::DequantData<int8_t, float>(weight_tensor); | |||||
| weight_tensor->set_data_type(kNumberTypeFloat32); | weight_tensor->set_data_type(kNumberTypeFloat32); | ||||
| } else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt16) { | } else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt16) { | ||||
| dequant_weight = kernel::DequantUtil::DequantData<int16_t, float>(weight_tensor); | |||||
| dequant_weight = lite::DequantUtil::DequantData<int16_t, float>(weight_tensor); | |||||
| weight_tensor->set_data_type(kNumberTypeFloat32); | weight_tensor->set_data_type(kNumberTypeFloat32); | ||||
| } else { | } else { | ||||
| set_flag = false; | set_flag = false; | ||||
| @@ -25,7 +25,7 @@ | |||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/runtime/opencl/opencl_runtime.h" | #include "src/runtime/opencl/opencl_runtime.h" | ||||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||||
| #include "mindspore/lite/src/dequant.h" | |||||
| #include "src/runtime/kernel/opencl/utils.h" | #include "src/runtime/kernel/opencl/utils.h" | ||||
| using mindspore::lite::RET_ERROR; | using mindspore::lite::RET_ERROR; | ||||
| @@ -27,6 +27,7 @@ | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/sub_graph_kernel.h" | #include "src/sub_graph_kernel.h" | ||||
| #include "src/dequant.h" | |||||
| #if SUPPORT_GPU | #if SUPPORT_GPU | ||||
| #include "src/runtime/kernel/opencl/opencl_subgraph.h" | #include "src/runtime/kernel/opencl/opencl_subgraph.h" | ||||
| #include "src/runtime/opencl/opencl_runtime.h" | #include "src/runtime/opencl/opencl_runtime.h" | ||||
| @@ -213,8 +214,10 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in | |||||
| if (mindspore::lite::IsSupportFloat16() && | if (mindspore::lite::IsSupportFloat16() && | ||||
| ((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) { | ((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) { | ||||
| kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type}; | kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type}; | ||||
| auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, fp16_cpu_desc.data_type); | |||||
| auto *kernel = | auto *kernel = | ||||
| KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc); | KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc); | ||||
| DequantUtil::RestoreTensorData(tensor_origin_data_map); | |||||
| if (kernel != nullptr) { | if (kernel != nullptr) { | ||||
| MS_LOG(DEBUG) << "Get fp16 op success: " << schema::EnumNamePrimitiveType(fp16_cpu_desc.type) << " " | MS_LOG(DEBUG) << "Get fp16 op success: " << schema::EnumNamePrimitiveType(fp16_cpu_desc.type) << " " | ||||
| << node->name_; | << node->name_; | ||||
| @@ -225,7 +228,9 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in | |||||
| MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op."; | MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op."; | ||||
| desc.data_type = kNumberTypeFloat32; | desc.data_type = kNumberTypeFloat32; | ||||
| } | } | ||||
| auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, desc.data_type); | |||||
| auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); | auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); | ||||
| DequantUtil::RestoreTensorData(tensor_origin_data_map); | |||||
| if (kernel != nullptr) { | if (kernel != nullptr) { | ||||
| return kernel; | return kernel; | ||||
| } | } | ||||
| @@ -126,6 +126,7 @@ set(TEST_LITE_SRC | |||||
| ${LITE_DIR}/src/kernel_registry.cc | ${LITE_DIR}/src/kernel_registry.cc | ||||
| ${LITE_DIR}/src/lite_kernel.cc | ${LITE_DIR}/src/lite_kernel.cc | ||||
| ${LITE_DIR}/src/lite_session.cc | ${LITE_DIR}/src/lite_session.cc | ||||
| ${LITE_DIR}/src/dequant.cc | |||||
| ${LITE_DIR}/src/sub_graph_kernel.cc | ${LITE_DIR}/src/sub_graph_kernel.cc | ||||
| ${LITE_DIR}/src/lite_model.cc | ${LITE_DIR}/src/lite_model.cc | ||||
| ${LITE_DIR}/src/scheduler.cc | ${LITE_DIR}/src/scheduler.cc | ||||
| @@ -95,6 +95,7 @@ set(LITE_SRC | |||||
| ${SRC_DIR}/executor.cc | ${SRC_DIR}/executor.cc | ||||
| ${SRC_DIR}/lite_model.cc | ${SRC_DIR}/lite_model.cc | ||||
| ${SRC_DIR}/errorcode.cc | ${SRC_DIR}/errorcode.cc | ||||
| ${SRC_DIR}/dequant.cc | |||||
| ) | ) | ||||
| if (SUPPORT_TRAIN) | if (SUPPORT_TRAIN) | ||||
| set(LITE_SRC | set(LITE_SRC | ||||
| @@ -782,4 +782,27 @@ FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &func_graph) { | |||||
| return new_func_graph; | return new_func_graph; | ||||
| } | } | ||||
| void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, ParamValueLitePtr *param_value) { | |||||
| MS_ASSERT(node != nullptr); | |||||
| MS_ASSERT(param_node != nullptr); | |||||
| MS_ASSERT(param_value != nullptr); | |||||
| auto op_name = node->fullname_with_scope(); | |||||
| *param_node = node->cast<ParameterPtr>(); | |||||
| if (*param_node == nullptr) { | |||||
| MS_LOG(INFO) << op_name << " can not cast to ParameterPtr"; | |||||
| return; | |||||
| } | |||||
| if (!(*param_node)->has_default()) { | |||||
| MS_LOG(INFO) << op_name << " not has_default"; | |||||
| return; | |||||
| } | |||||
| *param_value = std::static_pointer_cast<ParamValueLite>((*param_node)->default_param()); | |||||
| if (*param_value == nullptr) { | |||||
| MS_LOG(INFO) << "default_param can not cast to ParamValueLite"; | |||||
| return; | |||||
| } | |||||
| } | |||||
| } // namespace mindspore::lite::quant | } // namespace mindspore::lite::quant | ||||
| @@ -75,9 +75,10 @@ class QuantStrategy { | |||||
| bool CanMulOpQuantized(const CNodePtr &node) const; | bool CanMulOpQuantized(const CNodePtr &node) const; | ||||
| bool CanOpPostQuantized(AnfNodePtr &node) const; | bool CanOpPostQuantized(AnfNodePtr &node) const; | ||||
| private: | |||||
| size_t mWeightSize; | size_t mWeightSize; | ||||
| size_t mConvWeightQuantChannelThreshold; | size_t mConvWeightQuantChannelThreshold; | ||||
| private: | |||||
| static const std::vector<schema::PrimitiveType> conv_types; | static const std::vector<schema::PrimitiveType> conv_types; | ||||
| static const std::vector<schema::PrimitiveType> mul_types; | static const std::vector<schema::PrimitiveType> mul_types; | ||||
| }; | }; | ||||
| @@ -356,5 +357,8 @@ STATUS CopyInputDataToTensor(size_t input_index, size_t image_index, | |||||
| const std::vector<std::vector<std::string>> &images, mindspore::tensor::MSTensor *tensor); | const std::vector<std::vector<std::string>> &images, mindspore::tensor::MSTensor *tensor); | ||||
| FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &); | FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &); | ||||
| void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, ParamValueLitePtr *param_value); | |||||
| } // namespace mindspore::lite::quant | } // namespace mindspore::lite::quant | ||||
| #endif | #endif | ||||
| @@ -20,7 +20,6 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include "src/common/common.h" | #include "src/common/common.h" | ||||
| #include "ir/dtype/type_id.h" | |||||
| using std::string; | using std::string; | ||||
| using std::vector; | using std::vector; | ||||
| @@ -73,13 +72,13 @@ WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const std::string &config_f | |||||
| this->bit_num_ = static_cast<size_t>(std::stoull(bitNum)); | this->bit_num_ = static_cast<size_t>(std::stoull(bitNum)); | ||||
| auto convQuantWeightChannelThreshold = static_cast<size_t>(std::stoull(convWeightChannelThreshold)); | auto convQuantWeightChannelThreshold = static_cast<size_t>(std::stoull(convWeightChannelThreshold)); | ||||
| quant_strategy_ = std::make_unique<QuantStrategy>(quantSize, convQuantWeightChannelThreshold); | quant_strategy_ = std::make_unique<QuantStrategy>(quantSize, convQuantWeightChannelThreshold); | ||||
| quant_max = (1 << (unsigned int)(this->bit_num_ - 1)) - 1; | |||||
| quant_min = -(1 << (unsigned int)(this->bit_num_ - 1)); | |||||
| // parse type_id | |||||
| quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1; | |||||
| quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1)); | |||||
| // parse type_id_ | |||||
| if (this->bit_num_ > 0 && this->bit_num_ <= 8) { | if (this->bit_num_ > 0 && this->bit_num_ <= 8) { | ||||
| type_id = kNumberTypeInt8; | |||||
| type_id_ = kNumberTypeInt8; | |||||
| } else if (this->bit_num_ <= 16) { | } else if (this->bit_num_ <= 16) { | ||||
| type_id = kNumberTypeInt16; | |||||
| type_id_ = kNumberTypeInt16; | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "invalid input bits"; | MS_LOG(ERROR) << "invalid input bits"; | ||||
| } | } | ||||
| @@ -90,7 +89,7 @@ WeightQuantizer::~WeightQuantizer() { delete fp32_session_; } | |||||
| STATUS WeightQuantizer::SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, | STATUS WeightQuantizer::SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, | ||||
| std::shared_ptr<PrimitiveC> primitive_c) { | std::shared_ptr<PrimitiveC> primitive_c) { | ||||
| // set dtype | // set dtype | ||||
| param_value->set_tensor_type(type_id); | |||||
| param_value->set_tensor_type(type_id_); | |||||
| auto abstract_base = param_node->abstract(); | auto abstract_base = param_node->abstract(); | ||||
| if (abstract_base == nullptr) { | if (abstract_base == nullptr) { | ||||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); | MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); | ||||
| @@ -101,49 +100,158 @@ STATUS WeightQuantizer::SetAbstract(ParamValueLitePtr param_value, ParameterPtr | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base); | auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base); | ||||
| abstract_tensor->element()->set_type(TypeIdToType(type_id)); | |||||
| abstract_tensor->element()->set_type(TypeIdToType(type_id_)); | |||||
| primitive_c->set_quant_type(schema::QuantType_WeightQuant); | primitive_c->set_quant_type(schema::QuantType_WeightQuant); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) { | |||||
| for (auto &cnode : nodes) { | |||||
| if (!quant_strategy_->CanConvOpQuantized(cnode)) { | |||||
| continue; | |||||
| } | |||||
| STATUS WeightQuantizer::DoConvQuantize(CNodePtr cnode) { | |||||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||||
| if (primitive_c == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||||
| if (primitive_c == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto input_node = cnode->input(2); | |||||
| if (!input_node->isa<Parameter>()) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto input_node = cnode->input(2); | |||||
| if (!input_node->isa<Parameter>()) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| ParameterPtr param_node; | |||||
| ParamValueLitePtr param_value; | |||||
| auto param_node = input_node->cast<ParameterPtr>(); | |||||
| if (!param_node->has_default()) { | |||||
| return RET_ERROR; | |||||
| GetLiteParameter(input_node, ¶m_node, ¶m_value); | |||||
| if (param_node == nullptr || param_value == nullptr) { | |||||
| MS_LOG(ERROR) << "GetLiteParameter error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) { | |||||
| MS_LOG(ERROR) << "model weight data type invalid which is " << param_value->tensor_type(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto status = RET_ERROR; | |||||
| if (type_id_ == kNumberTypeInt8) { | |||||
| status = | |||||
| QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true); | |||||
| } else if (type_id_ == kNumberTypeInt16) { | |||||
| status = | |||||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true); | |||||
| } | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||||
| return status; | |||||
| } | |||||
| status = SetAbstract(param_value, param_node, primitive_c); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "SetAbstract failed : " << status; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) { | |||||
| auto already_quant = false; | |||||
| ParamValueLitePtr param_value = nullptr; | |||||
| ParameterPtr param_node = nullptr; | |||||
| for (size_t i = 1; i < cnode->size(); i++) { | |||||
| auto inputNode = cnode->input(i); | |||||
| if (inputNode->isa<Parameter>()) { | |||||
| param_node = inputNode->cast<ParameterPtr>(); | |||||
| if ((param_node != nullptr) && param_node->has_default()) { | |||||
| param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param()); | |||||
| if ((param_value == nullptr) || (param_value->tensor_size() == 0) || (param_value->tensor_addr() == nullptr)) { | |||||
| param_value = nullptr; | |||||
| continue; | |||||
| } else if (param_value->tensor_type() == mindspore::kNumberTypeInt8 || | |||||
| param_value->tensor_type() == mindspore::kNumberTypeInt16) { | |||||
| MS_LOG(INFO) << "the node: " << cnode->fullname_with_scope() << " input_i: " << i << "has been " | |||||
| << " quantized"; | |||||
| already_quant = true; | |||||
| break; | |||||
| } else if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) { | |||||
| param_value = nullptr; | |||||
| continue; | |||||
| } else { | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | |||||
| ParamValueLitePtr param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param()); | |||||
| if (param_value == nullptr) { | |||||
| if (already_quant) { | |||||
| return RET_OK; | |||||
| } | |||||
| if (param_value == nullptr) { | |||||
| MS_LOG(ERROR) << "No valid input param node !"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||||
| if (primitive_c == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto status = RET_ERROR; | |||||
| if (type_id_ == kNumberTypeInt8) { | |||||
| status = | |||||
| QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true); | |||||
| } else if (type_id_ == kNumberTypeInt16) { | |||||
| status = | |||||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true); | |||||
| } | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||||
| return status; | |||||
| } | |||||
| status = SetAbstract(param_value, param_node, primitive_c); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "SetAbstract failed : " << status; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) { | |||||
| MS_ASSERT(cnode != nullptr); | |||||
| auto op_name = cnode->fullname_with_scope(); | |||||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||||
| MS_ASSERT(primitive_c != nullptr); | |||||
| if (cnode->inputs().size() < 4) { | |||||
| MS_LOG(ERROR) << op_name << " inputs is " << cnode->inputs().size(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| { | |||||
| auto weight_i = cnode->input(2); | |||||
| ParameterPtr param_node; | |||||
| ParamValueLitePtr param_value; | |||||
| GetLiteParameter(weight_i, ¶m_node, ¶m_value); | |||||
| if (param_node == nullptr || param_value == nullptr) { | |||||
| MS_LOG(ERROR) << "GetLiteParameter error"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) { | |||||
| MS_LOG(ERROR) << "model weight data type invalid which is " << param_value->tensor_type(); | |||||
| return RET_ERROR; | |||||
| if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { | |||||
| MS_LOG(WARNING) << "param_value tensor type is: " << param_value->tensor_type() << " not quant"; | |||||
| return RET_OK; | |||||
| } | |||||
| if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) { | |||||
| MS_LOG(INFO) << op_name << " weight_i cnt: " << param_value->tensor_size() / 4 << " < " | |||||
| << quant_strategy_->mWeightSize; | |||||
| return RET_OK; | |||||
| } | } | ||||
| auto status = RET_ERROR; | auto status = RET_ERROR; | ||||
| if (type_id == kNumberTypeInt8) { | |||||
| if (type_id_ == kNumberTypeInt8) { | |||||
| status = | status = | ||||
| QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true); | |||||
| } else if (type_id == kNumberTypeInt16) { | |||||
| QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); | |||||
| } else if (type_id_ == kNumberTypeInt16) { | |||||
| status = | status = | ||||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true); | |||||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); | |||||
| } | } | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | MS_LOG(ERROR) << "QuantFilter failed : " << status; | ||||
| @@ -155,65 +263,26 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| return RET_OK; | |||||
| } | |||||
| STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) { | |||||
| for (auto &node : nodes) { | |||||
| if (!quant_strategy_->CanMulOpQuantized(node)) { | |||||
| continue; | |||||
| } | |||||
| auto already_quant = false; | |||||
| ParamValueLitePtr param_value = nullptr; | |||||
| ParameterPtr param_node = nullptr; | |||||
| for (size_t i = 1; i < node->size(); i++) { | |||||
| auto inputNode = node->input(i); | |||||
| if (inputNode->isa<Parameter>()) { | |||||
| param_node = inputNode->cast<ParameterPtr>(); | |||||
| if ((param_node != nullptr) && param_node->has_default()) { | |||||
| param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param()); | |||||
| if ((param_value == nullptr) || (param_value->tensor_size() == 0) || | |||||
| (param_value->tensor_addr() == nullptr)) { | |||||
| param_value = nullptr; | |||||
| continue; | |||||
| } else if (param_value->tensor_type() == mindspore::kNumberTypeInt8 || | |||||
| param_value->tensor_type() == mindspore::kNumberTypeInt16) { | |||||
| MS_LOG(INFO) << "the node: " << node->fullname_with_scope() << " input_i: " << i << "has been " | |||||
| << " quantized"; | |||||
| already_quant = true; | |||||
| break; | |||||
| } else if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) { | |||||
| param_value = nullptr; | |||||
| continue; | |||||
| } else { | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| if (already_quant) { | |||||
| continue; | |||||
| } | |||||
| if (param_value == nullptr) { | |||||
| MS_LOG(ERROR) << "No valid input param node !"; | |||||
| { | |||||
| auto weight_h = cnode->input(3); | |||||
| ParameterPtr param_node; | |||||
| ParamValueLitePtr param_value; | |||||
| GetLiteParameter(weight_h, ¶m_node, ¶m_value); | |||||
| if (param_node == nullptr || param_value == nullptr) { | |||||
| MS_LOG(ERROR) << "GetLiteParameter error"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(node->input(0)); | |||||
| if (primitive_c == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||||
| if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { | |||||
| MS_LOG(ERROR) << "param_value tensor type is: " << param_value->tensor_type() << " not quant"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto status = RET_ERROR; | auto status = RET_ERROR; | ||||
| if (type_id == kNumberTypeInt8) { | |||||
| if (type_id_ == kNumberTypeInt8) { | |||||
| status = | status = | ||||
| QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true); | |||||
| } else if (type_id == kNumberTypeInt16) { | |||||
| QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); | |||||
| } else if (type_id_ == kNumberTypeInt16) { | |||||
| status = | status = | ||||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true); | |||||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); | |||||
| } | } | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | MS_LOG(ERROR) << "QuantFilter failed : " << status; | ||||
| @@ -225,7 +294,78 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| { | |||||
| if (cnode->inputs().size() > 4) { | |||||
| auto bias = cnode->input(4); | |||||
| ParameterPtr param_node; | |||||
| ParamValueLitePtr param_value; | |||||
| GetLiteParameter(bias, ¶m_node, ¶m_value); | |||||
| if (param_node == nullptr || param_value == nullptr) { | |||||
| MS_LOG(ERROR) << "GetLiteParameter error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { | |||||
| MS_LOG(ERROR) << "param_value tensor type is: " << param_value->tensor_type() << " not quant"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto status = RET_ERROR; | |||||
| if (type_id_ == kNumberTypeInt8) { | |||||
| status = | |||||
| QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); | |||||
| } else if (type_id_ == kNumberTypeInt16) { | |||||
| status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, | |||||
| false); | |||||
| } | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||||
| return status; | |||||
| } | |||||
| status = SetAbstract(param_value, param_node, primitive_c); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "SetAbstract failed : " << status; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) { | |||||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||||
| MS_ASSERT(primitive_c != nullptr); | |||||
| auto weight_h = cnode->input(1); | |||||
| ParameterPtr param_node; | |||||
| ParamValueLitePtr param_value; | |||||
| GetLiteParameter(weight_h, ¶m_node, ¶m_value); | |||||
| if (param_node == nullptr || param_value == nullptr || param_value->tensor_type() != TypeId::kNumberTypeFloat32) { | |||||
| MS_LOG(INFO) << "This Gather op " << cnode->fullname_with_scope() << " can not quant weight"; | |||||
| return RET_OK; | |||||
| } | |||||
| if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) { | |||||
| MS_LOG(INFO) << cnode->fullname_with_scope() << " param cnt: " << param_value->tensor_size() / 4 << " < " | |||||
| << quant_strategy_->mWeightSize; | |||||
| return RET_OK; | |||||
| } | |||||
| auto status = RET_ERROR; | |||||
| if (type_id_ == kNumberTypeInt8) { | |||||
| status = | |||||
| QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); | |||||
| } else if (type_id_ == kNumberTypeInt16) { | |||||
| status = | |||||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); | |||||
| } | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||||
| return status; | |||||
| } | |||||
| status = SetAbstract(param_value, param_node, primitive_c); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "SetAbstract failed : " << status; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -315,6 +455,23 @@ STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) { | |||||
| } | } | ||||
| auto cnodes = func_graph->GetOrderedCnodes(); | auto cnodes = func_graph->GetOrderedCnodes(); | ||||
| for (auto &cnode : cnodes) { | |||||
| auto op_type = NodePrimitiveType(cnode); | |||||
| if (op_type == schema::PrimitiveType_Lstm) { | |||||
| status = DoLstmQuntize(cnode); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "DoLstmQuntize error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } else if (op_type == schema::PrimitiveType_Gather) { | |||||
| status = DoGatherQuntize(cnode); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "DoGatherQuntize error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| } | |||||
| for (auto iter = cnodes.end(); iter != cnodes.begin();) { | for (auto iter = cnodes.end(); iter != cnodes.begin();) { | ||||
| auto cnode = *(--iter); | auto cnode = *(--iter); | ||||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | ||||
| @@ -357,18 +514,18 @@ STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) { | |||||
| } | } | ||||
| // 1. try quant | // 1. try quant | ||||
| for (int bit_num_t = 2; bit_num_t <= 8; bit_num_t++) { | for (int bit_num_t = 2; bit_num_t <= 8; bit_num_t++) { | ||||
| type_id = TypeId::kNumberTypeInt8; | |||||
| type_id_ = TypeId::kNumberTypeInt8; | |||||
| int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1; | int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1; | ||||
| int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1)); | int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1)); | ||||
| if (type_id == TypeId::kNumberTypeInt8) { | |||||
| if (type_id_ == TypeId::kNumberTypeInt8) { | |||||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, | status = QuantFilter<int8_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, | ||||
| quant_min_t, bit_num_t, true); | quant_min_t, bit_num_t, true); | ||||
| } else if (type_id == TypeId::kNumberTypeInt16) { | |||||
| } else if (type_id_ == TypeId::kNumberTypeInt16) { | |||||
| status = QuantFilter<int16_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, | status = QuantFilter<int16_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, | ||||
| quant_min_t, bit_num_t, true); | quant_min_t, bit_num_t, true); | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "unexpected type_id: " << type_id; | |||||
| MS_LOG(ERROR) << "unexpected type_id_: " << type_id_; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| @@ -456,13 +613,53 @@ STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS WeightQuantizer::DoFixedQuant(FuncGraphPtr func_graph) { | |||||
| MS_ASSERT(func_graph != nullptr); | |||||
| for (auto &cnode : func_graph->GetOrderedCnodes()) { | |||||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||||
| if (primitive_c == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto op_name = cnode->fullname_with_scope(); | |||||
| auto op_type = (schema::PrimitiveType)primitive_c->Type(); | |||||
| if (quant_strategy_->CanConvOpQuantized(cnode)) { | |||||
| auto status = DoConvQuantize(cnode); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "DoConvQuantize error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } else if (quant_strategy_->CanMulOpQuantized(cnode)) { | |||||
| auto status = DoMulQuantize(cnode); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "DoMulQuantize error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } else if (op_type == schema::PrimitiveType_Lstm) { | |||||
| auto status = DoLstmQuntize(cnode); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "DoLstmQuntize error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } else if (op_type == schema::PrimitiveType_Gather) { | |||||
| auto status = DoGatherQuntize(cnode); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "DoGatherQuntize error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } else { | |||||
| MS_LOG(DEBUG) << op_name << " of type: " << schema::EnumNamePrimitiveType(op_type) << " no need quant"; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) { | STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) { | ||||
| MS_ASSERT(func_graph != nullptr); | MS_ASSERT(func_graph != nullptr); | ||||
| STATUS ret; | |||||
| auto cnodes = func_graph->GetOrderedCnodes(); | |||||
| if (!config_file_.empty()) { | if (!config_file_.empty()) { | ||||
| ret = ParseConfigFile(config_file_, &config_param_); | |||||
| auto ret = ParseConfigFile(config_file_, &config_param_); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "ReadConfig error."; | MS_LOG(ERROR) << "ReadConfig error."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -470,20 +667,14 @@ STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) { | |||||
| } | } | ||||
| if (config_param_.mixed) { | if (config_param_.mixed) { | ||||
| bit_num_ = 8; | |||||
| quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1; | |||||
| quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1)); | |||||
| type_id_ = kNumberTypeInt8; | |||||
| MS_LOG(INFO) << "Do mixed bit quantization"; | MS_LOG(INFO) << "Do mixed bit quantization"; | ||||
| return DoMiexedQuant(func_graph); | return DoMiexedQuant(func_graph); | ||||
| } | } | ||||
| ret = DoConvQuantize(cnodes); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "DoConvQuantize failed :" << ret; | |||||
| return ret; | |||||
| } | |||||
| ret = DoMulQuantize(cnodes); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "DoMulQuantize failed :" << ret; | |||||
| return ret; | |||||
| } | |||||
| return ret; | |||||
| return DoFixedQuant(func_graph); | |||||
| } | } | ||||
| } // namespace mindspore::lite::quant | } // namespace mindspore::lite::quant | ||||
| @@ -41,19 +41,21 @@ class WeightQuantizer : public Quantizer { | |||||
| ~WeightQuantizer(); | ~WeightQuantizer(); | ||||
| STATUS DoQuantize(FuncGraphPtr func_graph) override; | STATUS DoQuantize(FuncGraphPtr func_graph) override; | ||||
| STATUS DoConvQuantize(const std::list<CNodePtr> &nodes); | |||||
| STATUS DoMulQuantize(const std::list<CNodePtr> &nodes); | |||||
| STATUS DoConvQuantize(CNodePtr); | |||||
| STATUS DoMulQuantize(CNodePtr); | |||||
| STATUS DoLstmQuntize(CNodePtr cnode); | |||||
| STATUS DoGatherQuntize(CNodePtr cnode); | |||||
| static STATUS WeightQuantInputCheck(const converter::Flags *config); | static STATUS WeightQuantInputCheck(const converter::Flags *config); | ||||
| static bool IsPosNum(const std::string &str); | static bool IsPosNum(const std::string &str); | ||||
| int quant_max; | |||||
| int quant_min; | |||||
| TypeId type_id{kTypeUnknown}; | |||||
| int quant_max_{127}; | |||||
| int quant_min_{-128}; | |||||
| TypeId type_id_{kNumberTypeInt8}; | |||||
| std::map<std::string, int> opname_bit_; | std::map<std::string, int> opname_bit_; | ||||
| private: | private: | ||||
| std::unique_ptr<QuantStrategy> quant_strategy_; | std::unique_ptr<QuantStrategy> quant_strategy_; | ||||
| size_t bit_num_; | |||||
| size_t bit_num_{8}; | |||||
| std::string config_file_; | std::string config_file_; | ||||
| PostQuantConfig config_param_; | PostQuantConfig config_param_; | ||||
| std::vector<std::vector<std::string>> images_; // multi_input, [[mode_input_0], [model_input_1]...] | std::vector<std::vector<std::string>> images_; // multi_input, [[mode_input_0], [model_input_1]...] | ||||
| @@ -61,6 +63,7 @@ class WeightQuantizer : public Quantizer { | |||||
| STATUS DoMiexedQuant(FuncGraphPtr); | STATUS DoMiexedQuant(FuncGraphPtr); | ||||
| STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr<PrimitiveC> primitive_c); | STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr<PrimitiveC> primitive_c); | ||||
| STATUS DoFixedQuant(FuncGraphPtr); | |||||
| }; | }; | ||||
| } // namespace mindspore::lite::quant | } // namespace mindspore::lite::quant | ||||
| #endif | #endif | ||||