Browse Source

!14770 remove fp16 kernel process float32 weight code

From: @hangangqiang
Reviewed-by: @zhang_xue_tong,@zhanghaibo5
Signed-off-by: @zhang_xue_tong
pull/14770/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
02e6b30400
21 changed files with 129 additions and 114 deletions
  1. +1
    -1
      mindspore/lite/micro/cmake/file_list.cmake
  2. +1
    -1
      mindspore/lite/schema/ops.fbs
  3. +1
    -1
      mindspore/lite/src/CMakeLists.txt
  4. +3
    -3
      mindspore/lite/src/lite_session.cc
  5. +2
    -2
      mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc
  6. +2
    -12
      mindspore/lite/src/runtime/kernel/arm/fp16/convolution_base_fp16.cc
  7. +4
    -0
      mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.cc
  8. +2
    -1
      mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc
  9. +2
    -1
      mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc
  10. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc
  11. +2
    -1
      mindspore/lite/src/runtime/kernel/arm/fp16/matmul_base_fp16.cc
  12. +1
    -1
      mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc
  13. +1
    -1
      mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h
  14. +43
    -51
      mindspore/lite/src/scheduler.cc
  15. +13
    -0
      mindspore/lite/src/sub_graph_kernel.h
  16. +30
    -18
      mindspore/lite/src/weight_decoder.cc
  17. +16
    -16
      mindspore/lite/src/weight_decoder.h
  18. +1
    -1
      mindspore/lite/test/CMakeLists.txt
  19. +1
    -1
      mindspore/lite/tools/converter/CMakeLists.txt
  20. +1
    -0
      mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc
  21. +1
    -1
      mindspore/lite/tools/converter/quantizer/huffman_encode.cc

+ 1
- 1
mindspore/lite/micro/cmake/file_list.cmake View File

@@ -135,7 +135,7 @@ set(LITE_SRC
${LITE_DIR}/src/sub_graph_split.cc ${LITE_DIR}/src/sub_graph_split.cc
${LITE_DIR}/src/tensorlist.cc ${LITE_DIR}/src/tensorlist.cc
${LITE_DIR}/src/tensor.cc ${LITE_DIR}/src/tensor.cc
${LITE_DIR}/src/dequant.cc
${LITE_DIR}/src/weight_decoder.cc
${LITE_DIR}/src/huffman_decode.cc ${LITE_DIR}/src/huffman_decode.cc
${LITE_DIR}/src/common/log_adapter.cc ${LITE_DIR}/src/common/log_adapter.cc
${LITE_DIR}/src/common/utils.cc ${LITE_DIR}/src/common/utils.cc


+ 1
- 1
mindspore/lite/schema/ops.fbs View File

@@ -841,7 +841,7 @@ table Rsqrt {
} }


table QuantDTypeCast { table QuantDTypeCast {
src_t: long;
src_t: long; // deprecated
dst_t: long; dst_t: long;
} }




+ 1
- 1
mindspore/lite/src/CMakeLists.txt View File

@@ -61,7 +61,7 @@ set(LITE_SRC
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc
${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc
${CMAKE_CURRENT_SOURCE_DIR}/errorcode.cc ${CMAKE_CURRENT_SOURCE_DIR}/errorcode.cc
${CMAKE_CURRENT_SOURCE_DIR}/dequant.cc
${CMAKE_CURRENT_SOURCE_DIR}/weight_decoder.cc
${CMAKE_CURRENT_SOURCE_DIR}/huffman_decode.cc ${CMAKE_CURRENT_SOURCE_DIR}/huffman_decode.cc
) )
if(DEFINED ARCHS) if(DEFINED ARCHS)


+ 3
- 3
mindspore/lite/src/lite_session.cc View File

@@ -28,7 +28,7 @@
#include "src/common/graph_util.h" #include "src/common/graph_util.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "src/lite_model.h" #include "src/lite_model.h"
#include "src/dequant.h"
#include "src/weight_decoder.h"
#ifdef ENABLE_MINDRT #ifdef ENABLE_MINDRT
#include "src/mindrt_executor.h" #include "src/mindrt_executor.h"
#endif #endif
@@ -57,13 +57,13 @@ int DecompressTensor(const schema::Tensor &src_tensor, Tensor *dst_tensor) {
// huffman code and bit pack are not assumed to be performed at same time // huffman code and bit pack are not assumed to be performed at same time
STATUS ret = RET_ERROR; STATUS ret = RET_ERROR;
if (src_tensor.enableHuffmanCode()) { if (src_tensor.enableHuffmanCode()) {
ret = DequantUtil::DecodeHuffmanCode(src_tensor, dst_tensor);
ret = WeightDecoder::DecodeHuffmanCode(src_tensor, dst_tensor);
if (ret != RET_OK && ret != RET_NO_CHANGE) { if (ret != RET_OK && ret != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Decode huffman code failed: " << ret; MS_LOG(ERROR) << "Decode huffman code failed: " << ret;
return ret; return ret;
} }
} else if (need_bit_unpack) { } else if (need_bit_unpack) {
ret = DequantUtil::UnPackToInt(src_tensor, dst_tensor);
ret = WeightDecoder::UnPackToInt(src_tensor, dst_tensor);
if (ret != RET_OK && ret != RET_NO_CHANGE) { if (ret != RET_OK && ret != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Unpack to int8 failed: " << ret; MS_LOG(ERROR) << "Unpack to int8 failed: " << ret;
return ret; return ret;


+ 2
- 2
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc View File

@@ -91,8 +91,8 @@ int Convolution1x1FP16CPUKernel::InitWeightBias() {
if (origin_bias_data_type_ == kNumberTypeFloat16) { if (origin_bias_data_type_ == kNumberTypeFloat16) {
memcpy(bias_data_, origin_bias_, output_channel * sizeof(float16_t)); memcpy(bias_data_, origin_bias_, output_channel * sizeof(float16_t));
} else { } else {
Float32ToFloat16(reinterpret_cast<float *>(origin_bias_), reinterpret_cast<float16_t *>(bias_data_),
output_channel);
MS_LOG(ERROR) << "Conv1x1 only support fp16 weight";
return RET_ERROR;
} }
memset(reinterpret_cast<char *>(bias_data_) + weight_size, 0, size - weight_size); memset(reinterpret_cast<char *>(bias_data_) + weight_size, 0, size - weight_size);
} }


+ 2
- 12
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_base_fp16.cc View File

@@ -41,18 +41,8 @@ int ConvolutionBaseFP16CPUKernel::GetExecuteTensor() {
int ConvolutionBaseFP16CPUKernel::GetExecuteFilter(lite::Tensor *weight_tensor, void *origin_data) { int ConvolutionBaseFP16CPUKernel::GetExecuteFilter(lite::Tensor *weight_tensor, void *origin_data) {
MS_ASSERT(origin_weight_data_type_ == kNumberTypeFloat32 || origin_weight_data_type_ == kNumberTypeFloat16); MS_ASSERT(origin_weight_data_type_ == kNumberTypeFloat32 || origin_weight_data_type_ == kNumberTypeFloat16);
if (origin_weight_data_type_ == kNumberTypeFloat32) { if (origin_weight_data_type_ == kNumberTypeFloat32) {
float *origin_weight = reinterpret_cast<float *>(origin_data);
size_t fp16_weight_size = weight_tensor->Channel() * weight_tensor->Batch() * weight_tensor->Height() *
weight_tensor->Width() * sizeof(float16_t);
fp16_weight_ = reinterpret_cast<float16_t *>(malloc(fp16_weight_size));
if (fp16_weight_ == nullptr) {
MS_LOG(ERROR) << "malloc fp16_weight_ failed.";
return RET_ERROR;
}
for (size_t i = 0; i < fp16_weight_size / sizeof(float16_t); ++i) {
fp16_weight_[i] = (float16_t)origin_weight[i];
}
execute_weight_ = fp16_weight_;
MS_LOG(ERROR) << "Conv fp16 only support fp16 weight";
return RET_ERROR;
} else { } else {
execute_weight_ = reinterpret_cast<float16_t *>(origin_data); execute_weight_ = reinterpret_cast<float16_t *>(origin_data);
fp16_weight_ = nullptr; fp16_weight_ = nullptr;


+ 4
- 0
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.cc View File

@@ -247,6 +247,10 @@ kernel::LiteKernel *CreateDelegateConvFp16(const std::vector<lite::Tensor *> &in
const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter,
const InnerContext *ctx) { const InnerContext *ctx) {
auto weight_data_type = inputs.at(1)->data_type(); auto weight_data_type = inputs.at(1)->data_type();
if (weight_data_type != kNumberTypeFloat16) {
MS_LOG(ERROR) << "Convfp16 only support fp16 weight";
return nullptr;
}
TypeId bias_data_type = kTypeUnknown; TypeId bias_data_type = kTypeUnknown;
if (inputs.size() == 3) { if (inputs.size() == 3) {
bias_data_type = inputs.at(2)->data_type(); bias_data_type = inputs.at(2)->data_type();


+ 2
- 1
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc View File

@@ -59,7 +59,8 @@ int ConvolutionFP16CPUKernel::InitWeightBias() {
if (origin_bias_data_type_ == kNumberTypeFloat16) { if (origin_bias_data_type_ == kNumberTypeFloat16) {
memcpy(bias_data_, origin_bias_, out_channel * sizeof(float16_t)); memcpy(bias_data_, origin_bias_, out_channel * sizeof(float16_t));
} else { } else {
Float32ToFloat16(reinterpret_cast<float *>(origin_bias_), reinterpret_cast<float16_t *>(bias_data_), out_channel);
MS_LOG(ERROR) << "Conv fp16 only support fp16 bias";
return RET_ERROR;
} }
} else { } else {
MS_ASSERT(in_tensors_.size() == kInputSize1); MS_ASSERT(in_tensors_.size() == kInputSize1);


+ 2
- 1
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc View File

@@ -96,7 +96,8 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() {
if (origin_bias_data_type_ == kNumberTypeFloat16) { if (origin_bias_data_type_ == kNumberTypeFloat16) {
memcpy(bias_data_, origin_bias_, out_channel * sizeof(float16_t)); memcpy(bias_data_, origin_bias_, out_channel * sizeof(float16_t));
} else { } else {
Float32ToFloat16(reinterpret_cast<float *>(origin_bias_), reinterpret_cast<float16_t *>(bias_data_), out_channel);
MS_LOG(ERROR) << "Conv winograd fp16 only support fp16 bias";
return RET_ERROR;
} }
} else { } else {
MS_ASSERT(in_tensors_.size() == kInputSize1); MS_ASSERT(in_tensors_.size() == kInputSize1);


+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc View File

@@ -67,7 +67,7 @@ int DeConvolutionFp16CPUKernel::InitWeightBias() {
if (in_tensors_.size() == 3 && in_tensors_.at(kBiasIndex)->shape().size() == 1 && if (in_tensors_.size() == 3 && in_tensors_.at(kBiasIndex)->shape().size() == 1 &&
in_tensors_.at(kBiasIndex)->DimensionSize(0) == output_channel) { in_tensors_.at(kBiasIndex)->DimensionSize(0) == output_channel) {
if (in_tensors_.at(2)->data_type() != kNumberTypeFloat16) { if (in_tensors_.at(2)->data_type() != kNumberTypeFloat16) {
MS_LOG(ERROR) << "deconv fp16 kernel require fp16 bias";
MS_LOG(ERROR) << "DeConv fp16 only support fp16 weight";
return RET_ERROR; return RET_ERROR;
} }
if (bias_size != in_tensors_.at(2)->Size()) { if (bias_size != in_tensors_.at(2)->Size()) {


+ 2
- 1
mindspore/lite/src/runtime/kernel/arm/fp16/matmul_base_fp16.cc View File

@@ -78,7 +78,8 @@ int MatmulBaseFP16CPUKernel::InitBias() {
} }
memset(bias_ptr_, 0, max_bias_data * sizeof(float16_t)); memset(bias_ptr_, 0, max_bias_data * sizeof(float16_t));
if (in_tensors_[2]->data_type() == kNumberTypeFloat32) { if (in_tensors_[2]->data_type() == kNumberTypeFloat32) {
Float32ToFloat16(reinterpret_cast<float *>(in_tensors_[2]->data_c()), bias_ptr_, bias_tensor->ElementsNum());
MS_LOG(ERROR) << "Matmul fp16 only support fp16 weight";
return RET_ERROR;
} else if (in_tensors_[2]->data_type() == kNumberTypeFloat16) { } else if (in_tensors_[2]->data_type() == kNumberTypeFloat16) {
memcpy(bias_ptr_, in_tensors_[2]->data_c(), max_bias_data * sizeof(float16_t)); memcpy(bias_ptr_, in_tensors_[2]->data_c(), max_bias_data * sizeof(float16_t));
} else { } else {


+ 1
- 1
mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc View File

@@ -16,7 +16,7 @@


#include <mindspore/lite/src/runtime/infer_manager.h> #include <mindspore/lite/src/runtime/infer_manager.h>
#include "src/runtime/kernel/opencl/opencl_kernel.h" #include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "mindspore/lite/src/dequant.h"
#include "mindspore/lite/src/weight_decoder.h"


using mindspore::lite::RET_ERROR; using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK; using mindspore::lite::RET_OK;


+ 1
- 1
mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h View File

@@ -25,7 +25,7 @@
#include "src/lite_kernel.h" #include "src/lite_kernel.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/runtime/gpu/opencl/opencl_runtime.h" #include "src/runtime/gpu/opencl/opencl_runtime.h"
#include "mindspore/lite/src/dequant.h"
#include "mindspore/lite/src/weight_decoder.h"
#include "src/runtime/kernel/opencl/utils.h" #include "src/runtime/kernel/opencl/utils.h"
#include "nnacl/resize_parameter.h" #include "nnacl/resize_parameter.h"




+ 43
- 51
mindspore/lite/src/scheduler.cc View File

@@ -31,8 +31,7 @@
#include "src/common/prim_util.h" #include "src/common/prim_util.h"
#include "src/runtime/infer_manager.h" #include "src/runtime/infer_manager.h"
#include "src/sub_graph_split.h" #include "src/sub_graph_split.h"
#include "src/dequant.h"
#include "nnacl/matmul_parameter.h"
#include "src/weight_decoder.h"
#if GPU_OPENCL #if GPU_OPENCL
#include "src/runtime/kernel/opencl/opencl_subgraph.h" #include "src/runtime/kernel/opencl/opencl_subgraph.h"
#include "src/runtime/gpu/opencl/opencl_runtime.h" #include "src/runtime/gpu/opencl/opencl_runtime.h"
@@ -216,7 +215,7 @@ int Scheduler::InferSubGraphShape(size_t subgraph_index, bool *infer_shape_inter


namespace { namespace {
#ifndef SUPPORT_TRAIN #ifndef SUPPORT_TRAIN
int CopyConstTensor(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_origin_tensors, TypeId dst_data_type) {
int CastConstTensorData(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_origin_tensors, TypeId dst_data_type) {
MS_ASSERT(restored_origin_tensors != nullptr); MS_ASSERT(restored_origin_tensors != nullptr);
MS_ASSERT(tensor != nullptr); MS_ASSERT(tensor != nullptr);
if (dst_data_type != kNumberTypeFloat32 && dst_data_type != kNumberTypeFloat16) { if (dst_data_type != kNumberTypeFloat32 && dst_data_type != kNumberTypeFloat16) {
@@ -248,6 +247,26 @@ int CopyConstTensor(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_origi
#else #else
MS_LOG(ERROR) << "Unsupported dst data type: float16"; MS_LOG(ERROR) << "Unsupported dst data type: float16";
return RET_ERROR; return RET_ERROR;
#endif
} else if (tensor->data_type() == kNumberTypeFloat16 && dst_data_type == kNumberTypeFloat32) {
#if defined(ENABLE_ARM64) && defined(ENABLE_FP16)
auto restore_tensor = Tensor::CopyTensor(*tensor, false);
restore_tensor->set_data(origin_data);
restore_tensor->set_own_data(tensor->own_data());
tensor->set_data(nullptr);
tensor->set_data_type(kNumberTypeFloat32);
auto ret = tensor->MallocData();
if (RET_OK != ret) {
MS_LOG(ERROR) << "malloc data failed";
return ret;
}
auto new_tensor_data = tensor->data_c();
MS_ASSERT(new_tensor_data != nullptr);
Float16ToFloat32_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum());
(*restored_origin_tensors)[tensor] = restore_tensor;
#else
MS_LOG(ERROR) << "Unsupported dst data type: float16";
return RET_ERROR;
#endif #endif
} else { } else {
if (tensor->own_data()) { if (tensor->own_data()) {
@@ -290,19 +309,6 @@ inline void RestoreTensorData(std::map<Tensor *, Tensor *> *restored_origin_tens
} }
FreeRestoreTensors(restored_origin_tensors); FreeRestoreTensors(restored_origin_tensors);
} }

inline bool IsChannelFirst(int index, OpParameter *op_parameter) {
MS_ASSERT(op_parameter != nullptr);
if (op_parameter->type_ == schema::PrimitiveType_MatMul) {
const auto *param = reinterpret_cast<MatMulParameter *>(op_parameter);
if (index == 0) {
return !(param->a_transpose_);
} else if (index == 1) {
return param->b_transpose_;
}
}
return true;
}
} // namespace } // namespace


kernel::LiteKernel *Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_tensors, kernel::LiteKernel *Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_tensors,
@@ -321,23 +327,21 @@ kernel::LiteKernel *Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_ten
} }
cpu_desc.data_type = kNumberTypeFloat16; cpu_desc.data_type = kNumberTypeFloat16;
} }
auto ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kernel_data_type);
if (ret != RET_OK) {
MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret;
return nullptr;
}
std::map<Tensor *, Tensor *> restored_origin_tensors; std::map<Tensor *, Tensor *> restored_origin_tensors;
int index = 0;
for (auto &tensor : in_tensors) {
auto channel_first = IsChannelFirst(index++, op_parameter);
auto *restore_tensor = DequantUtil::DequantTensor(tensor, cpu_desc.data_type, channel_first, kernel_data_type);
if (restore_tensor != nullptr) {
restored_origin_tensors[tensor] = restore_tensor;
} else {
#ifndef SUPPORT_TRAIN #ifndef SUPPORT_TRAIN
auto ret = CopyConstTensor(tensor, &restored_origin_tensors, kernel_data_type);
if (ret != RET_OK) {
MS_LOG(DEBUG) << "CopyConstTensor failed: " << ret;
return nullptr;
}
#endif
for (auto &tensor : in_tensors) {
ret = CastConstTensorData(tensor, &restored_origin_tensors, kernel_data_type);
if (ret != RET_OK) {
MS_LOG(DEBUG) << "CastConstTensorData failed: " << ret;
return nullptr;
} }
} }
#endif
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, cpu_desc, op_parameter); auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, cpu_desc, op_parameter);
if (kernel != nullptr) { if (kernel != nullptr) {
MS_LOG(DEBUG) << "Get TypeId(" << kernel_data_type << ") op success: " << PrimitiveCurVersionTypeName(op_type); MS_LOG(DEBUG) << "Get TypeId(" << kernel_data_type << ") op success: " << PrimitiveCurVersionTypeName(op_type);
@@ -362,24 +366,18 @@ kernel::LiteKernel *Scheduler::FindGpuKernel(const std::vector<Tensor *> &in_ten
gpu_desc.data_type = kNumberTypeInt8; gpu_desc.data_type = kNumberTypeInt8;
} }


// weight quant
std::map<Tensor *, Tensor *> restored_origin_tensors;
for (auto &tensor : in_tensors) {
int index = 0;
auto channel_first = IsChannelFirst(index++, op_parameter);
auto *restore_tensor = DequantUtil::DequantTensor(tensor, desc.data_type, channel_first, kNumberTypeFloat32);
if (restore_tensor != nullptr) {
restored_origin_tensors[tensor] = restore_tensor;
}
// weight dequant
auto ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kNumberTypeFloat32);
if (ret != RET_OK) {
MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret;
return nullptr;
} }


auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, gpu_desc, op_parameter); auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, gpu_desc, op_parameter);
if (kernel != nullptr) { if (kernel != nullptr) {
MS_LOG(DEBUG) << "Get gpu op success: " << PrimitiveCurVersionTypeName(gpu_desc.type); MS_LOG(DEBUG) << "Get gpu op success: " << PrimitiveCurVersionTypeName(gpu_desc.type);
FreeRestoreTensors(&restored_origin_tensors);
} else { } else {
MS_LOG(DEBUG) << "Get gpu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(gpu_desc.type); MS_LOG(DEBUG) << "Get gpu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(gpu_desc.type);
RestoreTensorData(&restored_origin_tensors);
} }
return kernel; return kernel;
} else { } else {
@@ -396,26 +394,20 @@ kernel::LiteKernel *Scheduler::FindNpuKernel(const std::vector<Tensor *> &in_ten
if (npu_desc.data_type == kNumberTypeFloat16) { if (npu_desc.data_type == kNumberTypeFloat16) {
npu_desc.data_type = kNumberTypeFloat32; npu_desc.data_type = kNumberTypeFloat32;
} }
auto ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kNumberTypeFloat32);
if (ret != RET_OK) {
MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret;
return nullptr;
}
for (auto tensor : in_tensors) { for (auto tensor : in_tensors) {
if (tensor->data_type() == kNumberTypeFloat16) { if (tensor->data_type() == kNumberTypeFloat16) {
tensor->set_data_type(kNumberTypeFloat32); tensor->set_data_type(kNumberTypeFloat32);
} }
} }
std::map<Tensor *, Tensor *> restored_origin_tensors;
for (auto &tensor : in_tensors) {
int index = 0;
auto channel_first = IsChannelFirst(index++, op_parameter);
auto *restore_tensor = DequantUtil::DequantTensor(tensor, desc.data_type, channel_first, kNumberTypeFloat32);
if (restore_tensor != nullptr) {
restored_origin_tensors[tensor] = restore_tensor;
}
}
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, npu_desc, op_parameter); auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, npu_desc, op_parameter);
if (kernel != nullptr) { if (kernel != nullptr) {
FreeRestoreTensors(&restored_origin_tensors);
MS_LOG(DEBUG) << "Get npu op success: " << PrimitiveCurVersionTypeName(npu_desc.type); MS_LOG(DEBUG) << "Get npu op success: " << PrimitiveCurVersionTypeName(npu_desc.type);
} else { } else {
RestoreTensorData(&restored_origin_tensors);
MS_LOG(DEBUG) << "Get npu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(npu_desc.type); MS_LOG(DEBUG) << "Get npu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(npu_desc.type);
} }
return kernel; return kernel;


+ 13
- 0
mindspore/lite/src/sub_graph_kernel.h View File

@@ -178,6 +178,19 @@ class CpuFp16SubGraph : public CpuSubGraph {
int PreProcess() override; int PreProcess() override;
int Run() override { return CpuSubGraph::Run(); } int Run() override { return CpuSubGraph::Run(); }
int Run(const KernelCallBack &before, const KernelCallBack &after) override { int Run(const KernelCallBack &before, const KernelCallBack &after) override {
#ifdef Debug
for (const auto *node : nodes_) {
if (node->Type() == schema::PrimitiveType_PartialFusion) {
continue;
}
for (const auto *in_tensor : node->in_tensors()) {
if (in_tensor->data_type() == kNumberTypeFloat32) {
MS_LOG(ERROR) << "FP16 kernel can not accept float32 input";
return lite::RET_ERROR;
}
}
}
#endif
return CpuSubGraph::Run(before, after); return CpuSubGraph::Run(before, after);
}; };
int PostProcess() override; int PostProcess() override;


mindspore/lite/src/dequant.cc → mindspore/lite/src/weight_decoder.cc View File

@@ -16,12 +16,11 @@
#include <cmath> #include <cmath>
#include <string> #include <string>
#include <memory> #include <memory>
#include "src/dequant.h"
#include "src/weight_decoder.h"
#include "src/huffman_decode.h" #include "src/huffman_decode.h"
#include "nnacl/matmul_parameter.h"


namespace mindspore::lite { namespace mindspore::lite {
int DequantUtil::DequantWeight(lite::Tensor *input_tensor, bool channel_first, TypeId dst_data_type) {
int WeightDecoder::DequantWeight(lite::Tensor *input_tensor, bool channel_first, TypeId dst_data_type) {
MS_ASSERT(input_tensor != nullptr); MS_ASSERT(input_tensor != nullptr);
if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) { if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) {
MS_LOG(ERROR) << "Conv weight input type error." << input_tensor->data_type(); MS_LOG(ERROR) << "Conv weight input type error." << input_tensor->data_type();
@@ -69,7 +68,7 @@ int DequantUtil::DequantWeight(lite::Tensor *input_tensor, bool channel_first, T
return RET_OK; return RET_OK;
} }


int DequantUtil::DecodeHuffmanCode(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor) {
int WeightDecoder::DecodeHuffmanCode(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor) {
MS_ASSERT(dst_tensor != nullptr); MS_ASSERT(dst_tensor != nullptr);
if (!dst_tensor->IsConst() || !src_tensor.enableHuffmanCode()) { if (!dst_tensor->IsConst() || !src_tensor.enableHuffmanCode()) {
return RET_NO_CHANGE; return RET_NO_CHANGE;
@@ -93,7 +92,7 @@ int DequantUtil::DecodeHuffmanCode(const schema::Tensor &src_tensor, lite::Tenso
return RET_OK; return RET_OK;
} }


int DequantUtil::UnPackToInt(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor) {
int WeightDecoder::UnPackToInt(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor) {
MS_ASSERT(dst_tensor != nullptr); MS_ASSERT(dst_tensor != nullptr);
auto quant_params = src_tensor.quantParams(); auto quant_params = src_tensor.quantParams();
if (quant_params == nullptr || quant_params->size() == 0) { if (quant_params == nullptr || quant_params->size() == 0) {
@@ -127,26 +126,39 @@ int DequantUtil::UnPackToInt(const schema::Tensor &src_tensor, lite::Tensor *dst
} }
} }


Tensor *DequantUtil::DequantTensor(Tensor *tensor, TypeId data_type, bool channel_first, TypeId dst_data_type) {
int WeightDecoder::DequantNode(OpParameter *op_parameter, const std::vector<Tensor *> &in_tensors,
TypeId dst_data_type) {
if (op_parameter->quant_type_ != schema::QuantType_QUANT_WEIGHT) {
return RET_OK;
}
int index = 0;
for (auto &tensor : in_tensors) {
auto channel_first = IsChannelFirst(index++, op_parameter);
auto ret = WeightDecoder::DequantTensor(tensor, channel_first, dst_data_type);
if (ret != RET_OK && ret != RET_NO_CHANGE) {
MS_LOG(DEBUG) << "Dequant tensor failed";
return RET_ERROR;
}
}
return RET_OK;
}

int WeightDecoder::DequantTensor(Tensor *tensor, bool channel_first, TypeId dst_data_type) {
MS_ASSERT(tensor != nullptr); MS_ASSERT(tensor != nullptr);
Tensor *restore_tensor = nullptr;
if (!tensor->IsConst() || !(data_type == TypeId::kNumberTypeFloat32 || data_type == TypeId::kNumberTypeFloat16)) {
return nullptr;
if (!tensor->IsConst() ||
!(dst_data_type == TypeId::kNumberTypeFloat32 || dst_data_type == TypeId::kNumberTypeFloat16)) {
return RET_NO_CHANGE;
} }
auto restore_type = tensor->data_type();
bool need_dequant = !tensor->quant_params().empty() && tensor->quant_params().front().inited && bool need_dequant = !tensor->quant_params().empty() && tensor->quant_params().front().inited &&
(restore_type == kNumberTypeInt8 || restore_type == kNumberTypeInt16);
(tensor->data_type() == kNumberTypeInt8 || tensor->data_type() == kNumberTypeInt16);
if (!need_dequant) { if (!need_dequant) {
return nullptr;
return RET_NO_CHANGE;
} }
restore_tensor = Tensor::CopyTensor(*tensor, false);
restore_tensor->set_data(tensor->data_c());
restore_tensor->set_own_data(tensor->own_data());
auto ret = DequantUtil::DequantWeight(tensor, channel_first, dst_data_type);
auto ret = WeightDecoder::DequantWeight(tensor, channel_first, dst_data_type);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Dequant data failed: " << ret; MS_LOG(ERROR) << "Dequant data failed: " << ret;
return nullptr;
return ret;
} }
return restore_tensor;
return RET_OK;
} }
} // namespace mindspore::lite } // namespace mindspore::lite

mindspore/lite/src/dequant.h → mindspore/lite/src/weight_decoder.h View File

@@ -22,19 +22,22 @@
#include <vector> #include <vector>
#include <queue> #include <queue>
#include <cmath> #include <cmath>
#include "nnacl/matmul_parameter.h"
#include "src/lite_kernel.h" #include "src/lite_kernel.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/tensor.h" #include "src/tensor.h"


namespace mindspore::lite { namespace mindspore::lite {
class DequantUtil {
class WeightDecoder {
public: public:
static int UnPackToInt(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor); static int UnPackToInt(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor);


static int DecodeHuffmanCode(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor); static int DecodeHuffmanCode(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor);


static Tensor *DequantTensor(Tensor *tensor, TypeId data_type, bool channel_first = true,
TypeId dst_data_type = kNumberTypeFloat32);
static int DequantNode(OpParameter *op_parameter, const std::vector<Tensor *> &in_tensors, TypeId dst_data_type);

private:
static int DequantTensor(Tensor *tensor, bool channel_first = true, TypeId dst_data_type = kNumberTypeFloat32);


template <typename ST, typename DT = float> template <typename ST, typename DT = float>
static DT *DequantData(lite::Tensor *input_tensor, bool channel_first = true) { static DT *DequantData(lite::Tensor *input_tensor, bool channel_first = true) {
@@ -102,22 +105,19 @@ class DequantUtil {
return dequant_datas; return dequant_datas;
} }


template <typename T1, typename T2>
static void UnpackUtil(const T1 *weight_data, int pack_size, int origin_bit, void *unpack_int_data) {
if (weight_data == nullptr || unpack_int_data == nullptr) {
MS_LOG(ERROR) << "data is nullptr";
return;
}
std::queue<bool> unpack_bit_data;
size_t count = 0;
for (int i = 0; i < pack_size; ++i) {
T2 pack_data = (static_cast<const T2 *>(static_cast<const void *>(weight_data)))[i];
bool is_last = i == pack_size - 1;
UnPackData<T1, T2>(origin_bit, pack_data, &unpack_bit_data, unpack_int_data, &count, is_last);
inline static bool IsChannelFirst(int index, const OpParameter *op_parameter) {
MS_ASSERT(op_parameter != nullptr);
if (op_parameter->type_ == schema::PrimitiveType_MatMul) {
const auto *param = reinterpret_cast<const MatMulParameter *>(op_parameter);
if (index == 0) {
return !(param->a_transpose_);
} else if (index == 1) {
return param->b_transpose_;
}
} }
return true;
} }


private:
static int DequantWeight(lite::Tensor *input_tensor, bool channel_first, TypeId dst_data_type = kNumberTypeFloat32); static int DequantWeight(lite::Tensor *input_tensor, bool channel_first, TypeId dst_data_type = kNumberTypeFloat32);


template <typename T1, typename T2> template <typename T1, typename T2>

+ 1
- 1
mindspore/lite/test/CMakeLists.txt View File

@@ -143,7 +143,7 @@ set(TEST_LITE_SRC
${LITE_DIR}/src/kernel_registry.cc ${LITE_DIR}/src/kernel_registry.cc
${LITE_DIR}/src/lite_kernel.cc ${LITE_DIR}/src/lite_kernel.cc
${LITE_DIR}/src/lite_session.cc ${LITE_DIR}/src/lite_session.cc
${LITE_DIR}/src/dequant.cc
${LITE_DIR}/src/weight_decoder.cc
${LITE_DIR}/src/huffman_decode.cc ${LITE_DIR}/src/huffman_decode.cc
${LITE_DIR}/src/sub_graph_kernel.cc ${LITE_DIR}/src/sub_graph_kernel.cc
${LITE_DIR}/src/sub_graph_split.cc ${LITE_DIR}/src/sub_graph_split.cc


+ 1
- 1
mindspore/lite/tools/converter/CMakeLists.txt View File

@@ -115,7 +115,7 @@ set(LITE_SRC
${SRC_DIR}/executor.cc ${SRC_DIR}/executor.cc
${SRC_DIR}/lite_model.cc ${SRC_DIR}/lite_model.cc
${SRC_DIR}/errorcode.cc ${SRC_DIR}/errorcode.cc
${SRC_DIR}/dequant.cc
${SRC_DIR}/weight_decoder.cc
${SRC_DIR}/huffman_decode.cc ${SRC_DIR}/huffman_decode.cc
${SRC_DIR}/ops/ops_utils.cc ${SRC_DIR}/ops/ops_utils.cc
${SRC_DIR}/ops/ops_def.cc ${SRC_DIR}/ops/ops_def.cc


+ 1
- 0
mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc View File

@@ -339,6 +339,7 @@ STATUS FormatTransPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uniqu
MS_LOG(ERROR) << "Crop error"; MS_LOG(ERROR) << "Crop error";
return RET_ERROR; return RET_ERROR;
} }
node->primitive->value.AsCrop()->axis = axis_map[origin_axis];
node->primitive->value.AsCrop()->offsets = offsets; node->primitive->value.AsCrop()->offsets = offsets;
} }
if (type == schema::PrimitiveType_SliceFusion || type == schema::PrimitiveType_StridedSlice) { if (type == schema::PrimitiveType_SliceFusion || type == schema::PrimitiveType_StridedSlice) {


+ 1
- 1
mindspore/lite/tools/converter/quantizer/huffman_encode.cc View File

@@ -15,7 +15,7 @@
*/ */


#include "tools/converter/quantizer/huffman_encode.h" #include "tools/converter/quantizer/huffman_encode.h"
#include "src/dequant.h"
#include "src/weight_decoder.h"
#include "tools/converter/quantizer/quantize_util.h" #include "tools/converter/quantizer/quantize_util.h"


namespace mindspore { namespace mindspore {


Loading…
Cancel
Save