Browse Source

!11212 Reconstruct the tensor qequant mechanism when weight_quant && support lstm/gather weight_quant

From: @xutianchun
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
bc64a45ed9
29 changed files with 400 additions and 508 deletions
  1. +1
    -0
      mindspore/lite/src/CMakeLists.txt
  2. +42
    -3
      mindspore/lite/src/dequant.cc
  3. +10
    -3
      mindspore/lite/src/dequant.h
  4. +2
    -2
      mindspore/lite/src/lite_session.cc
  5. +0
    -33
      mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.cc
  6. +0
    -32
      mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc
  7. +0
    -1
      mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc
  8. +0
    -32
      mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc
  9. +0
    -32
      mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc
  10. +0
    -31
      mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc
  11. +0
    -1
      mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.h
  12. +0
    -31
      mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc
  13. +0
    -33
      mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.cc
  14. +0
    -30
      mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.cc
  15. +0
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.cc
  16. +0
    -30
      mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise_fp32.cc
  17. +0
    -32
      mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_fp32.cc
  18. +0
    -29
      mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection_fp32.cc
  19. +0
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection_fp32.h
  20. +0
    -34
      mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32.cc
  21. +5
    -5
      mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc
  22. +1
    -1
      mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h
  23. +5
    -0
      mindspore/lite/src/scheduler.cc
  24. +1
    -0
      mindspore/lite/test/CMakeLists.txt
  25. +1
    -0
      mindspore/lite/tools/converter/CMakeLists.txt
  26. +23
    -0
      mindspore/lite/tools/converter/quantizer/quantize_util.cc
  27. +5
    -1
      mindspore/lite/tools/converter/quantizer/quantize_util.h
  28. +295
    -104
      mindspore/lite/tools/converter/quantizer/weight_quantizer.cc
  29. +9
    -6
      mindspore/lite/tools/converter/quantizer/weight_quantizer.h

+ 1
- 0
mindspore/lite/src/CMakeLists.txt View File

@@ -39,6 +39,7 @@ set(LITE_SRC
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc
${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc
${CMAKE_CURRENT_SOURCE_DIR}/errorcode.cc ${CMAKE_CURRENT_SOURCE_DIR}/errorcode.cc
${CMAKE_CURRENT_SOURCE_DIR}/dequant.cc
) )


if (SUPPORT_GPU) if (SUPPORT_GPU)


mindspore/lite/src/runtime/kernel/arm/base/dequant.cc → mindspore/lite/src/dequant.cc View File

@@ -14,9 +14,9 @@
* limitations under the License. * limitations under the License.
*/ */
#include <cmath> #include <cmath>
#include "src/runtime/kernel/arm/base/dequant.h"
#include "src/dequant.h"


namespace mindspore::kernel {
namespace mindspore::lite {
float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) { float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) {
MS_ASSERT(input_tensor != nullptr); MS_ASSERT(input_tensor != nullptr);
if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) { if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) {
@@ -35,6 +35,8 @@ float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) {
} }


void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_int_data) { void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_int_data) {
MS_ASSERT(input_tensor != nullptr);
MS_ASSERT(unpack_int_data != nullptr);
auto quant_params = input_tensor->quantParams(); auto quant_params = input_tensor->quantParams();
if (quant_params == nullptr) { if (quant_params == nullptr) {
MS_LOG(ERROR) << "low bits quantparams is empty."; MS_LOG(ERROR) << "low bits quantparams is empty.";
@@ -47,4 +49,41 @@ void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_i
UnPackUtil<int16_t, uint16_t>(input_tensor, origin_bit, unpack_int_data); UnPackUtil<int16_t, uint16_t>(input_tensor, origin_bit, unpack_int_data);
} }
} }
} // namespace mindspore::kernel

std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const std::vector<Tensor *> &in_tensors,
TypeId data_type) {
std::map<Tensor *, std::pair<TypeId, void *>> tensor_origin_data;
if (data_type == TypeId::kNumberTypeFloat32 || data_type == TypeId::kNumberTypeFloat16) {
for (auto weight_tensor : in_tensors) {
MS_ASSERT(weight_tensor != nullptr);
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag = !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited &&
restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
return tensor_origin_data;
}
weight_tensor->set_data(dequant_weight);
weight_tensor->set_data_type(kNumberTypeFloat32);
tensor_origin_data[weight_tensor] = {restore_type, restore_data};
}
}
}
return tensor_origin_data;
}

void DequantUtil::RestoreTensorData(const std::map<Tensor *, std::pair<TypeId, void *>> &tensor_origin_data_map) {
for (auto &kv : tensor_origin_data_map) {
auto *tensor = kv.first;
auto type_id = kv.second.first;
auto data = kv.second.second;
tensor->FreeData();
tensor->set_data_type(type_id);
tensor->set_data(data);
}
}

} // namespace mindspore::lite

mindspore/lite/src/runtime/kernel/arm/base/dequant.h → mindspore/lite/src/dequant.h View File

@@ -17,6 +17,8 @@
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_ #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_


#include <map>
#include <utility>
#include <vector> #include <vector>
#include <queue> #include <queue>
#include <cmath> #include <cmath>
@@ -24,13 +26,18 @@
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/tensor.h" #include "src/tensor.h"


namespace mindspore::kernel {
namespace mindspore::lite {
class DequantUtil { class DequantUtil {
public: public:
static float *DequantWeight(lite::Tensor *input_tensor); static float *DequantWeight(lite::Tensor *input_tensor);


static void UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data); static void UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data);


static std::map<Tensor *, std::pair<TypeId, void *>> DequantTensor(const std::vector<Tensor *> &in_tensors,
TypeId data_type);

static void RestoreTensorData(const std::map<Tensor *, std::pair<TypeId, void *>> &tensor_origin_data_map);

template <typename ST, typename DT = float> template <typename ST, typename DT = float>
static DT *DequantData(lite::Tensor *input_tensor) { static DT *DequantData(lite::Tensor *input_tensor) {
const auto *quant_datas = static_cast<const ST *>(input_tensor->MutableData()); const auto *quant_datas = static_cast<const ST *>(input_tensor->MutableData());
@@ -108,7 +115,7 @@ class DequantUtil {
static void UnPackData(int origin_bit, const T2 &packed_data, std::queue<bool> *unpack_bit_data, void *unpack_int, static void UnPackData(int origin_bit, const T2 &packed_data, std::queue<bool> *unpack_bit_data, void *unpack_int,
size_t *count, bool is_last) { size_t *count, bool is_last) {
T2 uint_result = 0; T2 uint_result = 0;
T1 result = 0;
T1 result;
UnPackFromUintToOrigin<T2>(packed_data, unpack_bit_data); UnPackFromUintToOrigin<T2>(packed_data, unpack_bit_data);
while (static_cast<int>(unpack_bit_data->size()) >= origin_bit) { while (static_cast<int>(unpack_bit_data->size()) >= origin_bit) {
for (int k = 0; k < origin_bit; k++) { for (int k = 0; k < origin_bit; k++) {
@@ -163,6 +170,6 @@ class DequantUtil {
} }
} }
}; };
} // namespace mindspore::kernel
} // namespace mindspore::lite


#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_

+ 2
- 2
mindspore/lite/src/lite_session.cc View File

@@ -27,7 +27,7 @@
#include "src/common/graph_util.h" #include "src/common/graph_util.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "src/lite_model.h" #include "src/lite_model.h"
#include "src/runtime/kernel/arm/base/dequant.h"
#include "src/dequant.h"
#if SUPPORT_NPU #if SUPPORT_NPU
#include "src/runtime/agent/npu/npu_manager.h" #include "src/runtime/agent/npu/npu_manager.h"
#include "src/runtime/agent/npu/optimizer/npu_pass_manager.h" #include "src/runtime/agent/npu/optimizer/npu_pass_manager.h"
@@ -120,7 +120,7 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde
MS_LOG(ERROR) << "Malloc data for tensor failed "; MS_LOG(ERROR) << "Malloc data for tensor failed ";
return RET_ERROR; return RET_ERROR;
} }
kernel::DequantUtil::UnPackToInt(src_tensor, dst_tensor->MutableData());
DequantUtil::UnPackToInt(src_tensor, dst_tensor->MutableData());
copyed_tensor_idxes_.emplace_back(tensor_index); copyed_tensor_idxes_.emplace_back(tensor_index);
} else { } else {
dst_tensor->set_data(const_cast<unsigned char *>(src_tensor->data()->data())); dst_tensor->set_data(const_cast<unsigned char *>(src_tensor->data()->data()));


+ 0
- 33
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.cc View File

@@ -25,7 +25,6 @@
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"


using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
@@ -359,22 +358,6 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
MS_ASSERT(opParameter != nullptr); MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D);


auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag =
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
return nullptr;
}
weight_tensor->set_data_type(kNumberTypeFloat32);
weight_tensor->set_data(dequant_weight);
}

auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
kernel::LiteKernel *kernel = nullptr; kernel::LiteKernel *kernel = nullptr;
if (conv_param->group_ == 1) { if (conv_param->group_ == 1) {
@@ -385,11 +368,6 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &


if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(DEBUG) << "Create conv fp16 kernel failed."; MS_LOG(DEBUG) << "Create conv fp16 kernel failed.";
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter); free(opParameter);
return nullptr; return nullptr;
} }
@@ -398,20 +376,9 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(INFO) << "Init fp16 kernel failed, name: " << opParameter->name_ MS_LOG(INFO) << "Init fp16 kernel failed, name: " << opParameter->name_
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
delete kernel; delete kernel;
return nullptr; return nullptr;
} }

if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel; return kernel;
} }
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Conv2D, CpuConvFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Conv2D, CpuConvFp16KernelCreator)


+ 0
- 32
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc View File

@@ -22,7 +22,6 @@
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"


using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
@@ -138,22 +137,6 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
MS_ASSERT(opParameter != nullptr); MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D);


auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag =
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
return nullptr;
}
weight_tensor->set_data_type(kNumberTypeFloat32);
weight_tensor->set_data(dequant_weight);
}

auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
kernel::LiteKernel *kernel; kernel::LiteKernel *kernel;
if (conv_param->input_channel_ < 32) { if (conv_param->input_channel_ < 32) {
@@ -164,11 +147,6 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
} }
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr."; MS_LOG(ERROR) << "kernel is nullptr.";
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter); free(opParameter);
return nullptr; return nullptr;
} }
@@ -176,19 +154,9 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
delete kernel; delete kernel;
return nullptr; return nullptr;
} }
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel; return kernel;
} }




+ 0
- 1
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc View File

@@ -20,7 +20,6 @@
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"
#include "nnacl/fp16/conv_fp16.h" #include "nnacl/fp16/conv_fp16.h"
#include "nnacl/fp16/matmul_fp16.h" #include "nnacl/fp16/matmul_fp16.h"
#include "nnacl/fp16/cast_fp16.h" #include "nnacl/fp16/cast_fp16.h"


+ 0
- 32
mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc View File

@@ -20,7 +20,6 @@
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"


using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
@@ -212,30 +211,9 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor
MS_ASSERT(opParameter != nullptr); MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D); MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D);


auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
auto dequant_flag =
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
return nullptr;
}
weight_tensor->set_data_type(kNumberTypeFloat32);
weight_tensor->set_data(dequant_weight);
}

auto kernel = new (std::nothrow) DeconvolutionDepthwiseFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive); auto kernel = new (std::nothrow) DeconvolutionDepthwiseFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr."; MS_LOG(ERROR) << "kernel is nullptr.";
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter); free(opParameter);
return nullptr; return nullptr;
} }
@@ -243,19 +221,9 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
delete kernel; delete kernel;
return nullptr; return nullptr;
} }
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel; return kernel;
} }




+ 0
- 32
mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc View File

@@ -17,7 +17,6 @@
#include "src/runtime/kernel/arm/fp16/deconvolution_fp16.h" #include "src/runtime/kernel/arm/fp16/deconvolution_fp16.h"
#include "src/runtime/kernel/arm/fp16/deconvolution_winograd_fp16.h" #include "src/runtime/kernel/arm/fp16/deconvolution_winograd_fp16.h"
#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"


using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
@@ -220,22 +219,6 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
MS_ASSERT(opParameter != nullptr); MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D);


auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
auto dequant_flag =
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
return nullptr;
}
weight_tensor->set_data_type(kNumberTypeFloat32);
weight_tensor->set_data(dequant_weight);
}

kernel::LiteKernel *kernel; kernel::LiteKernel *kernel;
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
if ((conv_param->stride_h_ != 1 || conv_param->stride_w_ != 1) && if ((conv_param->stride_h_ != 1 || conv_param->stride_w_ != 1) &&
@@ -247,11 +230,6 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>


if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr."; MS_LOG(ERROR) << "kernel is nullptr.";
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter); free(opParameter);
return nullptr; return nullptr;
} }
@@ -259,19 +237,9 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
delete kernel; delete kernel;
return nullptr; return nullptr;
} }
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel; return kernel;
} }
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_DeConv2D, CpuDeConvFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_DeConv2D, CpuDeConvFp16KernelCreator)


+ 0
- 31
mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc View File

@@ -234,30 +234,9 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T
OpParameter *opParameter, const lite::InnerContext *ctx, OpParameter *opParameter, const lite::InnerContext *ctx,
const kernel::KernelKey &desc, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) { const mindspore::lite::PrimitiveC *primitive) {
auto *weight_tensor = inputs.at(kWeightIndex);
// data of second tensor of fc may be nullptr
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag =
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
return nullptr;
}
weight_tensor->set_data_type(kNumberTypeFloat32);
weight_tensor->set_data(dequant_weight);
}
auto *kernel = new (std::nothrow) FullconnectionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); auto *kernel = new (std::nothrow) FullconnectionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr."; MS_LOG(ERROR) << "kernel is nullptr.";
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter); free(opParameter);
return nullptr; return nullptr;
} }
@@ -265,19 +244,9 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
delete kernel; delete kernel;
return nullptr; return nullptr;
} }
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel; return kernel;
} }




+ 0
- 1
mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.h View File

@@ -24,7 +24,6 @@
#include "nnacl/fp16/matmul_fp16.h" #include "nnacl/fp16/matmul_fp16.h"
#include "nnacl/fp16/cast_fp16.h" #include "nnacl/fp16/cast_fp16.h"
#include "src/lite_kernel.h" #include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/base/dequant.h"


namespace mindspore::kernel { namespace mindspore::kernel {
class FullconnectionFP16CPUKernel : public LiteKernel { class FullconnectionFP16CPUKernel : public LiteKernel {


+ 0
- 31
mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc View File

@@ -20,7 +20,6 @@
#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/base/dequant.h"


using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR; using mindspore::lite::RET_ERROR;
@@ -330,29 +329,9 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *>
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc, const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) { const mindspore::lite::PrimitiveC *primitive) {
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag =
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
return nullptr;
}
weight_tensor->set_data_type(kNumberTypeFloat32);
weight_tensor->set_data(dequant_weight);
}
auto *kernel = new (std::nothrow) MatmulFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); auto *kernel = new (std::nothrow) MatmulFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr."; MS_LOG(ERROR) << "kernel is nullptr.";
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter); free(opParameter);
return nullptr; return nullptr;
} }
@@ -361,18 +340,8 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *>
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel; delete kernel;
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return nullptr; return nullptr;
} }
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel; return kernel;
} }




+ 0
- 33
mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.cc View File

@@ -22,7 +22,6 @@
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"


using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
@@ -356,22 +355,6 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D);
MS_ASSERT(desc.data_type == kNumberTypeFloat32); MS_ASSERT(desc.data_type == kNumberTypeFloat32);


// if get quantized weight, dequantize it to float32 type data.
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag =
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(op_parameter);
return nullptr;
}
weight_tensor->set_data(dequant_weight);
}

auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter); auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter);
kernel::LiteKernel *kernel = nullptr; kernel::LiteKernel *kernel = nullptr;
if (conv_param->group_ == 1) { if (conv_param->group_ == 1) {
@@ -382,11 +365,6 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &


if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr."; MS_LOG(ERROR) << "kernel is nullptr.";
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(op_parameter); free(op_parameter);
return nullptr; return nullptr;
} }
@@ -395,20 +373,9 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
if (ret != RET_OK && ret != RET_INFER_INVALID) { if (ret != RET_OK && ret != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
delete kernel; delete kernel;
return nullptr; return nullptr;
} }

if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel; return kernel;
} }




+ 0
- 30
mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.cc View File

@@ -21,7 +21,6 @@
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"


using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
@@ -126,19 +125,6 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
MS_ASSERT(opParameter != nullptr); MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D);


auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
return nullptr;
}
weight_tensor->set_data(dequant_weight);
}

auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
kernel::LiteKernel *kernel = nullptr; kernel::LiteKernel *kernel = nullptr;
if (primitive != nullptr && primitive->infer_flag()) { if (primitive != nullptr && primitive->infer_flag()) {
@@ -162,11 +148,6 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
} }
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr."; MS_LOG(ERROR) << "kernel is nullptr.";
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter); free(opParameter);
return nullptr; return nullptr;
} }
@@ -174,21 +155,10 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
if (ret != RET_OK && ret != RET_INFER_INVALID) { if (ret != RET_OK && ret != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
delete kernel; delete kernel;
return nullptr; return nullptr;
} }


if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}

return kernel; return kernel;
} }




+ 0
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.cc View File

@@ -19,7 +19,6 @@
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"
#include "nnacl/fp32/conv_fp32.h" #include "nnacl/fp32/conv_fp32.h"
#include "nnacl/fp32/matmul_fp32.h" #include "nnacl/fp32/matmul_fp32.h"




+ 0
- 30
mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise_fp32.cc View File

@@ -19,7 +19,6 @@
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"


using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
@@ -202,29 +201,10 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor
const mindspore::lite::PrimitiveC *primitive) { const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr); MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D); MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D);
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag =
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
return nullptr;
}
weight_tensor->set_data(dequant_weight);
}
auto kernel = auto kernel =
new (std::nothrow) kernel::DeconvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx, primitive); new (std::nothrow) kernel::DeconvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr."; MS_LOG(ERROR) << "kernel is nullptr.";
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter); free(opParameter);
return nullptr; return nullptr;
} }
@@ -232,19 +212,9 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
delete kernel; delete kernel;
return nullptr; return nullptr;
} }
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel; return kernel;
} }




+ 0
- 32
mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_fp32.cc View File

@@ -17,7 +17,6 @@
#include "src/runtime/kernel/arm/fp32/deconvolution_fp32.h" #include "src/runtime/kernel/arm/fp32/deconvolution_fp32.h"
#include "src/runtime/kernel/arm/fp32/deconvolution_winograd_fp32.h" #include "src/runtime/kernel/arm/fp32/deconvolution_winograd_fp32.h"
#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"


using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
@@ -240,20 +239,6 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *>
const mindspore::lite::PrimitiveC *primitive) { const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr); MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D);
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag =
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
return nullptr;
}
weight_tensor->set_data(dequant_weight);
}


kernel::LiteKernel *kernel; kernel::LiteKernel *kernel;
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
@@ -266,11 +251,6 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *>


if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr."; MS_LOG(ERROR) << "kernel is nullptr.";
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter); free(opParameter);
return nullptr; return nullptr;
} }
@@ -278,21 +258,9 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *>
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
delete kernel; delete kernel;
return nullptr; return nullptr;
} }

if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}

return kernel; return kernel;
} }




+ 0
- 29
mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection_fp32.cc View File

@@ -228,28 +228,9 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T
const mindspore::lite::PrimitiveC *primitive) { const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr); MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_FullConnection); MS_ASSERT(desc.type == schema::PrimitiveType_FullConnection);
auto *weight_tensor = inputs.at(kWeightIndex);
// data of second tensor of fc may be nullptr
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag =
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
return nullptr;
}
weight_tensor->set_data(dequant_weight);
}
auto kernel = new (std::nothrow) FullconnectionCPUKernel(opParameter, inputs, outputs, ctx, primitive); auto kernel = new (std::nothrow) FullconnectionCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (!kernel) { if (!kernel) {
MS_LOG(ERROR) << "kernel is nullptr."; MS_LOG(ERROR) << "kernel is nullptr.";
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter); free(opParameter);
return nullptr; return nullptr;
} }
@@ -257,19 +238,9 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
delete kernel; delete kernel;
return nullptr; return nullptr;
} }
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel; return kernel;
} }




+ 0
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection_fp32.h View File

@@ -22,7 +22,6 @@
#include "include/errorcode.h" #include "include/errorcode.h"
#include "nnacl/fp32/matmul_fp32.h" #include "nnacl/fp32/matmul_fp32.h"
#include "src/lite_kernel.h" #include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/base/dequant.h"


using mindspore::lite::InnerContext; using mindspore::lite::InnerContext;
namespace mindspore::kernel { namespace mindspore::kernel {


+ 0
- 34
mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32.cc View File

@@ -19,7 +19,6 @@
#include "nnacl/fp32/matmul_fp32.h" #include "nnacl/fp32/matmul_fp32.h"
#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/base/dequant.h"


using mindspore::lite::RET_ERROR; using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_INPUT_TENSOR_ERROR; using mindspore::lite::RET_INPUT_TENSOR_ERROR;
@@ -417,30 +416,9 @@ kernel::LiteKernel *CpuMatmulFp32KernelCreator(const std::vector<lite::Tensor *>
const mindspore::lite::PrimitiveC *primitive) { const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr); MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_MatMul); MS_ASSERT(desc.type == schema::PrimitiveType_MatMul);

auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag =
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
return nullptr;
}
weight_tensor->set_data(dequant_weight);
}

auto kernel = new (std::nothrow) MatmulCPUKernel(opParameter, inputs, outputs, ctx, primitive); auto kernel = new (std::nothrow) MatmulCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr."; MS_LOG(ERROR) << "kernel is nullptr.";
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter); free(opParameter);
return nullptr; return nullptr;
} }
@@ -448,21 +426,9 @@ kernel::LiteKernel *CpuMatmulFp32KernelCreator(const std::vector<lite::Tensor *>
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
delete kernel; delete kernel;
return nullptr; return nullptr;
} }

if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}

return kernel; return kernel;
} }




+ 5
- 5
mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc View File

@@ -15,7 +15,7 @@
*/ */


#include "src/runtime/kernel/opencl/opencl_kernel.h" #include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "src/runtime/kernel/arm/base/dequant.h"
#include "mindspore/lite/src/dequant.h"


using mindspore::lite::RET_ERROR; using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK; using mindspore::lite::RET_OK;
@@ -263,10 +263,10 @@ int OpenCLKernel::DequantWeight() {
if (is_fp16) { if (is_fp16) {
#ifdef ENABLE_ARM64 #ifdef ENABLE_ARM64
if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt8) { if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt8) {
dequant_weight = kernel::DequantUtil::DequantData<int8_t, float16_t>(weight_tensor);
dequant_weight = lite::DequantUtil::DequantData<int8_t, float16_t>(weight_tensor);
weight_tensor->set_data_type(kNumberTypeFloat16); weight_tensor->set_data_type(kNumberTypeFloat16);
} else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt16) { } else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt16) {
dequant_weight = kernel::DequantUtil::DequantData<int16_t, float16_t>(weight_tensor);
dequant_weight = lite::DequantUtil::DequantData<int16_t, float16_t>(weight_tensor);
weight_tensor->set_data_type(kNumberTypeFloat16); weight_tensor->set_data_type(kNumberTypeFloat16);
} else { } else {
set_flag = false; set_flag = false;
@@ -276,10 +276,10 @@ int OpenCLKernel::DequantWeight() {
#endif #endif
} else { } else {
if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt8) { if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt8) {
dequant_weight = kernel::DequantUtil::DequantData<int8_t, float>(weight_tensor);
dequant_weight = lite::DequantUtil::DequantData<int8_t, float>(weight_tensor);
weight_tensor->set_data_type(kNumberTypeFloat32); weight_tensor->set_data_type(kNumberTypeFloat32);
} else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt16) { } else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt16) {
dequant_weight = kernel::DequantUtil::DequantData<int16_t, float>(weight_tensor);
dequant_weight = lite::DequantUtil::DequantData<int16_t, float>(weight_tensor);
weight_tensor->set_data_type(kNumberTypeFloat32); weight_tensor->set_data_type(kNumberTypeFloat32);
} else { } else {
set_flag = false; set_flag = false;


+ 1
- 1
mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h View File

@@ -25,7 +25,7 @@
#include "src/lite_kernel.h" #include "src/lite_kernel.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/runtime/opencl/opencl_runtime.h" #include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/kernel/arm/base/dequant.h"
#include "mindspore/lite/src/dequant.h"
#include "src/runtime/kernel/opencl/utils.h" #include "src/runtime/kernel/opencl/utils.h"


using mindspore::lite::RET_ERROR; using mindspore::lite::RET_ERROR;


+ 5
- 0
mindspore/lite/src/scheduler.cc View File

@@ -27,6 +27,7 @@
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "src/sub_graph_kernel.h" #include "src/sub_graph_kernel.h"
#include "src/dequant.h"
#if SUPPORT_GPU #if SUPPORT_GPU
#include "src/runtime/kernel/opencl/opencl_subgraph.h" #include "src/runtime/kernel/opencl/opencl_subgraph.h"
#include "src/runtime/opencl/opencl_runtime.h" #include "src/runtime/opencl/opencl_runtime.h"
@@ -213,8 +214,10 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
if (mindspore::lite::IsSupportFloat16() && if (mindspore::lite::IsSupportFloat16() &&
((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) { ((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) {
kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type}; kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type};
auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, fp16_cpu_desc.data_type);
auto *kernel = auto *kernel =
KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc); KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc);
DequantUtil::RestoreTensorData(tensor_origin_data_map);
if (kernel != nullptr) { if (kernel != nullptr) {
MS_LOG(DEBUG) << "Get fp16 op success: " << schema::EnumNamePrimitiveType(fp16_cpu_desc.type) << " " MS_LOG(DEBUG) << "Get fp16 op success: " << schema::EnumNamePrimitiveType(fp16_cpu_desc.type) << " "
<< node->name_; << node->name_;
@@ -225,7 +228,9 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op."; MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op.";
desc.data_type = kNumberTypeFloat32; desc.data_type = kNumberTypeFloat32;
} }
auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, desc.data_type);
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc);
DequantUtil::RestoreTensorData(tensor_origin_data_map);
if (kernel != nullptr) { if (kernel != nullptr) {
return kernel; return kernel;
} }


+ 1
- 0
mindspore/lite/test/CMakeLists.txt View File

@@ -126,6 +126,7 @@ set(TEST_LITE_SRC
${LITE_DIR}/src/kernel_registry.cc ${LITE_DIR}/src/kernel_registry.cc
${LITE_DIR}/src/lite_kernel.cc ${LITE_DIR}/src/lite_kernel.cc
${LITE_DIR}/src/lite_session.cc ${LITE_DIR}/src/lite_session.cc
${LITE_DIR}/src/dequant.cc
${LITE_DIR}/src/sub_graph_kernel.cc ${LITE_DIR}/src/sub_graph_kernel.cc
${LITE_DIR}/src/lite_model.cc ${LITE_DIR}/src/lite_model.cc
${LITE_DIR}/src/scheduler.cc ${LITE_DIR}/src/scheduler.cc


+ 1
- 0
mindspore/lite/tools/converter/CMakeLists.txt View File

@@ -95,6 +95,7 @@ set(LITE_SRC
${SRC_DIR}/executor.cc ${SRC_DIR}/executor.cc
${SRC_DIR}/lite_model.cc ${SRC_DIR}/lite_model.cc
${SRC_DIR}/errorcode.cc ${SRC_DIR}/errorcode.cc
${SRC_DIR}/dequant.cc
) )
if (SUPPORT_TRAIN) if (SUPPORT_TRAIN)
set(LITE_SRC set(LITE_SRC


+ 23
- 0
mindspore/lite/tools/converter/quantizer/quantize_util.cc View File

@@ -782,4 +782,27 @@ FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &func_graph) {
return new_func_graph; return new_func_graph;
} }


void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, ParamValueLitePtr *param_value) {
MS_ASSERT(node != nullptr);
MS_ASSERT(param_node != nullptr);
MS_ASSERT(param_value != nullptr);

auto op_name = node->fullname_with_scope();

*param_node = node->cast<ParameterPtr>();
if (*param_node == nullptr) {
MS_LOG(INFO) << op_name << " can not cast to ParameterPtr";
return;
}
if (!(*param_node)->has_default()) {
MS_LOG(INFO) << op_name << " not has_default";
return;
}

*param_value = std::static_pointer_cast<ParamValueLite>((*param_node)->default_param());
if (*param_value == nullptr) {
MS_LOG(INFO) << "default_param can not cast to ParamValueLite";
return;
}
}
} // namespace mindspore::lite::quant } // namespace mindspore::lite::quant

+ 5
- 1
mindspore/lite/tools/converter/quantizer/quantize_util.h View File

@@ -75,9 +75,10 @@ class QuantStrategy {
bool CanMulOpQuantized(const CNodePtr &node) const; bool CanMulOpQuantized(const CNodePtr &node) const;
bool CanOpPostQuantized(AnfNodePtr &node) const; bool CanOpPostQuantized(AnfNodePtr &node) const;


private:
size_t mWeightSize; size_t mWeightSize;
size_t mConvWeightQuantChannelThreshold; size_t mConvWeightQuantChannelThreshold;

private:
static const std::vector<schema::PrimitiveType> conv_types; static const std::vector<schema::PrimitiveType> conv_types;
static const std::vector<schema::PrimitiveType> mul_types; static const std::vector<schema::PrimitiveType> mul_types;
}; };
@@ -356,5 +357,8 @@ STATUS CopyInputDataToTensor(size_t input_index, size_t image_index,
const std::vector<std::vector<std::string>> &images, mindspore::tensor::MSTensor *tensor); const std::vector<std::vector<std::string>> &images, mindspore::tensor::MSTensor *tensor);


FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &); FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &);

void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, ParamValueLitePtr *param_value);

} // namespace mindspore::lite::quant } // namespace mindspore::lite::quant
#endif #endif

+ 295
- 104
mindspore/lite/tools/converter/quantizer/weight_quantizer.cc View File

@@ -20,7 +20,6 @@
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include "src/common/common.h" #include "src/common/common.h"
#include "ir/dtype/type_id.h"


using std::string; using std::string;
using std::vector; using std::vector;
@@ -73,13 +72,13 @@ WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const std::string &config_f
this->bit_num_ = static_cast<size_t>(std::stoull(bitNum)); this->bit_num_ = static_cast<size_t>(std::stoull(bitNum));
auto convQuantWeightChannelThreshold = static_cast<size_t>(std::stoull(convWeightChannelThreshold)); auto convQuantWeightChannelThreshold = static_cast<size_t>(std::stoull(convWeightChannelThreshold));
quant_strategy_ = std::make_unique<QuantStrategy>(quantSize, convQuantWeightChannelThreshold); quant_strategy_ = std::make_unique<QuantStrategy>(quantSize, convQuantWeightChannelThreshold);
quant_max = (1 << (unsigned int)(this->bit_num_ - 1)) - 1;
quant_min = -(1 << (unsigned int)(this->bit_num_ - 1));
// parse type_id
quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1;
quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1));
// parse type_id_
if (this->bit_num_ > 0 && this->bit_num_ <= 8) { if (this->bit_num_ > 0 && this->bit_num_ <= 8) {
type_id = kNumberTypeInt8;
type_id_ = kNumberTypeInt8;
} else if (this->bit_num_ <= 16) { } else if (this->bit_num_ <= 16) {
type_id = kNumberTypeInt16;
type_id_ = kNumberTypeInt16;
} else { } else {
MS_LOG(ERROR) << "invalid input bits"; MS_LOG(ERROR) << "invalid input bits";
} }
@@ -90,7 +89,7 @@ WeightQuantizer::~WeightQuantizer() { delete fp32_session_; }
STATUS WeightQuantizer::SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, STATUS WeightQuantizer::SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node,
std::shared_ptr<PrimitiveC> primitive_c) { std::shared_ptr<PrimitiveC> primitive_c) {
// set dtype // set dtype
param_value->set_tensor_type(type_id);
param_value->set_tensor_type(type_id_);
auto abstract_base = param_node->abstract(); auto abstract_base = param_node->abstract();
if (abstract_base == nullptr) { if (abstract_base == nullptr) {
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
@@ -101,49 +100,158 @@ STATUS WeightQuantizer::SetAbstract(ParamValueLitePtr param_value, ParameterPtr
return RET_ERROR; return RET_ERROR;
} }
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base); auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
abstract_tensor->element()->set_type(TypeIdToType(type_id));
abstract_tensor->element()->set_type(TypeIdToType(type_id_));
primitive_c->set_quant_type(schema::QuantType_WeightQuant); primitive_c->set_quant_type(schema::QuantType_WeightQuant);


return RET_OK; return RET_OK;
} }


STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
for (auto &cnode : nodes) {
if (!quant_strategy_->CanConvOpQuantized(cnode)) {
continue;
}
STATUS WeightQuantizer::DoConvQuantize(CNodePtr cnode) {
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr";
return RET_ERROR;
}


auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr";
return RET_ERROR;
}
auto input_node = cnode->input(2);
if (!input_node->isa<Parameter>()) {
return RET_ERROR;
}


auto input_node = cnode->input(2);
if (!input_node->isa<Parameter>()) {
return RET_ERROR;
}
ParameterPtr param_node;
ParamValueLitePtr param_value;


auto param_node = input_node->cast<ParameterPtr>();
if (!param_node->has_default()) {
return RET_ERROR;
GetLiteParameter(input_node, &param_node, &param_value);
if (param_node == nullptr || param_value == nullptr) {
MS_LOG(ERROR) << "GetLiteParameter error";
return RET_ERROR;
}

if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) {
MS_LOG(ERROR) << "model weight data type invalid which is " << param_value->tensor_type();
return RET_ERROR;
}
auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) {
status =
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true);
} else if (type_id_ == kNumberTypeInt16) {
status =
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR;
}
return RET_OK;
}

STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) {
auto already_quant = false;
ParamValueLitePtr param_value = nullptr;
ParameterPtr param_node = nullptr;
for (size_t i = 1; i < cnode->size(); i++) {
auto inputNode = cnode->input(i);
if (inputNode->isa<Parameter>()) {
param_node = inputNode->cast<ParameterPtr>();
if ((param_node != nullptr) && param_node->has_default()) {
param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
if ((param_value == nullptr) || (param_value->tensor_size() == 0) || (param_value->tensor_addr() == nullptr)) {
param_value = nullptr;
continue;
} else if (param_value->tensor_type() == mindspore::kNumberTypeInt8 ||
param_value->tensor_type() == mindspore::kNumberTypeInt16) {
MS_LOG(INFO) << "the node: " << cnode->fullname_with_scope() << " input_i: " << i << "has been "
<< " quantized";
already_quant = true;
break;
} else if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) {
param_value = nullptr;
continue;
} else {
break;
}
}
} }
}


ParamValueLitePtr param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
if (param_value == nullptr) {
if (already_quant) {
return RET_OK;
}

if (param_value == nullptr) {
MS_LOG(ERROR) << "No valid input param node !";
return RET_ERROR;
}

auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr";
return RET_ERROR;
}

auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) {
status =
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true);
} else if (type_id_ == kNumberTypeInt16) {
status =
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR;
}

return RET_OK;
}

STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) {
MS_ASSERT(cnode != nullptr);
auto op_name = cnode->fullname_with_scope();

auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
MS_ASSERT(primitive_c != nullptr);

if (cnode->inputs().size() < 4) {
MS_LOG(ERROR) << op_name << " inputs is " << cnode->inputs().size();
return RET_ERROR;
}
{
auto weight_i = cnode->input(2);
ParameterPtr param_node;
ParamValueLitePtr param_value;
GetLiteParameter(weight_i, &param_node, &param_value);
if (param_node == nullptr || param_value == nullptr) {
MS_LOG(ERROR) << "GetLiteParameter error";
return RET_ERROR; return RET_ERROR;
} }
if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) {
MS_LOG(ERROR) << "model weight data type invalid which is " << param_value->tensor_type();
return RET_ERROR;
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(WARNING) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";
return RET_OK;
}
if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) {
MS_LOG(INFO) << op_name << " weight_i cnt: " << param_value->tensor_size() / 4 << " < "
<< quant_strategy_->mWeightSize;
return RET_OK;
} }
auto status = RET_ERROR; auto status = RET_ERROR;
if (type_id == kNumberTypeInt8) {
if (type_id_ == kNumberTypeInt8) {
status = status =
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true);
} else if (type_id == kNumberTypeInt16) {
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
} else if (type_id_ == kNumberTypeInt16) {
status = status =
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true);
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
} }
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status; MS_LOG(ERROR) << "QuantFilter failed : " << status;
@@ -155,65 +263,26 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
return RET_ERROR; return RET_ERROR;
} }
} }
return RET_OK;
}

STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) {
for (auto &node : nodes) {
if (!quant_strategy_->CanMulOpQuantized(node)) {
continue;
}
auto already_quant = false;
ParamValueLitePtr param_value = nullptr;
ParameterPtr param_node = nullptr;
for (size_t i = 1; i < node->size(); i++) {
auto inputNode = node->input(i);
if (inputNode->isa<Parameter>()) {
param_node = inputNode->cast<ParameterPtr>();
if ((param_node != nullptr) && param_node->has_default()) {
param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
if ((param_value == nullptr) || (param_value->tensor_size() == 0) ||
(param_value->tensor_addr() == nullptr)) {
param_value = nullptr;
continue;
} else if (param_value->tensor_type() == mindspore::kNumberTypeInt8 ||
param_value->tensor_type() == mindspore::kNumberTypeInt16) {
MS_LOG(INFO) << "the node: " << node->fullname_with_scope() << " input_i: " << i << "has been "
<< " quantized";
already_quant = true;
break;
} else if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) {
param_value = nullptr;
continue;
} else {
break;
}
}
}
}

if (already_quant) {
continue;
}

if (param_value == nullptr) {
MS_LOG(ERROR) << "No valid input param node !";
{
auto weight_h = cnode->input(3);
ParameterPtr param_node;
ParamValueLitePtr param_value;
GetLiteParameter(weight_h, &param_node, &param_value);
if (param_node == nullptr || param_value == nullptr) {
MS_LOG(ERROR) << "GetLiteParameter error";
return RET_ERROR; return RET_ERROR;
} }

auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(node->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr";
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(ERROR) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";
return RET_ERROR; return RET_ERROR;
} }

auto status = RET_ERROR; auto status = RET_ERROR;
if (type_id == kNumberTypeInt8) {
if (type_id_ == kNumberTypeInt8) {
status = status =
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true);
} else if (type_id == kNumberTypeInt16) {
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
} else if (type_id_ == kNumberTypeInt16) {
status = status =
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true);
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
} }
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status; MS_LOG(ERROR) << "QuantFilter failed : " << status;
@@ -225,7 +294,78 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) {
return RET_ERROR; return RET_ERROR;
} }
} }
{
if (cnode->inputs().size() > 4) {
auto bias = cnode->input(4);
ParameterPtr param_node;
ParamValueLitePtr param_value;
GetLiteParameter(bias, &param_node, &param_value);
if (param_node == nullptr || param_value == nullptr) {
MS_LOG(ERROR) << "GetLiteParameter error";
return RET_ERROR;
}
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(ERROR) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";
return RET_ERROR;
}
auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) {
status =
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
} else if (type_id_ == kNumberTypeInt16) {
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR;
}
}
}
return RET_OK;
}

STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) {
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
MS_ASSERT(primitive_c != nullptr);

auto weight_h = cnode->input(1);
ParameterPtr param_node;
ParamValueLitePtr param_value;
GetLiteParameter(weight_h, &param_node, &param_value);
if (param_node == nullptr || param_value == nullptr || param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(INFO) << "This Gather op " << cnode->fullname_with_scope() << " can not quant weight";
return RET_OK;
}

if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) {
MS_LOG(INFO) << cnode->fullname_with_scope() << " param cnt: " << param_value->tensor_size() / 4 << " < "
<< quant_strategy_->mWeightSize;
return RET_OK;
}


auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) {
status =
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
} else if (type_id_ == kNumberTypeInt16) {
status =
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR;
}
return RET_OK; return RET_OK;
} }


@@ -315,6 +455,23 @@ STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) {
} }


auto cnodes = func_graph->GetOrderedCnodes(); auto cnodes = func_graph->GetOrderedCnodes();
for (auto &cnode : cnodes) {
auto op_type = NodePrimitiveType(cnode);
if (op_type == schema::PrimitiveType_Lstm) {
status = DoLstmQuntize(cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoLstmQuntize error";
return RET_ERROR;
}
} else if (op_type == schema::PrimitiveType_Gather) {
status = DoGatherQuntize(cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoGatherQuntize error";
return RET_ERROR;
}
}
}

for (auto iter = cnodes.end(); iter != cnodes.begin();) { for (auto iter = cnodes.end(); iter != cnodes.begin();) {
auto cnode = *(--iter); auto cnode = *(--iter);
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
@@ -357,18 +514,18 @@ STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) {
} }
// 1. try quant // 1. try quant
for (int bit_num_t = 2; bit_num_t <= 8; bit_num_t++) { for (int bit_num_t = 2; bit_num_t <= 8; bit_num_t++) {
type_id = TypeId::kNumberTypeInt8;
type_id_ = TypeId::kNumberTypeInt8;
int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1; int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1;
int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1)); int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1));


if (type_id == TypeId::kNumberTypeInt8) {
if (type_id_ == TypeId::kNumberTypeInt8) {
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, status = QuantFilter<int8_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t,
quant_min_t, bit_num_t, true); quant_min_t, bit_num_t, true);
} else if (type_id == TypeId::kNumberTypeInt16) {
} else if (type_id_ == TypeId::kNumberTypeInt16) {
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, status = QuantFilter<int16_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t,
quant_min_t, bit_num_t, true); quant_min_t, bit_num_t, true);
} else { } else {
MS_LOG(ERROR) << "unexpected type_id: " << type_id;
MS_LOG(ERROR) << "unexpected type_id_: " << type_id_;
return RET_ERROR; return RET_ERROR;
} }
if (status != RET_OK) { if (status != RET_OK) {
@@ -456,13 +613,53 @@ STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) {
return RET_OK; return RET_OK;
} }


STATUS WeightQuantizer::DoFixedQuant(FuncGraphPtr func_graph) {
MS_ASSERT(func_graph != nullptr);
for (auto &cnode : func_graph->GetOrderedCnodes()) {
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr";
return RET_ERROR;
}
auto op_name = cnode->fullname_with_scope();
auto op_type = (schema::PrimitiveType)primitive_c->Type();

if (quant_strategy_->CanConvOpQuantized(cnode)) {
auto status = DoConvQuantize(cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoConvQuantize error";
return RET_ERROR;
}
} else if (quant_strategy_->CanMulOpQuantized(cnode)) {
auto status = DoMulQuantize(cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoMulQuantize error";
return RET_ERROR;
}
} else if (op_type == schema::PrimitiveType_Lstm) {
auto status = DoLstmQuntize(cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoLstmQuntize error";
return RET_ERROR;
}
} else if (op_type == schema::PrimitiveType_Gather) {
auto status = DoGatherQuntize(cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoGatherQuntize error";
return RET_ERROR;
}
} else {
MS_LOG(DEBUG) << op_name << " of type: " << schema::EnumNamePrimitiveType(op_type) << " no need quant";
}
}
return RET_OK;
}

STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) { STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
MS_ASSERT(func_graph != nullptr); MS_ASSERT(func_graph != nullptr);
STATUS ret;
auto cnodes = func_graph->GetOrderedCnodes();


if (!config_file_.empty()) { if (!config_file_.empty()) {
ret = ParseConfigFile(config_file_, &config_param_);
auto ret = ParseConfigFile(config_file_, &config_param_);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "ReadConfig error."; MS_LOG(ERROR) << "ReadConfig error.";
return RET_ERROR; return RET_ERROR;
@@ -470,20 +667,14 @@ STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
} }


if (config_param_.mixed) { if (config_param_.mixed) {
bit_num_ = 8;
quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1;
quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1));
type_id_ = kNumberTypeInt8;
MS_LOG(INFO) << "Do mixed bit quantization"; MS_LOG(INFO) << "Do mixed bit quantization";
return DoMiexedQuant(func_graph); return DoMiexedQuant(func_graph);
} }


ret = DoConvQuantize(cnodes);
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoConvQuantize failed :" << ret;
return ret;
}
ret = DoMulQuantize(cnodes);
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoMulQuantize failed :" << ret;
return ret;
}
return ret;
return DoFixedQuant(func_graph);
} }
} // namespace mindspore::lite::quant } // namespace mindspore::lite::quant

+ 9
- 6
mindspore/lite/tools/converter/quantizer/weight_quantizer.h View File

@@ -41,19 +41,21 @@ class WeightQuantizer : public Quantizer {
~WeightQuantizer(); ~WeightQuantizer();


STATUS DoQuantize(FuncGraphPtr func_graph) override; STATUS DoQuantize(FuncGraphPtr func_graph) override;
STATUS DoConvQuantize(const std::list<CNodePtr> &nodes);
STATUS DoMulQuantize(const std::list<CNodePtr> &nodes);
STATUS DoConvQuantize(CNodePtr);
STATUS DoMulQuantize(CNodePtr);
STATUS DoLstmQuntize(CNodePtr cnode);
STATUS DoGatherQuntize(CNodePtr cnode);
static STATUS WeightQuantInputCheck(const converter::Flags *config); static STATUS WeightQuantInputCheck(const converter::Flags *config);
static bool IsPosNum(const std::string &str); static bool IsPosNum(const std::string &str);


int quant_max;
int quant_min;
TypeId type_id{kTypeUnknown};
int quant_max_{127};
int quant_min_{-128};
TypeId type_id_{kNumberTypeInt8};
std::map<std::string, int> opname_bit_; std::map<std::string, int> opname_bit_;


private: private:
std::unique_ptr<QuantStrategy> quant_strategy_; std::unique_ptr<QuantStrategy> quant_strategy_;
size_t bit_num_;
size_t bit_num_{8};
std::string config_file_; std::string config_file_;
PostQuantConfig config_param_; PostQuantConfig config_param_;
std::vector<std::vector<std::string>> images_; // multi_input, [[mode_input_0], [model_input_1]...] std::vector<std::vector<std::string>> images_; // multi_input, [[mode_input_0], [model_input_1]...]
@@ -61,6 +63,7 @@ class WeightQuantizer : public Quantizer {


STATUS DoMiexedQuant(FuncGraphPtr); STATUS DoMiexedQuant(FuncGraphPtr);
STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr<PrimitiveC> primitive_c); STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr<PrimitiveC> primitive_c);
STATUS DoFixedQuant(FuncGraphPtr);
}; };
} // namespace mindspore::lite::quant } // namespace mindspore::lite::quant
#endif #endif

Loading…
Cancel
Save