From: @hangangqiang Reviewed-by: @zhang_xue_tong,@zhanghaibo5 Signed-off-by: @zhang_xue_tongpull/14770/MERGE
| @@ -135,7 +135,7 @@ set(LITE_SRC | |||||
| ${LITE_DIR}/src/sub_graph_split.cc | ${LITE_DIR}/src/sub_graph_split.cc | ||||
| ${LITE_DIR}/src/tensorlist.cc | ${LITE_DIR}/src/tensorlist.cc | ||||
| ${LITE_DIR}/src/tensor.cc | ${LITE_DIR}/src/tensor.cc | ||||
| ${LITE_DIR}/src/dequant.cc | |||||
| ${LITE_DIR}/src/weight_decoder.cc | |||||
| ${LITE_DIR}/src/huffman_decode.cc | ${LITE_DIR}/src/huffman_decode.cc | ||||
| ${LITE_DIR}/src/common/log_adapter.cc | ${LITE_DIR}/src/common/log_adapter.cc | ||||
| ${LITE_DIR}/src/common/utils.cc | ${LITE_DIR}/src/common/utils.cc | ||||
| @@ -841,7 +841,7 @@ table Rsqrt { | |||||
| } | } | ||||
| table QuantDTypeCast { | table QuantDTypeCast { | ||||
| src_t: long; | |||||
| src_t: long; // deprecated | |||||
| dst_t: long; | dst_t: long; | ||||
| } | } | ||||
| @@ -61,7 +61,7 @@ set(LITE_SRC | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc | ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc | ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/errorcode.cc | ${CMAKE_CURRENT_SOURCE_DIR}/errorcode.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/dequant.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/weight_decoder.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/huffman_decode.cc | ${CMAKE_CURRENT_SOURCE_DIR}/huffman_decode.cc | ||||
| ) | ) | ||||
| if(DEFINED ARCHS) | if(DEFINED ARCHS) | ||||
| @@ -28,7 +28,7 @@ | |||||
| #include "src/common/graph_util.h" | #include "src/common/graph_util.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/lite_model.h" | #include "src/lite_model.h" | ||||
| #include "src/dequant.h" | |||||
| #include "src/weight_decoder.h" | |||||
| #ifdef ENABLE_MINDRT | #ifdef ENABLE_MINDRT | ||||
| #include "src/mindrt_executor.h" | #include "src/mindrt_executor.h" | ||||
| #endif | #endif | ||||
| @@ -57,13 +57,13 @@ int DecompressTensor(const schema::Tensor &src_tensor, Tensor *dst_tensor) { | |||||
| // huffman code and bit pack are not assumed to be performed at same time | // huffman code and bit pack are not assumed to be performed at same time | ||||
| STATUS ret = RET_ERROR; | STATUS ret = RET_ERROR; | ||||
| if (src_tensor.enableHuffmanCode()) { | if (src_tensor.enableHuffmanCode()) { | ||||
| ret = DequantUtil::DecodeHuffmanCode(src_tensor, dst_tensor); | |||||
| ret = WeightDecoder::DecodeHuffmanCode(src_tensor, dst_tensor); | |||||
| if (ret != RET_OK && ret != RET_NO_CHANGE) { | if (ret != RET_OK && ret != RET_NO_CHANGE) { | ||||
| MS_LOG(ERROR) << "Decode huffman code failed: " << ret; | MS_LOG(ERROR) << "Decode huffman code failed: " << ret; | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| } else if (need_bit_unpack) { | } else if (need_bit_unpack) { | ||||
| ret = DequantUtil::UnPackToInt(src_tensor, dst_tensor); | |||||
| ret = WeightDecoder::UnPackToInt(src_tensor, dst_tensor); | |||||
| if (ret != RET_OK && ret != RET_NO_CHANGE) { | if (ret != RET_OK && ret != RET_NO_CHANGE) { | ||||
| MS_LOG(ERROR) << "Unpack to int8 failed: " << ret; | MS_LOG(ERROR) << "Unpack to int8 failed: " << ret; | ||||
| return ret; | return ret; | ||||
| @@ -91,8 +91,8 @@ int Convolution1x1FP16CPUKernel::InitWeightBias() { | |||||
| if (origin_bias_data_type_ == kNumberTypeFloat16) { | if (origin_bias_data_type_ == kNumberTypeFloat16) { | ||||
| memcpy(bias_data_, origin_bias_, output_channel * sizeof(float16_t)); | memcpy(bias_data_, origin_bias_, output_channel * sizeof(float16_t)); | ||||
| } else { | } else { | ||||
| Float32ToFloat16(reinterpret_cast<float *>(origin_bias_), reinterpret_cast<float16_t *>(bias_data_), | |||||
| output_channel); | |||||
| MS_LOG(ERROR) << "Conv1x1 only support fp16 weight"; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| memset(reinterpret_cast<char *>(bias_data_) + weight_size, 0, size - weight_size); | memset(reinterpret_cast<char *>(bias_data_) + weight_size, 0, size - weight_size); | ||||
| } | } | ||||
| @@ -41,18 +41,8 @@ int ConvolutionBaseFP16CPUKernel::GetExecuteTensor() { | |||||
| int ConvolutionBaseFP16CPUKernel::GetExecuteFilter(lite::Tensor *weight_tensor, void *origin_data) { | int ConvolutionBaseFP16CPUKernel::GetExecuteFilter(lite::Tensor *weight_tensor, void *origin_data) { | ||||
| MS_ASSERT(origin_weight_data_type_ == kNumberTypeFloat32 || origin_weight_data_type_ == kNumberTypeFloat16); | MS_ASSERT(origin_weight_data_type_ == kNumberTypeFloat32 || origin_weight_data_type_ == kNumberTypeFloat16); | ||||
| if (origin_weight_data_type_ == kNumberTypeFloat32) { | if (origin_weight_data_type_ == kNumberTypeFloat32) { | ||||
| float *origin_weight = reinterpret_cast<float *>(origin_data); | |||||
| size_t fp16_weight_size = weight_tensor->Channel() * weight_tensor->Batch() * weight_tensor->Height() * | |||||
| weight_tensor->Width() * sizeof(float16_t); | |||||
| fp16_weight_ = reinterpret_cast<float16_t *>(malloc(fp16_weight_size)); | |||||
| if (fp16_weight_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc fp16_weight_ failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| for (size_t i = 0; i < fp16_weight_size / sizeof(float16_t); ++i) { | |||||
| fp16_weight_[i] = (float16_t)origin_weight[i]; | |||||
| } | |||||
| execute_weight_ = fp16_weight_; | |||||
| MS_LOG(ERROR) << "Conv fp16 only support fp16 weight"; | |||||
| return RET_ERROR; | |||||
| } else { | } else { | ||||
| execute_weight_ = reinterpret_cast<float16_t *>(origin_data); | execute_weight_ = reinterpret_cast<float16_t *>(origin_data); | ||||
| fp16_weight_ = nullptr; | fp16_weight_ = nullptr; | ||||
| @@ -247,6 +247,10 @@ kernel::LiteKernel *CreateDelegateConvFp16(const std::vector<lite::Tensor *> &in | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | ||||
| const InnerContext *ctx) { | const InnerContext *ctx) { | ||||
| auto weight_data_type = inputs.at(1)->data_type(); | auto weight_data_type = inputs.at(1)->data_type(); | ||||
| if (weight_data_type != kNumberTypeFloat16) { | |||||
| MS_LOG(ERROR) << "Convfp16 only support fp16 weight"; | |||||
| return nullptr; | |||||
| } | |||||
| TypeId bias_data_type = kTypeUnknown; | TypeId bias_data_type = kTypeUnknown; | ||||
| if (inputs.size() == 3) { | if (inputs.size() == 3) { | ||||
| bias_data_type = inputs.at(2)->data_type(); | bias_data_type = inputs.at(2)->data_type(); | ||||
| @@ -59,7 +59,8 @@ int ConvolutionFP16CPUKernel::InitWeightBias() { | |||||
| if (origin_bias_data_type_ == kNumberTypeFloat16) { | if (origin_bias_data_type_ == kNumberTypeFloat16) { | ||||
| memcpy(bias_data_, origin_bias_, out_channel * sizeof(float16_t)); | memcpy(bias_data_, origin_bias_, out_channel * sizeof(float16_t)); | ||||
| } else { | } else { | ||||
| Float32ToFloat16(reinterpret_cast<float *>(origin_bias_), reinterpret_cast<float16_t *>(bias_data_), out_channel); | |||||
| MS_LOG(ERROR) << "Conv fp16 only support fp16 bias"; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| } else { | } else { | ||||
| MS_ASSERT(in_tensors_.size() == kInputSize1); | MS_ASSERT(in_tensors_.size() == kInputSize1); | ||||
| @@ -96,7 +96,8 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { | |||||
| if (origin_bias_data_type_ == kNumberTypeFloat16) { | if (origin_bias_data_type_ == kNumberTypeFloat16) { | ||||
| memcpy(bias_data_, origin_bias_, out_channel * sizeof(float16_t)); | memcpy(bias_data_, origin_bias_, out_channel * sizeof(float16_t)); | ||||
| } else { | } else { | ||||
| Float32ToFloat16(reinterpret_cast<float *>(origin_bias_), reinterpret_cast<float16_t *>(bias_data_), out_channel); | |||||
| MS_LOG(ERROR) << "Conv winograd fp16 only support fp16 bias"; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| } else { | } else { | ||||
| MS_ASSERT(in_tensors_.size() == kInputSize1); | MS_ASSERT(in_tensors_.size() == kInputSize1); | ||||
| @@ -67,7 +67,7 @@ int DeConvolutionFp16CPUKernel::InitWeightBias() { | |||||
| if (in_tensors_.size() == 3 && in_tensors_.at(kBiasIndex)->shape().size() == 1 && | if (in_tensors_.size() == 3 && in_tensors_.at(kBiasIndex)->shape().size() == 1 && | ||||
| in_tensors_.at(kBiasIndex)->DimensionSize(0) == output_channel) { | in_tensors_.at(kBiasIndex)->DimensionSize(0) == output_channel) { | ||||
| if (in_tensors_.at(2)->data_type() != kNumberTypeFloat16) { | if (in_tensors_.at(2)->data_type() != kNumberTypeFloat16) { | ||||
| MS_LOG(ERROR) << "deconv fp16 kernel require fp16 bias"; | |||||
| MS_LOG(ERROR) << "DeConv fp16 only support fp16 weight"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (bias_size != in_tensors_.at(2)->Size()) { | if (bias_size != in_tensors_.at(2)->Size()) { | ||||
| @@ -78,7 +78,8 @@ int MatmulBaseFP16CPUKernel::InitBias() { | |||||
| } | } | ||||
| memset(bias_ptr_, 0, max_bias_data * sizeof(float16_t)); | memset(bias_ptr_, 0, max_bias_data * sizeof(float16_t)); | ||||
| if (in_tensors_[2]->data_type() == kNumberTypeFloat32) { | if (in_tensors_[2]->data_type() == kNumberTypeFloat32) { | ||||
| Float32ToFloat16(reinterpret_cast<float *>(in_tensors_[2]->data_c()), bias_ptr_, bias_tensor->ElementsNum()); | |||||
| MS_LOG(ERROR) << "Matmul fp16 only support fp16 weight"; | |||||
| return RET_ERROR; | |||||
| } else if (in_tensors_[2]->data_type() == kNumberTypeFloat16) { | } else if (in_tensors_[2]->data_type() == kNumberTypeFloat16) { | ||||
| memcpy(bias_ptr_, in_tensors_[2]->data_c(), max_bias_data * sizeof(float16_t)); | memcpy(bias_ptr_, in_tensors_[2]->data_c(), max_bias_data * sizeof(float16_t)); | ||||
| } else { | } else { | ||||
| @@ -16,7 +16,7 @@ | |||||
| #include <mindspore/lite/src/runtime/infer_manager.h> | #include <mindspore/lite/src/runtime/infer_manager.h> | ||||
| #include "src/runtime/kernel/opencl/opencl_kernel.h" | #include "src/runtime/kernel/opencl/opencl_kernel.h" | ||||
| #include "mindspore/lite/src/dequant.h" | |||||
| #include "mindspore/lite/src/weight_decoder.h" | |||||
| using mindspore::lite::RET_ERROR; | using mindspore::lite::RET_ERROR; | ||||
| using mindspore::lite::RET_OK; | using mindspore::lite::RET_OK; | ||||
| @@ -25,7 +25,7 @@ | |||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/runtime/gpu/opencl/opencl_runtime.h" | #include "src/runtime/gpu/opencl/opencl_runtime.h" | ||||
| #include "mindspore/lite/src/dequant.h" | |||||
| #include "mindspore/lite/src/weight_decoder.h" | |||||
| #include "src/runtime/kernel/opencl/utils.h" | #include "src/runtime/kernel/opencl/utils.h" | ||||
| #include "nnacl/resize_parameter.h" | #include "nnacl/resize_parameter.h" | ||||
| @@ -31,8 +31,7 @@ | |||||
| #include "src/common/prim_util.h" | #include "src/common/prim_util.h" | ||||
| #include "src/runtime/infer_manager.h" | #include "src/runtime/infer_manager.h" | ||||
| #include "src/sub_graph_split.h" | #include "src/sub_graph_split.h" | ||||
| #include "src/dequant.h" | |||||
| #include "nnacl/matmul_parameter.h" | |||||
| #include "src/weight_decoder.h" | |||||
| #if GPU_OPENCL | #if GPU_OPENCL | ||||
| #include "src/runtime/kernel/opencl/opencl_subgraph.h" | #include "src/runtime/kernel/opencl/opencl_subgraph.h" | ||||
| #include "src/runtime/gpu/opencl/opencl_runtime.h" | #include "src/runtime/gpu/opencl/opencl_runtime.h" | ||||
| @@ -216,7 +215,7 @@ int Scheduler::InferSubGraphShape(size_t subgraph_index, bool *infer_shape_inter | |||||
| namespace { | namespace { | ||||
| #ifndef SUPPORT_TRAIN | #ifndef SUPPORT_TRAIN | ||||
| int CopyConstTensor(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_origin_tensors, TypeId dst_data_type) { | |||||
| int CastConstTensorData(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_origin_tensors, TypeId dst_data_type) { | |||||
| MS_ASSERT(restored_origin_tensors != nullptr); | MS_ASSERT(restored_origin_tensors != nullptr); | ||||
| MS_ASSERT(tensor != nullptr); | MS_ASSERT(tensor != nullptr); | ||||
| if (dst_data_type != kNumberTypeFloat32 && dst_data_type != kNumberTypeFloat16) { | if (dst_data_type != kNumberTypeFloat32 && dst_data_type != kNumberTypeFloat16) { | ||||
| @@ -248,6 +247,26 @@ int CopyConstTensor(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_origi | |||||
| #else | #else | ||||
| MS_LOG(ERROR) << "Unsupported dst data type: float16"; | MS_LOG(ERROR) << "Unsupported dst data type: float16"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| #endif | |||||
| } else if (tensor->data_type() == kNumberTypeFloat16 && dst_data_type == kNumberTypeFloat32) { | |||||
| #if defined(ENABLE_ARM64) && defined(ENABLE_FP16) | |||||
| auto restore_tensor = Tensor::CopyTensor(*tensor, false); | |||||
| restore_tensor->set_data(origin_data); | |||||
| restore_tensor->set_own_data(tensor->own_data()); | |||||
| tensor->set_data(nullptr); | |||||
| tensor->set_data_type(kNumberTypeFloat32); | |||||
| auto ret = tensor->MallocData(); | |||||
| if (RET_OK != ret) { | |||||
| MS_LOG(ERROR) << "malloc data failed"; | |||||
| return ret; | |||||
| } | |||||
| auto new_tensor_data = tensor->data_c(); | |||||
| MS_ASSERT(new_tensor_data != nullptr); | |||||
| Float16ToFloat32_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum()); | |||||
| (*restored_origin_tensors)[tensor] = restore_tensor; | |||||
| #else | |||||
| MS_LOG(ERROR) << "Unsupported dst data type: float16"; | |||||
| return RET_ERROR; | |||||
| #endif | #endif | ||||
| } else { | } else { | ||||
| if (tensor->own_data()) { | if (tensor->own_data()) { | ||||
| @@ -290,19 +309,6 @@ inline void RestoreTensorData(std::map<Tensor *, Tensor *> *restored_origin_tens | |||||
| } | } | ||||
| FreeRestoreTensors(restored_origin_tensors); | FreeRestoreTensors(restored_origin_tensors); | ||||
| } | } | ||||
| inline bool IsChannelFirst(int index, OpParameter *op_parameter) { | |||||
| MS_ASSERT(op_parameter != nullptr); | |||||
| if (op_parameter->type_ == schema::PrimitiveType_MatMul) { | |||||
| const auto *param = reinterpret_cast<MatMulParameter *>(op_parameter); | |||||
| if (index == 0) { | |||||
| return !(param->a_transpose_); | |||||
| } else if (index == 1) { | |||||
| return param->b_transpose_; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| kernel::LiteKernel *Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_tensors, | kernel::LiteKernel *Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_tensors, | ||||
| @@ -321,23 +327,21 @@ kernel::LiteKernel *Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_ten | |||||
| } | } | ||||
| cpu_desc.data_type = kNumberTypeFloat16; | cpu_desc.data_type = kNumberTypeFloat16; | ||||
| } | } | ||||
| auto ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kernel_data_type); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret; | |||||
| return nullptr; | |||||
| } | |||||
| std::map<Tensor *, Tensor *> restored_origin_tensors; | std::map<Tensor *, Tensor *> restored_origin_tensors; | ||||
| int index = 0; | |||||
| for (auto &tensor : in_tensors) { | |||||
| auto channel_first = IsChannelFirst(index++, op_parameter); | |||||
| auto *restore_tensor = DequantUtil::DequantTensor(tensor, cpu_desc.data_type, channel_first, kernel_data_type); | |||||
| if (restore_tensor != nullptr) { | |||||
| restored_origin_tensors[tensor] = restore_tensor; | |||||
| } else { | |||||
| #ifndef SUPPORT_TRAIN | #ifndef SUPPORT_TRAIN | ||||
| auto ret = CopyConstTensor(tensor, &restored_origin_tensors, kernel_data_type); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(DEBUG) << "CopyConstTensor failed: " << ret; | |||||
| return nullptr; | |||||
| } | |||||
| #endif | |||||
| for (auto &tensor : in_tensors) { | |||||
| ret = CastConstTensorData(tensor, &restored_origin_tensors, kernel_data_type); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(DEBUG) << "CastConstTensorData failed: " << ret; | |||||
| return nullptr; | |||||
| } | } | ||||
| } | } | ||||
| #endif | |||||
| auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, cpu_desc, op_parameter); | auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, cpu_desc, op_parameter); | ||||
| if (kernel != nullptr) { | if (kernel != nullptr) { | ||||
| MS_LOG(DEBUG) << "Get TypeId(" << kernel_data_type << ") op success: " << PrimitiveCurVersionTypeName(op_type); | MS_LOG(DEBUG) << "Get TypeId(" << kernel_data_type << ") op success: " << PrimitiveCurVersionTypeName(op_type); | ||||
| @@ -362,24 +366,18 @@ kernel::LiteKernel *Scheduler::FindGpuKernel(const std::vector<Tensor *> &in_ten | |||||
| gpu_desc.data_type = kNumberTypeInt8; | gpu_desc.data_type = kNumberTypeInt8; | ||||
| } | } | ||||
| // weight quant | |||||
| std::map<Tensor *, Tensor *> restored_origin_tensors; | |||||
| for (auto &tensor : in_tensors) { | |||||
| int index = 0; | |||||
| auto channel_first = IsChannelFirst(index++, op_parameter); | |||||
| auto *restore_tensor = DequantUtil::DequantTensor(tensor, desc.data_type, channel_first, kNumberTypeFloat32); | |||||
| if (restore_tensor != nullptr) { | |||||
| restored_origin_tensors[tensor] = restore_tensor; | |||||
| } | |||||
| // weight dequant | |||||
| auto ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kNumberTypeFloat32); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret; | |||||
| return nullptr; | |||||
| } | } | ||||
| auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, gpu_desc, op_parameter); | auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, gpu_desc, op_parameter); | ||||
| if (kernel != nullptr) { | if (kernel != nullptr) { | ||||
| MS_LOG(DEBUG) << "Get gpu op success: " << PrimitiveCurVersionTypeName(gpu_desc.type); | MS_LOG(DEBUG) << "Get gpu op success: " << PrimitiveCurVersionTypeName(gpu_desc.type); | ||||
| FreeRestoreTensors(&restored_origin_tensors); | |||||
| } else { | } else { | ||||
| MS_LOG(DEBUG) << "Get gpu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(gpu_desc.type); | MS_LOG(DEBUG) << "Get gpu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(gpu_desc.type); | ||||
| RestoreTensorData(&restored_origin_tensors); | |||||
| } | } | ||||
| return kernel; | return kernel; | ||||
| } else { | } else { | ||||
| @@ -396,26 +394,20 @@ kernel::LiteKernel *Scheduler::FindNpuKernel(const std::vector<Tensor *> &in_ten | |||||
| if (npu_desc.data_type == kNumberTypeFloat16) { | if (npu_desc.data_type == kNumberTypeFloat16) { | ||||
| npu_desc.data_type = kNumberTypeFloat32; | npu_desc.data_type = kNumberTypeFloat32; | ||||
| } | } | ||||
| auto ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kNumberTypeFloat32); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret; | |||||
| return nullptr; | |||||
| } | |||||
| for (auto tensor : in_tensors) { | for (auto tensor : in_tensors) { | ||||
| if (tensor->data_type() == kNumberTypeFloat16) { | if (tensor->data_type() == kNumberTypeFloat16) { | ||||
| tensor->set_data_type(kNumberTypeFloat32); | tensor->set_data_type(kNumberTypeFloat32); | ||||
| } | } | ||||
| } | } | ||||
| std::map<Tensor *, Tensor *> restored_origin_tensors; | |||||
| for (auto &tensor : in_tensors) { | |||||
| int index = 0; | |||||
| auto channel_first = IsChannelFirst(index++, op_parameter); | |||||
| auto *restore_tensor = DequantUtil::DequantTensor(tensor, desc.data_type, channel_first, kNumberTypeFloat32); | |||||
| if (restore_tensor != nullptr) { | |||||
| restored_origin_tensors[tensor] = restore_tensor; | |||||
| } | |||||
| } | |||||
| auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, npu_desc, op_parameter); | auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, npu_desc, op_parameter); | ||||
| if (kernel != nullptr) { | if (kernel != nullptr) { | ||||
| FreeRestoreTensors(&restored_origin_tensors); | |||||
| MS_LOG(DEBUG) << "Get npu op success: " << PrimitiveCurVersionTypeName(npu_desc.type); | MS_LOG(DEBUG) << "Get npu op success: " << PrimitiveCurVersionTypeName(npu_desc.type); | ||||
| } else { | } else { | ||||
| RestoreTensorData(&restored_origin_tensors); | |||||
| MS_LOG(DEBUG) << "Get npu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(npu_desc.type); | MS_LOG(DEBUG) << "Get npu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(npu_desc.type); | ||||
| } | } | ||||
| return kernel; | return kernel; | ||||
| @@ -178,6 +178,19 @@ class CpuFp16SubGraph : public CpuSubGraph { | |||||
| int PreProcess() override; | int PreProcess() override; | ||||
| int Run() override { return CpuSubGraph::Run(); } | int Run() override { return CpuSubGraph::Run(); } | ||||
| int Run(const KernelCallBack &before, const KernelCallBack &after) override { | int Run(const KernelCallBack &before, const KernelCallBack &after) override { | ||||
| #ifdef Debug | |||||
| for (const auto *node : nodes_) { | |||||
| if (node->Type() == schema::PrimitiveType_PartialFusion) { | |||||
| continue; | |||||
| } | |||||
| for (const auto *in_tensor : node->in_tensors()) { | |||||
| if (in_tensor->data_type() == kNumberTypeFloat32) { | |||||
| MS_LOG(ERROR) << "FP16 kernel can not accept float32 input"; | |||||
| return lite::RET_ERROR; | |||||
| } | |||||
| } | |||||
| } | |||||
| #endif | |||||
| return CpuSubGraph::Run(before, after); | return CpuSubGraph::Run(before, after); | ||||
| }; | }; | ||||
| int PostProcess() override; | int PostProcess() override; | ||||
| @@ -16,12 +16,11 @@ | |||||
| #include <cmath> | #include <cmath> | ||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/dequant.h" | |||||
| #include "src/weight_decoder.h" | |||||
| #include "src/huffman_decode.h" | #include "src/huffman_decode.h" | ||||
| #include "nnacl/matmul_parameter.h" | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| int DequantUtil::DequantWeight(lite::Tensor *input_tensor, bool channel_first, TypeId dst_data_type) { | |||||
| int WeightDecoder::DequantWeight(lite::Tensor *input_tensor, bool channel_first, TypeId dst_data_type) { | |||||
| MS_ASSERT(input_tensor != nullptr); | MS_ASSERT(input_tensor != nullptr); | ||||
| if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) { | if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) { | ||||
| MS_LOG(ERROR) << "Conv weight input type error." << input_tensor->data_type(); | MS_LOG(ERROR) << "Conv weight input type error." << input_tensor->data_type(); | ||||
| @@ -69,7 +68,7 @@ int DequantUtil::DequantWeight(lite::Tensor *input_tensor, bool channel_first, T | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int DequantUtil::DecodeHuffmanCode(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor) { | |||||
| int WeightDecoder::DecodeHuffmanCode(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor) { | |||||
| MS_ASSERT(dst_tensor != nullptr); | MS_ASSERT(dst_tensor != nullptr); | ||||
| if (!dst_tensor->IsConst() || !src_tensor.enableHuffmanCode()) { | if (!dst_tensor->IsConst() || !src_tensor.enableHuffmanCode()) { | ||||
| return RET_NO_CHANGE; | return RET_NO_CHANGE; | ||||
| @@ -93,7 +92,7 @@ int DequantUtil::DecodeHuffmanCode(const schema::Tensor &src_tensor, lite::Tenso | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int DequantUtil::UnPackToInt(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor) { | |||||
| int WeightDecoder::UnPackToInt(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor) { | |||||
| MS_ASSERT(dst_tensor != nullptr); | MS_ASSERT(dst_tensor != nullptr); | ||||
| auto quant_params = src_tensor.quantParams(); | auto quant_params = src_tensor.quantParams(); | ||||
| if (quant_params == nullptr || quant_params->size() == 0) { | if (quant_params == nullptr || quant_params->size() == 0) { | ||||
| @@ -127,26 +126,39 @@ int DequantUtil::UnPackToInt(const schema::Tensor &src_tensor, lite::Tensor *dst | |||||
| } | } | ||||
| } | } | ||||
| Tensor *DequantUtil::DequantTensor(Tensor *tensor, TypeId data_type, bool channel_first, TypeId dst_data_type) { | |||||
| int WeightDecoder::DequantNode(OpParameter *op_parameter, const std::vector<Tensor *> &in_tensors, | |||||
| TypeId dst_data_type) { | |||||
| if (op_parameter->quant_type_ != schema::QuantType_QUANT_WEIGHT) { | |||||
| return RET_OK; | |||||
| } | |||||
| int index = 0; | |||||
| for (auto &tensor : in_tensors) { | |||||
| auto channel_first = IsChannelFirst(index++, op_parameter); | |||||
| auto ret = WeightDecoder::DequantTensor(tensor, channel_first, dst_data_type); | |||||
| if (ret != RET_OK && ret != RET_NO_CHANGE) { | |||||
| MS_LOG(DEBUG) << "Dequant tensor failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int WeightDecoder::DequantTensor(Tensor *tensor, bool channel_first, TypeId dst_data_type) { | |||||
| MS_ASSERT(tensor != nullptr); | MS_ASSERT(tensor != nullptr); | ||||
| Tensor *restore_tensor = nullptr; | |||||
| if (!tensor->IsConst() || !(data_type == TypeId::kNumberTypeFloat32 || data_type == TypeId::kNumberTypeFloat16)) { | |||||
| return nullptr; | |||||
| if (!tensor->IsConst() || | |||||
| !(dst_data_type == TypeId::kNumberTypeFloat32 || dst_data_type == TypeId::kNumberTypeFloat16)) { | |||||
| return RET_NO_CHANGE; | |||||
| } | } | ||||
| auto restore_type = tensor->data_type(); | |||||
| bool need_dequant = !tensor->quant_params().empty() && tensor->quant_params().front().inited && | bool need_dequant = !tensor->quant_params().empty() && tensor->quant_params().front().inited && | ||||
| (restore_type == kNumberTypeInt8 || restore_type == kNumberTypeInt16); | |||||
| (tensor->data_type() == kNumberTypeInt8 || tensor->data_type() == kNumberTypeInt16); | |||||
| if (!need_dequant) { | if (!need_dequant) { | ||||
| return nullptr; | |||||
| return RET_NO_CHANGE; | |||||
| } | } | ||||
| restore_tensor = Tensor::CopyTensor(*tensor, false); | |||||
| restore_tensor->set_data(tensor->data_c()); | |||||
| restore_tensor->set_own_data(tensor->own_data()); | |||||
| auto ret = DequantUtil::DequantWeight(tensor, channel_first, dst_data_type); | |||||
| auto ret = WeightDecoder::DequantWeight(tensor, channel_first, dst_data_type); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Dequant data failed: " << ret; | MS_LOG(ERROR) << "Dequant data failed: " << ret; | ||||
| return nullptr; | |||||
| return ret; | |||||
| } | } | ||||
| return restore_tensor; | |||||
| return RET_OK; | |||||
| } | } | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -22,19 +22,22 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <queue> | #include <queue> | ||||
| #include <cmath> | #include <cmath> | ||||
| #include "nnacl/matmul_parameter.h" | |||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/tensor.h" | #include "src/tensor.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class DequantUtil { | |||||
| class WeightDecoder { | |||||
| public: | public: | ||||
| static int UnPackToInt(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor); | static int UnPackToInt(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor); | ||||
| static int DecodeHuffmanCode(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor); | static int DecodeHuffmanCode(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor); | ||||
| static Tensor *DequantTensor(Tensor *tensor, TypeId data_type, bool channel_first = true, | |||||
| TypeId dst_data_type = kNumberTypeFloat32); | |||||
| static int DequantNode(OpParameter *op_parameter, const std::vector<Tensor *> &in_tensors, TypeId dst_data_type); | |||||
| private: | |||||
| static int DequantTensor(Tensor *tensor, bool channel_first = true, TypeId dst_data_type = kNumberTypeFloat32); | |||||
| template <typename ST, typename DT = float> | template <typename ST, typename DT = float> | ||||
| static DT *DequantData(lite::Tensor *input_tensor, bool channel_first = true) { | static DT *DequantData(lite::Tensor *input_tensor, bool channel_first = true) { | ||||
| @@ -102,22 +105,19 @@ class DequantUtil { | |||||
| return dequant_datas; | return dequant_datas; | ||||
| } | } | ||||
| template <typename T1, typename T2> | |||||
| static void UnpackUtil(const T1 *weight_data, int pack_size, int origin_bit, void *unpack_int_data) { | |||||
| if (weight_data == nullptr || unpack_int_data == nullptr) { | |||||
| MS_LOG(ERROR) << "data is nullptr"; | |||||
| return; | |||||
| } | |||||
| std::queue<bool> unpack_bit_data; | |||||
| size_t count = 0; | |||||
| for (int i = 0; i < pack_size; ++i) { | |||||
| T2 pack_data = (static_cast<const T2 *>(static_cast<const void *>(weight_data)))[i]; | |||||
| bool is_last = i == pack_size - 1; | |||||
| UnPackData<T1, T2>(origin_bit, pack_data, &unpack_bit_data, unpack_int_data, &count, is_last); | |||||
| inline static bool IsChannelFirst(int index, const OpParameter *op_parameter) { | |||||
| MS_ASSERT(op_parameter != nullptr); | |||||
| if (op_parameter->type_ == schema::PrimitiveType_MatMul) { | |||||
| const auto *param = reinterpret_cast<const MatMulParameter *>(op_parameter); | |||||
| if (index == 0) { | |||||
| return !(param->a_transpose_); | |||||
| } else if (index == 1) { | |||||
| return param->b_transpose_; | |||||
| } | |||||
| } | } | ||||
| return true; | |||||
| } | } | ||||
| private: | |||||
| static int DequantWeight(lite::Tensor *input_tensor, bool channel_first, TypeId dst_data_type = kNumberTypeFloat32); | static int DequantWeight(lite::Tensor *input_tensor, bool channel_first, TypeId dst_data_type = kNumberTypeFloat32); | ||||
| template <typename T1, typename T2> | template <typename T1, typename T2> | ||||
| @@ -143,7 +143,7 @@ set(TEST_LITE_SRC | |||||
| ${LITE_DIR}/src/kernel_registry.cc | ${LITE_DIR}/src/kernel_registry.cc | ||||
| ${LITE_DIR}/src/lite_kernel.cc | ${LITE_DIR}/src/lite_kernel.cc | ||||
| ${LITE_DIR}/src/lite_session.cc | ${LITE_DIR}/src/lite_session.cc | ||||
| ${LITE_DIR}/src/dequant.cc | |||||
| ${LITE_DIR}/src/weight_decoder.cc | |||||
| ${LITE_DIR}/src/huffman_decode.cc | ${LITE_DIR}/src/huffman_decode.cc | ||||
| ${LITE_DIR}/src/sub_graph_kernel.cc | ${LITE_DIR}/src/sub_graph_kernel.cc | ||||
| ${LITE_DIR}/src/sub_graph_split.cc | ${LITE_DIR}/src/sub_graph_split.cc | ||||
| @@ -115,7 +115,7 @@ set(LITE_SRC | |||||
| ${SRC_DIR}/executor.cc | ${SRC_DIR}/executor.cc | ||||
| ${SRC_DIR}/lite_model.cc | ${SRC_DIR}/lite_model.cc | ||||
| ${SRC_DIR}/errorcode.cc | ${SRC_DIR}/errorcode.cc | ||||
| ${SRC_DIR}/dequant.cc | |||||
| ${SRC_DIR}/weight_decoder.cc | |||||
| ${SRC_DIR}/huffman_decode.cc | ${SRC_DIR}/huffman_decode.cc | ||||
| ${SRC_DIR}/ops/ops_utils.cc | ${SRC_DIR}/ops/ops_utils.cc | ||||
| ${SRC_DIR}/ops/ops_def.cc | ${SRC_DIR}/ops/ops_def.cc | ||||
| @@ -339,6 +339,7 @@ STATUS FormatTransPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uniqu | |||||
| MS_LOG(ERROR) << "Crop error"; | MS_LOG(ERROR) << "Crop error"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| node->primitive->value.AsCrop()->axis = axis_map[origin_axis]; | |||||
| node->primitive->value.AsCrop()->offsets = offsets; | node->primitive->value.AsCrop()->offsets = offsets; | ||||
| } | } | ||||
| if (type == schema::PrimitiveType_SliceFusion || type == schema::PrimitiveType_StridedSlice) { | if (type == schema::PrimitiveType_SliceFusion || type == schema::PrimitiveType_StridedSlice) { | ||||
| @@ -15,7 +15,7 @@ | |||||
| */ | */ | ||||
| #include "tools/converter/quantizer/huffman_encode.h" | #include "tools/converter/quantizer/huffman_encode.h" | ||||
| #include "src/dequant.h" | |||||
| #include "src/weight_decoder.h" | |||||
| #include "tools/converter/quantizer/quantize_util.h" | #include "tools/converter/quantizer/quantize_util.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||