diff --git a/mindspore/lite/micro/cmake/file_list.cmake b/mindspore/lite/micro/cmake/file_list.cmake index 8777af0662..eda4d8c7e3 100644 --- a/mindspore/lite/micro/cmake/file_list.cmake +++ b/mindspore/lite/micro/cmake/file_list.cmake @@ -135,7 +135,7 @@ set(LITE_SRC ${LITE_DIR}/src/sub_graph_split.cc ${LITE_DIR}/src/tensorlist.cc ${LITE_DIR}/src/tensor.cc - ${LITE_DIR}/src/dequant.cc + ${LITE_DIR}/src/weight_decoder.cc ${LITE_DIR}/src/huffman_decode.cc ${LITE_DIR}/src/common/log_adapter.cc ${LITE_DIR}/src/common/utils.cc diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 2af4e527cd..2f019de155 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -841,7 +841,7 @@ table Rsqrt { } table QuantDTypeCast { - src_t: long; + src_t: long; // deprecated dst_t: long; } diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt index d363f8f2d5..e5ba183a17 100644 --- a/mindspore/lite/src/CMakeLists.txt +++ b/mindspore/lite/src/CMakeLists.txt @@ -61,7 +61,7 @@ set(LITE_SRC ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc ${CMAKE_CURRENT_SOURCE_DIR}/errorcode.cc - ${CMAKE_CURRENT_SOURCE_DIR}/dequant.cc + ${CMAKE_CURRENT_SOURCE_DIR}/weight_decoder.cc ${CMAKE_CURRENT_SOURCE_DIR}/huffman_decode.cc ) if(DEFINED ARCHS) diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 7001a3f331..ad979bcb08 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -28,7 +28,7 @@ #include "src/common/graph_util.h" #include "src/kernel_registry.h" #include "src/lite_model.h" -#include "src/dequant.h" +#include "src/weight_decoder.h" #ifdef ENABLE_MINDRT #include "src/mindrt_executor.h" #endif @@ -57,13 +57,13 @@ int DecompressTensor(const schema::Tensor &src_tensor, Tensor *dst_tensor) { // huffman code and bit pack are not assumed to be performed at same time STATUS ret = RET_ERROR; if (src_tensor.enableHuffmanCode()) { - ret = DequantUtil::DecodeHuffmanCode(src_tensor, dst_tensor); + ret = WeightDecoder::DecodeHuffmanCode(src_tensor, dst_tensor); if (ret != RET_OK && ret != RET_NO_CHANGE) { MS_LOG(ERROR) << "Decode huffman code failed: " << ret; return ret; } } else if (need_bit_unpack) { - ret = DequantUtil::UnPackToInt(src_tensor, dst_tensor); + ret = WeightDecoder::UnPackToInt(src_tensor, dst_tensor); if (ret != RET_OK && ret != RET_NO_CHANGE) { MS_LOG(ERROR) << "Unpack to int8 failed: " << ret; return ret; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc index b9152a4804..7e2e3356d8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc @@ -91,8 +91,8 @@ int Convolution1x1FP16CPUKernel::InitWeightBias() { if (origin_bias_data_type_ == kNumberTypeFloat16) { memcpy(bias_data_, origin_bias_, output_channel * sizeof(float16_t)); } else { - Float32ToFloat16(reinterpret_cast(origin_bias_), reinterpret_cast(bias_data_), - output_channel); + MS_LOG(ERROR) << "Conv1x1 only support fp16 weight"; + return RET_ERROR; } memset(reinterpret_cast(bias_data_) + weight_size, 0, size - weight_size); } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_base_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_base_fp16.cc index b1402a90ed..bdaf6cd6f0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_base_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_base_fp16.cc @@ -41,18 +41,8 @@ int ConvolutionBaseFP16CPUKernel::GetExecuteTensor() { int ConvolutionBaseFP16CPUKernel::GetExecuteFilter(lite::Tensor *weight_tensor, void *origin_data) { MS_ASSERT(origin_weight_data_type_ == kNumberTypeFloat32 || origin_weight_data_type_ == kNumberTypeFloat16); if (origin_weight_data_type_ == kNumberTypeFloat32) { - float *origin_weight = reinterpret_cast(origin_data); - size_t fp16_weight_size = weight_tensor->Channel() * weight_tensor->Batch() * weight_tensor->Height() * - weight_tensor->Width() * sizeof(float16_t); - fp16_weight_ = reinterpret_cast(malloc(fp16_weight_size)); - if (fp16_weight_ == nullptr) { - MS_LOG(ERROR) << "malloc fp16_weight_ failed."; - return RET_ERROR; - } - for (size_t i = 0; i < fp16_weight_size / sizeof(float16_t); ++i) { - fp16_weight_[i] = (float16_t)origin_weight[i]; - } - execute_weight_ = fp16_weight_; + MS_LOG(ERROR) << "Conv fp16 only support fp16 weight"; + return RET_ERROR; } else { execute_weight_ = reinterpret_cast(origin_data); fp16_weight_ = nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.cc index b1b880e9fd..13c4862a76 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.cc @@ -247,6 +247,10 @@ kernel::LiteKernel *CreateDelegateConvFp16(const std::vector &in const std::vector &outputs, OpParameter *op_parameter, const InnerContext *ctx) { auto weight_data_type = inputs.at(1)->data_type(); + if (weight_data_type != kNumberTypeFloat16) { + MS_LOG(ERROR) << "Convfp16 only support fp16 weight"; + return nullptr; + } TypeId bias_data_type = kTypeUnknown; if (inputs.size() == 3) { bias_data_type = inputs.at(2)->data_type(); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc index 67f4cb3d66..44563fb30d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc @@ -59,7 +59,8 @@ int ConvolutionFP16CPUKernel::InitWeightBias() { if (origin_bias_data_type_ == kNumberTypeFloat16) { memcpy(bias_data_, origin_bias_, out_channel * sizeof(float16_t)); } else { - Float32ToFloat16(reinterpret_cast(origin_bias_), reinterpret_cast(bias_data_), out_channel); + MS_LOG(ERROR) << "Conv fp16 only support fp16 bias"; + return RET_ERROR; } } else { MS_ASSERT(in_tensors_.size() == kInputSize1); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc index e655e35234..46b2c7bfec 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc @@ -96,7 +96,8 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { if (origin_bias_data_type_ == kNumberTypeFloat16) { memcpy(bias_data_, origin_bias_, out_channel * sizeof(float16_t)); } else { - Float32ToFloat16(reinterpret_cast(origin_bias_), reinterpret_cast(bias_data_), out_channel); + MS_LOG(ERROR) << "Conv winograd fp16 only support fp16 bias"; + return RET_ERROR; } } else { MS_ASSERT(in_tensors_.size() == kInputSize1); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc index 0a750215c9..353f172bed 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc @@ -67,7 +67,7 @@ int DeConvolutionFp16CPUKernel::InitWeightBias() { if (in_tensors_.size() == 3 && in_tensors_.at(kBiasIndex)->shape().size() == 1 && in_tensors_.at(kBiasIndex)->DimensionSize(0) == output_channel) { if (in_tensors_.at(2)->data_type() != kNumberTypeFloat16) { - MS_LOG(ERROR) << "deconv fp16 kernel require fp16 bias"; + MS_LOG(ERROR) << "DeConv fp16 only support fp16 weight"; return RET_ERROR; } if (bias_size != in_tensors_.at(2)->Size()) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_base_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_base_fp16.cc index beaaf0d22d..a395b54b73 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_base_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_base_fp16.cc @@ -78,7 +78,8 @@ int MatmulBaseFP16CPUKernel::InitBias() { } memset(bias_ptr_, 0, max_bias_data * sizeof(float16_t)); if (in_tensors_[2]->data_type() == kNumberTypeFloat32) { - Float32ToFloat16(reinterpret_cast(in_tensors_[2]->data_c()), bias_ptr_, bias_tensor->ElementsNum()); + MS_LOG(ERROR) << "Matmul fp16 only support fp16 weight"; + return RET_ERROR; } else if (in_tensors_[2]->data_type() == kNumberTypeFloat16) { memcpy(bias_ptr_, in_tensors_[2]->data_c(), max_bias_data * sizeof(float16_t)); } else { diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc index 601afb9797..23b5ab3bf4 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc @@ -16,7 +16,7 @@ #include #include "src/runtime/kernel/opencl/opencl_kernel.h" -#include "mindspore/lite/src/dequant.h" +#include "mindspore/lite/src/weight_decoder.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h index 9d245025bb..0594b967a8 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h @@ -25,7 +25,7 @@ #include "src/lite_kernel.h" #include "include/errorcode.h" #include "src/runtime/gpu/opencl/opencl_runtime.h" -#include "mindspore/lite/src/dequant.h" +#include "mindspore/lite/src/weight_decoder.h" #include "src/runtime/kernel/opencl/utils.h" #include "nnacl/resize_parameter.h" diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index 2a8e735a07..787ff03f4c 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -31,8 +31,7 @@ #include "src/common/prim_util.h" #include "src/runtime/infer_manager.h" #include "src/sub_graph_split.h" -#include "src/dequant.h" -#include "nnacl/matmul_parameter.h" +#include "src/weight_decoder.h" #if GPU_OPENCL #include "src/runtime/kernel/opencl/opencl_subgraph.h" #include "src/runtime/gpu/opencl/opencl_runtime.h" @@ -216,7 +215,7 @@ int Scheduler::InferSubGraphShape(size_t subgraph_index, bool *infer_shape_inter namespace { #ifndef SUPPORT_TRAIN -int CopyConstTensor(Tensor *tensor, std::map *restored_origin_tensors, TypeId dst_data_type) { +int CastConstTensorData(Tensor *tensor, std::map *restored_origin_tensors, TypeId dst_data_type) { MS_ASSERT(restored_origin_tensors != nullptr); MS_ASSERT(tensor != nullptr); if (dst_data_type != kNumberTypeFloat32 && dst_data_type != kNumberTypeFloat16) { @@ -248,6 +247,26 @@ int CopyConstTensor(Tensor *tensor, std::map *restored_origi #else MS_LOG(ERROR) << "Unsupported dst data type: float16"; return RET_ERROR; +#endif + } else if (tensor->data_type() == kNumberTypeFloat16 && dst_data_type == kNumberTypeFloat32) { +#if defined(ENABLE_ARM64) && defined(ENABLE_FP16) + auto restore_tensor = Tensor::CopyTensor(*tensor, false); + restore_tensor->set_data(origin_data); + restore_tensor->set_own_data(tensor->own_data()); + tensor->set_data(nullptr); + tensor->set_data_type(kNumberTypeFloat32); + auto ret = tensor->MallocData(); + if (RET_OK != ret) { + MS_LOG(ERROR) << "malloc data failed"; + return ret; + } + auto new_tensor_data = tensor->data_c(); + MS_ASSERT(new_tensor_data != nullptr); + Float16ToFloat32_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum()); + (*restored_origin_tensors)[tensor] = restore_tensor; +#else + MS_LOG(ERROR) << "Unsupported dst data type: float16"; + return RET_ERROR; #endif } else { if (tensor->own_data()) { @@ -290,19 +309,6 @@ inline void RestoreTensorData(std::map *restored_origin_tens } FreeRestoreTensors(restored_origin_tensors); } - -inline bool IsChannelFirst(int index, OpParameter *op_parameter) { - MS_ASSERT(op_parameter != nullptr); - if (op_parameter->type_ == schema::PrimitiveType_MatMul) { - const auto *param = reinterpret_cast(op_parameter); - if (index == 0) { - return !(param->a_transpose_); - } else if (index == 1) { - return param->b_transpose_; - } - } - return true; -} } // namespace kernel::LiteKernel *Scheduler::FindCpuKernel(const std::vector &in_tensors, @@ -321,23 +327,21 @@ kernel::LiteKernel *Scheduler::FindCpuKernel(const std::vector &in_ten } cpu_desc.data_type = kNumberTypeFloat16; } + auto ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kernel_data_type); + if (ret != RET_OK) { + MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret; + return nullptr; + } std::map restored_origin_tensors; - int index = 0; - for (auto &tensor : in_tensors) { - auto channel_first = IsChannelFirst(index++, op_parameter); - auto *restore_tensor = DequantUtil::DequantTensor(tensor, cpu_desc.data_type, channel_first, kernel_data_type); - if (restore_tensor != nullptr) { - restored_origin_tensors[tensor] = restore_tensor; - } else { #ifndef SUPPORT_TRAIN - auto ret = CopyConstTensor(tensor, &restored_origin_tensors, kernel_data_type); - if (ret != RET_OK) { - MS_LOG(DEBUG) << "CopyConstTensor failed: " << ret; - return nullptr; - } -#endif + for (auto &tensor : in_tensors) { + ret = CastConstTensorData(tensor, &restored_origin_tensors, kernel_data_type); + if (ret != RET_OK) { + MS_LOG(DEBUG) << "CastConstTensorData failed: " << ret; + return nullptr; } } +#endif auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, cpu_desc, op_parameter); if (kernel != nullptr) { MS_LOG(DEBUG) << "Get TypeId(" << kernel_data_type << ") op success: " << PrimitiveCurVersionTypeName(op_type); @@ -362,24 +366,18 @@ kernel::LiteKernel *Scheduler::FindGpuKernel(const std::vector &in_ten gpu_desc.data_type = kNumberTypeInt8; } - // weight quant - std::map restored_origin_tensors; - for (auto &tensor : in_tensors) { - int index = 0; - auto channel_first = IsChannelFirst(index++, op_parameter); - auto *restore_tensor = DequantUtil::DequantTensor(tensor, desc.data_type, channel_first, kNumberTypeFloat32); - if (restore_tensor != nullptr) { - restored_origin_tensors[tensor] = restore_tensor; - } + // weight dequant + auto ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kNumberTypeFloat32); + if (ret != RET_OK) { + MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret; + return nullptr; } auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, gpu_desc, op_parameter); if (kernel != nullptr) { MS_LOG(DEBUG) << "Get gpu op success: " << PrimitiveCurVersionTypeName(gpu_desc.type); - FreeRestoreTensors(&restored_origin_tensors); } else { MS_LOG(DEBUG) << "Get gpu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(gpu_desc.type); - RestoreTensorData(&restored_origin_tensors); } return kernel; } else { @@ -396,26 +394,20 @@ kernel::LiteKernel *Scheduler::FindNpuKernel(const std::vector &in_ten if (npu_desc.data_type == kNumberTypeFloat16) { npu_desc.data_type = kNumberTypeFloat32; } + auto ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kNumberTypeFloat32); + if (ret != RET_OK) { + MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret; + return nullptr; + } for (auto tensor : in_tensors) { if (tensor->data_type() == kNumberTypeFloat16) { tensor->set_data_type(kNumberTypeFloat32); } } - std::map restored_origin_tensors; - for (auto &tensor : in_tensors) { - int index = 0; - auto channel_first = IsChannelFirst(index++, op_parameter); - auto *restore_tensor = DequantUtil::DequantTensor(tensor, desc.data_type, channel_first, kNumberTypeFloat32); - if (restore_tensor != nullptr) { - restored_origin_tensors[tensor] = restore_tensor; - } - } auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, npu_desc, op_parameter); if (kernel != nullptr) { - FreeRestoreTensors(&restored_origin_tensors); MS_LOG(DEBUG) << "Get npu op success: " << PrimitiveCurVersionTypeName(npu_desc.type); } else { - RestoreTensorData(&restored_origin_tensors); MS_LOG(DEBUG) << "Get npu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(npu_desc.type); } return kernel; diff --git a/mindspore/lite/src/sub_graph_kernel.h b/mindspore/lite/src/sub_graph_kernel.h index 0ae5a2d111..728eb5ad24 100644 --- a/mindspore/lite/src/sub_graph_kernel.h +++ b/mindspore/lite/src/sub_graph_kernel.h @@ -178,6 +178,19 @@ class CpuFp16SubGraph : public CpuSubGraph { int PreProcess() override; int Run() override { return CpuSubGraph::Run(); } int Run(const KernelCallBack &before, const KernelCallBack &after) override { +#ifdef Debug + for (const auto *node : nodes_) { + if (node->Type() == schema::PrimitiveType_PartialFusion) { + continue; + } + for (const auto *in_tensor : node->in_tensors()) { + if (in_tensor->data_type() == kNumberTypeFloat32) { + MS_LOG(ERROR) << "FP16 kernel can not accept float32 input"; + return lite::RET_ERROR; + } + } + } +#endif return CpuSubGraph::Run(before, after); }; int PostProcess() override; diff --git a/mindspore/lite/src/dequant.cc b/mindspore/lite/src/weight_decoder.cc similarity index 77% rename from mindspore/lite/src/dequant.cc rename to mindspore/lite/src/weight_decoder.cc index 9aee49ab08..cc451a87fc 100644 --- a/mindspore/lite/src/dequant.cc +++ b/mindspore/lite/src/weight_decoder.cc @@ -16,12 +16,11 @@ #include #include #include -#include "src/dequant.h" +#include "src/weight_decoder.h" #include "src/huffman_decode.h" -#include "nnacl/matmul_parameter.h" namespace mindspore::lite { -int DequantUtil::DequantWeight(lite::Tensor *input_tensor, bool channel_first, TypeId dst_data_type) { +int WeightDecoder::DequantWeight(lite::Tensor *input_tensor, bool channel_first, TypeId dst_data_type) { MS_ASSERT(input_tensor != nullptr); if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) { MS_LOG(ERROR) << "Conv weight input type error." << input_tensor->data_type(); @@ -69,7 +68,7 @@ int DequantUtil::DequantWeight(lite::Tensor *input_tensor, bool channel_first, T return RET_OK; } -int DequantUtil::DecodeHuffmanCode(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor) { +int WeightDecoder::DecodeHuffmanCode(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor) { MS_ASSERT(dst_tensor != nullptr); if (!dst_tensor->IsConst() || !src_tensor.enableHuffmanCode()) { return RET_NO_CHANGE; @@ -93,7 +92,7 @@ int DequantUtil::DecodeHuffmanCode(const schema::Tensor &src_tensor, lite::Tenso return RET_OK; } -int DequantUtil::UnPackToInt(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor) { +int WeightDecoder::UnPackToInt(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor) { MS_ASSERT(dst_tensor != nullptr); auto quant_params = src_tensor.quantParams(); if (quant_params == nullptr || quant_params->size() == 0) { @@ -127,26 +126,39 @@ int DequantUtil::UnPackToInt(const schema::Tensor &src_tensor, lite::Tensor *dst } } -Tensor *DequantUtil::DequantTensor(Tensor *tensor, TypeId data_type, bool channel_first, TypeId dst_data_type) { +int WeightDecoder::DequantNode(OpParameter *op_parameter, const std::vector &in_tensors, + TypeId dst_data_type) { + if (op_parameter->quant_type_ != schema::QuantType_QUANT_WEIGHT) { + return RET_OK; + } + int index = 0; + for (auto &tensor : in_tensors) { + auto channel_first = IsChannelFirst(index++, op_parameter); + auto ret = WeightDecoder::DequantTensor(tensor, channel_first, dst_data_type); + if (ret != RET_OK && ret != RET_NO_CHANGE) { + MS_LOG(DEBUG) << "Dequant tensor failed"; + return RET_ERROR; + } + } + return RET_OK; +} + +int WeightDecoder::DequantTensor(Tensor *tensor, bool channel_first, TypeId dst_data_type) { MS_ASSERT(tensor != nullptr); - Tensor *restore_tensor = nullptr; - if (!tensor->IsConst() || !(data_type == TypeId::kNumberTypeFloat32 || data_type == TypeId::kNumberTypeFloat16)) { - return nullptr; + if (!tensor->IsConst() || + !(dst_data_type == TypeId::kNumberTypeFloat32 || dst_data_type == TypeId::kNumberTypeFloat16)) { + return RET_NO_CHANGE; } - auto restore_type = tensor->data_type(); bool need_dequant = !tensor->quant_params().empty() && tensor->quant_params().front().inited && - (restore_type == kNumberTypeInt8 || restore_type == kNumberTypeInt16); + (tensor->data_type() == kNumberTypeInt8 || tensor->data_type() == kNumberTypeInt16); if (!need_dequant) { - return nullptr; + return RET_NO_CHANGE; } - restore_tensor = Tensor::CopyTensor(*tensor, false); - restore_tensor->set_data(tensor->data_c()); - restore_tensor->set_own_data(tensor->own_data()); - auto ret = DequantUtil::DequantWeight(tensor, channel_first, dst_data_type); + auto ret = WeightDecoder::DequantWeight(tensor, channel_first, dst_data_type); if (ret != RET_OK) { MS_LOG(ERROR) << "Dequant data failed: " << ret; - return nullptr; + return ret; } - return restore_tensor; + return RET_OK; } } // namespace mindspore::lite diff --git a/mindspore/lite/src/dequant.h b/mindspore/lite/src/weight_decoder.h similarity index 88% rename from mindspore/lite/src/dequant.h rename to mindspore/lite/src/weight_decoder.h index 66e52f8db4..d7dd90c47d 100644 --- a/mindspore/lite/src/dequant.h +++ b/mindspore/lite/src/weight_decoder.h @@ -22,19 +22,22 @@ #include #include #include +#include "nnacl/matmul_parameter.h" #include "src/lite_kernel.h" #include "src/common/utils.h" #include "src/tensor.h" namespace mindspore::lite { -class DequantUtil { +class WeightDecoder { public: static int UnPackToInt(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor); static int DecodeHuffmanCode(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor); - static Tensor *DequantTensor(Tensor *tensor, TypeId data_type, bool channel_first = true, - TypeId dst_data_type = kNumberTypeFloat32); + static int DequantNode(OpParameter *op_parameter, const std::vector &in_tensors, TypeId dst_data_type); + + private: + static int DequantTensor(Tensor *tensor, bool channel_first = true, TypeId dst_data_type = kNumberTypeFloat32); template static DT *DequantData(lite::Tensor *input_tensor, bool channel_first = true) { @@ -102,22 +105,19 @@ class DequantUtil { return dequant_datas; } - template - static void UnpackUtil(const T1 *weight_data, int pack_size, int origin_bit, void *unpack_int_data) { - if (weight_data == nullptr || unpack_int_data == nullptr) { - MS_LOG(ERROR) << "data is nullptr"; - return; - } - std::queue unpack_bit_data; - size_t count = 0; - for (int i = 0; i < pack_size; ++i) { - T2 pack_data = (static_cast(static_cast(weight_data)))[i]; - bool is_last = i == pack_size - 1; - UnPackData(origin_bit, pack_data, &unpack_bit_data, unpack_int_data, &count, is_last); + inline static bool IsChannelFirst(int index, const OpParameter *op_parameter) { + MS_ASSERT(op_parameter != nullptr); + if (op_parameter->type_ == schema::PrimitiveType_MatMul) { + const auto *param = reinterpret_cast(op_parameter); + if (index == 0) { + return !(param->a_transpose_); + } else if (index == 1) { + return param->b_transpose_; + } } + return true; } - private: static int DequantWeight(lite::Tensor *input_tensor, bool channel_first, TypeId dst_data_type = kNumberTypeFloat32); template diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index f466970e3e..cd3e070ec8 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -143,7 +143,7 @@ set(TEST_LITE_SRC ${LITE_DIR}/src/kernel_registry.cc ${LITE_DIR}/src/lite_kernel.cc ${LITE_DIR}/src/lite_session.cc - ${LITE_DIR}/src/dequant.cc + ${LITE_DIR}/src/weight_decoder.cc ${LITE_DIR}/src/huffman_decode.cc ${LITE_DIR}/src/sub_graph_kernel.cc ${LITE_DIR}/src/sub_graph_split.cc diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 8d9fb8815f..0097ced63e 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -115,7 +115,7 @@ set(LITE_SRC ${SRC_DIR}/executor.cc ${SRC_DIR}/lite_model.cc ${SRC_DIR}/errorcode.cc - ${SRC_DIR}/dequant.cc + ${SRC_DIR}/weight_decoder.cc ${SRC_DIR}/huffman_decode.cc ${SRC_DIR}/ops/ops_utils.cc ${SRC_DIR}/ops/ops_def.cc diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc index f8ea5c98af..f8bbed3524 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc @@ -339,6 +339,7 @@ STATUS FormatTransPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uniqu MS_LOG(ERROR) << "Crop error"; return RET_ERROR; } + node->primitive->value.AsCrop()->axis = axis_map[origin_axis]; node->primitive->value.AsCrop()->offsets = offsets; } if (type == schema::PrimitiveType_SliceFusion || type == schema::PrimitiveType_StridedSlice) { diff --git a/mindspore/lite/tools/converter/quantizer/huffman_encode.cc b/mindspore/lite/tools/converter/quantizer/huffman_encode.cc index ced6ca8873..a1f3e4b6bd 100644 --- a/mindspore/lite/tools/converter/quantizer/huffman_encode.cc +++ b/mindspore/lite/tools/converter/quantizer/huffman_encode.cc @@ -15,7 +15,7 @@ */ #include "tools/converter/quantizer/huffman_encode.h" -#include "src/dequant.h" +#include "src/weight_decoder.h" #include "tools/converter/quantizer/quantize_util.h" namespace mindspore {