| @@ -103,8 +103,9 @@ int ModelImpl::BuildOps() { | |||||
| auto cNode = meta_graph_->nodes()->GetAs<schema::CNode>(i); | auto cNode = meta_graph_->nodes()->GetAs<schema::CNode>(i); | ||||
| auto name = cNode->name()->str(); | auto name = cNode->name()->str(); | ||||
| auto srcPrim = cNode->primitive(); | auto srcPrim = cNode->primitive(); | ||||
| this->ops_[name] = PrimitiveC::UnPackFromSchemaPrimitive(const_cast<schema::Primitive *>(srcPrim)); | |||||
| auto prim = PrimitiveC::UnPackFromSchemaPrimitive(const_cast<schema::Primitive *>(srcPrim)); | |||||
| prim->SetQuantType(cNode->quantType()); | |||||
| this->ops_[name] = prim; | |||||
| } | } | ||||
| return 0; | return 0; | ||||
| } | } | ||||
| @@ -688,6 +688,10 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(const schema::Primitive *primi | |||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| void PrimitiveC::SetQuantType(schema::QuantType quant_type) { | |||||
| this->quant_type_ = quant_type; | |||||
| } | |||||
| schema::QuantType PrimitiveC::GetQuantType() const { return quant_type_;} | |||||
| #endif | #endif | ||||
| int PrimitiveC::Type() const { | int PrimitiveC::Type() const { | ||||
| @@ -145,6 +145,9 @@ class PrimitiveC { | |||||
| int Type() const; | int Type() const; | ||||
| void SetQuantType(schema::QuantType quant_type); | |||||
| schema::QuantType GetQuantType() const; | |||||
| protected: | protected: | ||||
| template <typename T, typename = std::enable_if<std::is_base_of<PrimitiveC, T>::value>> | template <typename T, typename = std::enable_if<std::is_base_of<PrimitiveC, T>::value>> | ||||
| static PrimitiveC *NewPrimitiveC(const schema::Primitive *primitive) { | static PrimitiveC *NewPrimitiveC(const schema::Primitive *primitive) { | ||||
| @@ -194,6 +197,7 @@ class PrimitiveC { | |||||
| const schema::Primitive *primitive_ = nullptr; | const schema::Primitive *primitive_ = nullptr; | ||||
| char *primitive_buf_ = nullptr; | char *primitive_buf_ = nullptr; | ||||
| bool infer_flag_ = true; | bool infer_flag_ = true; | ||||
| schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; | |||||
| }; | }; | ||||
| #endif | #endif | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -331,4 +331,46 @@ int ConvolutionBaseCPUKernel::SetQuantParam() { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int ConvolutionBaseCPUKernel::RestoreFilter(lite::tensor::Tensor *input_tensor) { | |||||
| MS_ASSERT(input_tensor != nullptr); | |||||
| if (input_tensor->GetQuantParams().empty()) { | |||||
| MS_LOG(ERROR) << "no quant param"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| const auto* quant_data = static_cast<const uint8_t*>(input_tensor->Data()); | |||||
| auto* dequant_data = static_cast<float *>(malloc(input_tensor->DataSize() * sizeof(float))); | |||||
| if (dequant_data == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc faile"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (input_tensor->GetQuantParams().size() != kPerTensor) { | |||||
| size_t channels = static_cast<size_t>(input_tensor->Batch()); | |||||
| if (input_tensor->GetQuantParams().size() != channels) { | |||||
| MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->GetQuantParams().size() << channels; | |||||
| return RET_ERROR; | |||||
| } | |||||
| size_t per_channel_size = input_tensor->DataSize() / channels; | |||||
| auto quant_param = input_tensor->GetQuantParams(); | |||||
| for (size_t i = 0; i < channels; i++) { | |||||
| auto param = quant_param.at(i); | |||||
| auto scale = param.scale; | |||||
| auto zero_point = param.zeroPoint; | |||||
| for (size_t j = 0; j < per_channel_size; j++) { | |||||
| dequant_data[per_channel_size * i + j] = static_cast<float>( | |||||
| (quant_data[per_channel_size * i + j] - zero_point) * scale); | |||||
| } | |||||
| } | |||||
| } else { | |||||
| auto quant_param = input_tensor->GetQuantParams(); | |||||
| auto param = quant_param.front(); | |||||
| auto scale = param.scale; | |||||
| auto zero_point = param.zeroPoint; | |||||
| for (int64_t j = 0; j < input_tensor->DataSize(); j++) { | |||||
| dequant_data[j] = static_cast<float>((quant_data[j] - zero_point) * scale); | |||||
| } | |||||
| } | |||||
| input_tensor->SetData(dequant_data); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -60,6 +60,7 @@ class ConvolutionBaseCPUKernel : public LiteKernel { | |||||
| int SetQuantMultiplier(); | int SetQuantMultiplier(); | ||||
| int CheckResizeValid(); | int CheckResizeValid(); | ||||
| void FreeQuantParam(); | void FreeQuantParam(); | ||||
| static int RestoreFilter(lite::tensor::Tensor *input_tensor); | |||||
| protected: | protected: | ||||
| int tile_num_; | int tile_num_; | ||||
| @@ -239,6 +239,12 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten | |||||
| CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func); | CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func); | ||||
| } | } | ||||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||||
| auto *restore_data = weight_tensor->Data(); | |||||
| if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { | |||||
| ConvolutionBaseCPUKernel::RestoreFilter(inputs.at(kWeightIndex)); | |||||
| } | |||||
| kernel::LiteKernel *kernel; | kernel::LiteKernel *kernel; | ||||
| if (kernel_h == 1 && kernel_w == 1) { | if (kernel_h == 1 && kernel_w == 1) { | ||||
| kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(op_parameter, inputs, outputs, ctx, primitive); | kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(op_parameter, inputs, outputs, ctx, primitive); | ||||
| @@ -263,6 +269,12 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten | |||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_)); | << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_)); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->SetData(restore_data); | |||||
| } | |||||
| return kernel; | return kernel; | ||||
| } | } | ||||
| @@ -131,6 +131,13 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::tensor::T | |||||
| const mindspore::lite::PrimitiveC *primitive) { | const mindspore::lite::PrimitiveC *primitive) { | ||||
| MS_ASSERT(opParameter != nullptr); | MS_ASSERT(opParameter != nullptr); | ||||
| MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); | MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); | ||||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||||
| auto *restore_data = weight_tensor->Data(); | |||||
| if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { | |||||
| ConvolutionBaseCPUKernel::RestoreFilter(inputs.at(kWeightIndex)); | |||||
| } | |||||
| auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | ||||
| kernel::LiteKernel *kernel; | kernel::LiteKernel *kernel; | ||||
| if (conv_param->input_channel_ < 32) { | if (conv_param->input_channel_ < 32) { | ||||
| @@ -149,6 +156,12 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::tensor::T | |||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { | |||||
| weight_tensor->FreeData(); | |||||
| weight_tensor->SetData(restore_data); | |||||
| } | |||||
| return kernel; | return kernel; | ||||
| } | } | ||||
| @@ -64,7 +64,8 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me | |||||
| MS_ASSERT(dst_node != nullptr); | MS_ASSERT(dst_node != nullptr); | ||||
| // add quant param | // add quant param | ||||
| dst_node->quantType = primitive->GetQuantType(); | dst_node->quantType = primitive->GetQuantType(); | ||||
| if (dst_node->quantType == schema::QuantType_PostTraining || dst_node->quantType == schema::QuantType_AwareTraining) { | |||||
| if (dst_node->quantType == schema::QuantType_PostTraining || dst_node->quantType == schema::QuantType_AwareTraining | |||||
| || dst_node->quantType == schema::QuantType_WeightQuant) { | |||||
| MS_LOG(DEBUG) << "node: " << dst_node->name << " add QuantParam"; | MS_LOG(DEBUG) << "node: " << dst_node->name << " add QuantParam"; | ||||
| // activation | // activation | ||||
| auto input_quant_params = primitive->GetInputQuantParams(); | auto input_quant_params = primitive->GetInputQuantParams(); | ||||
| @@ -103,7 +104,7 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me | |||||
| } | } | ||||
| } else { | } else { | ||||
| for (auto output_quant_param : output_quant_params[0]) { | for (auto output_quant_param : output_quant_params[0]) { | ||||
| if (tensor_output->quantParams.empty()) { | |||||
| if (tensor_output->quantParams.empty() && dst_node->quantType != schema::QuantType_WeightQuant) { | |||||
| std::unique_ptr<schema::QuantParamT> output_quant_param_ptr = | std::unique_ptr<schema::QuantParamT> output_quant_param_ptr = | ||||
| std::make_unique<schema::QuantParamT>(output_quant_param); | std::make_unique<schema::QuantParamT>(output_quant_param); | ||||
| MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale | MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale | ||||
| @@ -26,6 +26,7 @@ | |||||
| #include "tools/optimizer/fusion/constant_folding_fusion.h" | #include "tools/optimizer/fusion/constant_folding_fusion.h" | ||||
| #include "tools/converter/quantizer/post_training_quantizer.h" | #include "tools/converter/quantizer/post_training_quantizer.h" | ||||
| #include "tools/converter/quantizer/quant_cast.h" | #include "tools/converter/quantizer/quant_cast.h" | ||||
| #include "tools/converter/quantizer/weight_quantizer.h" | |||||
| using std::string; | using std::string; | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -57,11 +58,20 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(old_graph); | FuncGraphPtr new_graph = optimizer->Optimize(old_graph); | ||||
| // quant | // quant | ||||
| if (config != nullptr && config->quantType == schema::QuantType_PostTraining) { | |||||
| this->mQuantizer = std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, 8); | |||||
| if (mQuantizer == nullptr) { | |||||
| MS_LOG(ERROR) << "New PostTrainingQuantizer failed"; | |||||
| return nullptr; | |||||
| if (config != nullptr) { | |||||
| if (config->quantType == schema::QuantType_PostTraining) { | |||||
| this->mQuantizer = std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, 8); | |||||
| if (mQuantizer == nullptr) { | |||||
| MS_LOG(ERROR) << "New PostTrainingQuantizer failed"; | |||||
| return nullptr; | |||||
| } | |||||
| } else if (config->quantType == schema::QuantType_WeightQuant) { | |||||
| this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, config->quantSize, | |||||
| config->convWeightQuantChannelThreshold, config->bitNum); | |||||
| if (mQuantizer == nullptr) { | |||||
| MS_LOG(ERROR) << "New PostTrainingQuantizer failed"; | |||||
| return nullptr; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| if (mQuantizer != nullptr) { | if (mQuantizer != nullptr) { | ||||
| @@ -71,12 +81,14 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||||
| MS_LOG(ERROR) << "Quant failed " << status; | MS_LOG(ERROR) << "Quant failed " << status; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| quant::QuantCast quant_cast; | |||||
| quant_cast.SetInputDataDType(kNumberTypeFloat32); | |||||
| status = quant_cast.Run(new_graph); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "add QuantCast error"; | |||||
| return nullptr; | |||||
| if (config->quantType == schema::QuantType_PostTraining) { | |||||
| quant::QuantCast quant_cast; | |||||
| quant_cast.SetInputDataDType(kNumberTypeFloat32); | |||||
| status = quant_cast.Run(new_graph); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "add QuantCast error"; | |||||
| return nullptr; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -36,6 +36,8 @@ Flags::Flags() { | |||||
| AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128"); | AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128"); | ||||
| AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "-0.5"); | AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "-0.5"); | ||||
| AddFlag(&Flags::quantSize, "quantSize", "Weight quantization size threshold", "0"); | AddFlag(&Flags::quantSize, "quantSize", "Weight quantization size threshold", "0"); | ||||
| AddFlag(&Flags::convWeightQuantChannelThreshold, "convWeightQuantChannelThreshold", | |||||
| "convWeightQuantChannelThreshold", "16"); | |||||
| AddFlag(&Flags::configFile, "config_file", "Configuration for post-training.", ""); | AddFlag(&Flags::configFile, "config_file", "Configuration for post-training.", ""); | ||||
| AddFlag(&Flags::formatTrans, "formatTrans", "whether transform format. true | false", "true"); | AddFlag(&Flags::formatTrans, "formatTrans", "whether transform format. true | false", "true"); | ||||
| } | } | ||||
| @@ -191,6 +191,7 @@ STATUS WeightFormatHardCodePass::HardCodeTFLITE(const std::unique_ptr<CNodeT> &n | |||||
| switch (this->quantType) { | switch (this->quantType) { | ||||
| case QuantType_AwareTraining: | case QuantType_AwareTraining: | ||||
| case QuantType_PostTraining: | case QuantType_PostTraining: | ||||
| case QuantType_WeightQuant: | |||||
| case QuantType_QUANT_NONE: { | case QuantType_QUANT_NONE: { | ||||
| if (opType == schema::PrimitiveType_Conv2D) { | if (opType == schema::PrimitiveType_Conv2D) { | ||||
| weightTensor->format = schema::Format_KHWC; | weightTensor->format = schema::Format_KHWC; | ||||
| @@ -31,7 +31,7 @@ void WeightFormatTransformPass::SetDstFormat(Format format) { this->dstFormat = | |||||
| STATUS WeightFormatTransformPass::Run(MetaGraphT *graph) { | STATUS WeightFormatTransformPass::Run(MetaGraphT *graph) { | ||||
| MS_ASSERT(graph != nullptr); | MS_ASSERT(graph != nullptr); | ||||
| if (this->quantType == QuantType_AwareTraining) { | |||||
| if (this->quantType == QuantType_AwareTraining || this->quantType == QuantType_WeightQuant) { | |||||
| auto status = QuantDataFormatTrans(graph); | auto status = QuantDataFormatTrans(graph); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status; | MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status; | ||||
| @@ -11,6 +11,7 @@ add_library(quantizer_mid OBJECT | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc | ${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc | ${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc | ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/weight_quantizer.cc | |||||
| ) | ) | ||||
| if(ENABLE_ASAN) | if(ENABLE_ASAN) | ||||
| @@ -530,7 +530,8 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<P | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto status = | auto status = | ||||
| QuantFilter(paramValue, primitive_c, QuantType_PostTraining, quant_max, quant_min, bit_num, perchanel, depthwise); | |||||
| QuantFilter<int8_t>(paramValue, primitive_c, QuantType_PostTraining, quant_max, | |||||
| quant_min, bit_num, perchanel, depthwise); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "QuantFilter failed: " << status; | MS_LOG(ERROR) << "QuantFilter failed: " << status; | ||||
| return status; | return status; | ||||
| @@ -279,171 +279,6 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType, | |||||
| int quant_max, int quant_min, size_t bitNum, bool per_channel, bool depth_wise) { | |||||
| auto dims = weight->tensor_shape(); | |||||
| if (per_channel) { | |||||
| if (dims.size() != 4) { | |||||
| MS_LOG(ERROR) << "weight dims size error: " << dims.size() << " Back to per layer."; | |||||
| per_channel = false; | |||||
| } else { | |||||
| uint32_t channels = dims[0]; | |||||
| if (channels == 0) { | |||||
| MS_LOG(ERROR) << "channels is 0"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| } | |||||
| vector<schema::QuantParamT> quant_params; | |||||
| size_t elem_count = weight->tensor_shape_size(); | |||||
| auto *raw_datas = static_cast<float *>(weight->tensor_addr()); | |||||
| if (raw_datas == nullptr) { | |||||
| MS_LOG(ERROR) << "rawDatas is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| vector<int8_t> quant_datas(elem_count); | |||||
| if (per_channel) { | |||||
| // notice: | |||||
| // at now for tflite model, Conv2D's weight format is KHWC, so is DepthwiseConv2D | |||||
| // if TransWeightFormat is done before PostTraingingQuantization, the DepthwiseCon2D's weight is CHWK | |||||
| if (depth_wise) { | |||||
| // channel at last | |||||
| auto channels = dims[3]; | |||||
| if (channels == 0) { | |||||
| MS_LOG(ERROR) << "channels is zero"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| size_t one_filter_size = elem_count / channels; | |||||
| for (int i = 0; i < channels; i++) { | |||||
| float min = FLT_MAX; | |||||
| float max = -FLT_MAX; | |||||
| // find min and max | |||||
| for (size_t j = 0; j < one_filter_size; j++) { | |||||
| auto index = i + j * channels; | |||||
| if (index >= elem_count) { | |||||
| MS_LOG(ERROR) << "over flow!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| min = std::min(min, raw_datas[index]); | |||||
| max = std::max(max, raw_datas[index]); | |||||
| } | |||||
| schema::QuantParamT quant_param; | |||||
| STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "CalQuantizationParams failed" << status; | |||||
| return status; | |||||
| } | |||||
| quant_params.emplace_back(quant_param); | |||||
| // do quantization | |||||
| for (uint32_t j = 0; j < one_filter_size; j++) { | |||||
| auto index = i + j * channels; | |||||
| if (index >= elem_count) { | |||||
| MS_LOG(ERROR) << "over flow!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| float raw_data = raw_datas[index]; | |||||
| auto quant_data = QuantizeData<int8_t>(raw_data, quant_param, quant_max, quant_min); | |||||
| quant_datas[index] = quant_data; | |||||
| } | |||||
| } | |||||
| auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), | |||||
| elem_count * sizeof(int8_t)); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||||
| return RET_ERROR; | |||||
| } | |||||
| weight->set_tensor_size(elem_count * sizeof(int8_t)); | |||||
| } else { | |||||
| // channel at first | |||||
| auto channels = dims[0]; | |||||
| if (channels == 0) { | |||||
| MS_LOG(ERROR) << "channels is zero"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| size_t one_filter_size = elem_count / channels; | |||||
| for (int i = 0; i < channels; i++) { | |||||
| float min = FLT_MAX; | |||||
| float max = -FLT_MAX; | |||||
| // find min and max | |||||
| for (size_t j = 0; j < one_filter_size; j++) { | |||||
| auto index = j + i * one_filter_size; | |||||
| if (index >= elem_count) { | |||||
| MS_LOG(ERROR) << "over flow!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| min = std::min(min, raw_datas[index]); | |||||
| max = std::max(max, raw_datas[index]); | |||||
| } | |||||
| schema::QuantParamT quant_param; | |||||
| STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "CalQuantizationParams failed" << status; | |||||
| return status; | |||||
| } | |||||
| quant_params.emplace_back(quant_param); | |||||
| // do quantization | |||||
| for (uint32_t j = 0; j < one_filter_size; j++) { | |||||
| auto index = j + i * one_filter_size; | |||||
| if (index >= elem_count) { | |||||
| MS_LOG(ERROR) << "over flow!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| float raw_data = raw_datas[index]; | |||||
| auto quant_data = QuantizeData<int8_t>(raw_data, quant_param, quant_max, quant_min); | |||||
| quant_datas[index] = quant_data; | |||||
| } | |||||
| } | |||||
| auto ret = | |||||
| memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t)); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||||
| return RET_ERROR; | |||||
| } | |||||
| weight->set_tensor_size(elem_count * sizeof(int8_t)); | |||||
| } | |||||
| } else { | |||||
| // per layer | |||||
| float min = FLT_MAX; | |||||
| float max = -FLT_MIN; | |||||
| for (uint32_t i = 0; i < elem_count; i++) { | |||||
| // find max min | |||||
| min = std::min(min, raw_datas[i]); | |||||
| max = std::max(max, raw_datas[i]); | |||||
| } | |||||
| schema::QuantParamT quant_param; | |||||
| STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "CalQuantizationParams failed" << status; | |||||
| return status; | |||||
| } | |||||
| quant_params.emplace_back(quant_param); | |||||
| // update data and datatype | |||||
| for (uint32_t i = 0; i < elem_count; i++) { | |||||
| float raw_data = raw_datas[i]; | |||||
| auto quant_data = QuantizeData<int8_t>(raw_data, quant_param, quant_max, quant_min); | |||||
| quant_datas[i] = quant_data; | |||||
| } | |||||
| auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t)); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||||
| return RET_ERROR; | |||||
| } | |||||
| weight->set_tensor_size(elem_count * sizeof(int8_t)); | |||||
| } | |||||
| if (quant_params.empty()) { | |||||
| MS_LOG(ERROR) << "quant_params empty"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| primitive_c->AddInputQuantParam(quant_params); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS PostBitPack(float *weight, size_t shapeSize, size_t bitNum) { | STATUS PostBitPack(float *weight, size_t shapeSize, size_t bitNum) { | ||||
| auto *rawDatas = reinterpret_cast<uint8_t *>(weight); | auto *rawDatas = reinterpret_cast<uint8_t *>(weight); | ||||
| vector<uint8_t> qDatas(rawDatas, rawDatas + shapeSize); | vector<uint8_t> qDatas(rawDatas, rawDatas + shapeSize); | ||||
| @@ -21,6 +21,8 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <cmath> | #include <cmath> | ||||
| #include <array> | #include <array> | ||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include "tools/converter/quantizer/quantizer.h" | #include "tools/converter/quantizer/quantizer.h" | ||||
| #include "src/ops/primitive_c.h" | #include "src/ops/primitive_c.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| @@ -117,10 +119,171 @@ T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quan | |||||
| return static_cast<T>(quant_data); | return static_cast<T>(quant_data); | ||||
| }(); | }(); | ||||
| } | } | ||||
| template <typename T> | |||||
| STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType, | STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType, | ||||
| int quant_max, int quant_min, size_t bitNum = UINT8_QUANTIZATION, bool per_channel = false, | |||||
| bool depth_wise = false); | |||||
| int quant_max, int quant_min, size_t bitNum, bool per_channel, bool depth_wise) { | |||||
| auto dims = weight->tensor_shape(); | |||||
| if (per_channel) { | |||||
| if (dims.size() != 4) { | |||||
| MS_LOG(ERROR) << "weight dims size error: " << dims.size() << " Back to per layer."; | |||||
| per_channel = false; | |||||
| } else { | |||||
| uint32_t channels = dims[0]; | |||||
| if (channels == 0) { | |||||
| MS_LOG(ERROR) << "channels is 0"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| } | |||||
| std::vector<schema::QuantParamT> quant_params; | |||||
| size_t elem_count = weight->tensor_shape_size(); | |||||
| auto *raw_datas = static_cast<float *>(weight->tensor_addr()); | |||||
| if (raw_datas == nullptr) { | |||||
| MS_LOG(ERROR) << "rawDatas is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| std::vector<T> quant_datas(elem_count); | |||||
| if (per_channel) { | |||||
| // notice: | |||||
| // at now for tflite model, Conv2D's weight format is KHWC, so is DepthwiseConv2D | |||||
| // if TransWeightFormat is done before PostTraingingQuantization, the DepthwiseCon2D's weight is CHWK | |||||
| if (depth_wise) { | |||||
| // channel at last | |||||
| auto channels = dims[3]; | |||||
| if (channels == 0) { | |||||
| MS_LOG(ERROR) << "channels is zero"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| size_t one_filter_size = elem_count / channels; | |||||
| for (int i = 0; i < channels; i++) { | |||||
| float min = FLT_MAX; | |||||
| float max = -FLT_MAX; | |||||
| // find min and max | |||||
| for (size_t j = 0; j < one_filter_size; j++) { | |||||
| auto index = i + j * channels; | |||||
| if (index >= elem_count) { | |||||
| MS_LOG(ERROR) << "over flow!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| min = std::min(min, raw_datas[index]); | |||||
| max = std::max(max, raw_datas[index]); | |||||
| } | |||||
| schema::QuantParamT quant_param; | |||||
| STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "CalQuantizationParams failed" << status; | |||||
| return status; | |||||
| } | |||||
| quant_params.emplace_back(quant_param); | |||||
| // do quantization | |||||
| for (uint32_t j = 0; j < one_filter_size; j++) { | |||||
| auto index = i + j * channels; | |||||
| if (index >= elem_count) { | |||||
| MS_LOG(ERROR) << "over flow!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| float raw_data = raw_datas[index]; | |||||
| auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min); | |||||
| quant_datas[index] = quant_data; | |||||
| } | |||||
| } | |||||
| auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), | |||||
| elem_count * sizeof(T)); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||||
| return RET_ERROR; | |||||
| } | |||||
| weight->set_tensor_size(elem_count * sizeof(T)); | |||||
| } else { | |||||
| // channel at first | |||||
| auto channels = dims[0]; | |||||
| if (channels == 0) { | |||||
| MS_LOG(ERROR) << "channels is zero"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| size_t one_filter_size = elem_count / channels; | |||||
| for (int i = 0; i < channels; i++) { | |||||
| float min = FLT_MAX; | |||||
| float max = -FLT_MAX; | |||||
| // find min and max | |||||
| for (size_t j = 0; j < one_filter_size; j++) { | |||||
| auto index = j + i * one_filter_size; | |||||
| if (index >= elem_count) { | |||||
| MS_LOG(ERROR) << "over flow!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| min = std::min(min, raw_datas[index]); | |||||
| max = std::max(max, raw_datas[index]); | |||||
| } | |||||
| schema::QuantParamT quant_param; | |||||
| STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "CalQuantizationParams failed" << status; | |||||
| return status; | |||||
| } | |||||
| quant_params.emplace_back(quant_param); | |||||
| // do quantization | |||||
| for (uint32_t j = 0; j < one_filter_size; j++) { | |||||
| auto index = j + i * one_filter_size; | |||||
| if (index >= elem_count) { | |||||
| MS_LOG(ERROR) << "over flow!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| float raw_data = raw_datas[index]; | |||||
| auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min); | |||||
| quant_datas[index] = quant_data; | |||||
| } | |||||
| } | |||||
| auto ret = | |||||
| memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t)); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||||
| return RET_ERROR; | |||||
| } | |||||
| weight->set_tensor_size(elem_count * sizeof(T)); | |||||
| } | |||||
| } else { | |||||
| // per layer | |||||
| float min = FLT_MAX; | |||||
| float max = -FLT_MIN; | |||||
| for (uint32_t i = 0; i < elem_count; i++) { | |||||
| // find max min | |||||
| min = std::min(min, raw_datas[i]); | |||||
| max = std::max(max, raw_datas[i]); | |||||
| } | |||||
| schema::QuantParamT quant_param; | |||||
| STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "CalQuantizationParams failed" << status; | |||||
| return status; | |||||
| } | |||||
| quant_params.emplace_back(quant_param); | |||||
| // update data and datatype | |||||
| for (uint32_t i = 0; i < elem_count; i++) { | |||||
| float raw_data = raw_datas[i]; | |||||
| auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min); | |||||
| quant_datas[i] = quant_data; | |||||
| } | |||||
| auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t)); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||||
| return RET_ERROR; | |||||
| } | |||||
| weight->set_tensor_size(elem_count * sizeof(T)); | |||||
| } | |||||
| if (quant_params.empty()) { | |||||
| MS_LOG(ERROR) << "quant_params empty"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| primitive_c->AddInputQuantParam(quant_params); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS PostBitPack(float *weights, size_t shapeSize, size_t bitNum = UINT8_QUANTIZATION); | STATUS PostBitPack(float *weights, size_t shapeSize, size_t bitNum = UINT8_QUANTIZATION); | ||||
| } // namespace quant | } // namespace quant | ||||
| @@ -0,0 +1,148 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/quantizer/weight_quantizer.h" | |||||
| #include <list> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "src/common/common.h" | |||||
| #include "ir/dtype/type_id.h" | |||||
| using std::string; | |||||
| using std::vector; | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| namespace quant { | |||||
| WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const string &weightSize, | |||||
| const std::string &convWeightChannelThreshold, const std::string &bitNum) | |||||
| : Quantizer(graph) { | |||||
| auto quantSize = static_cast<size_t>(std::stoull(weightSize)); | |||||
| this->bitNum = static_cast<size_t>(std::stoull(bitNum)); | |||||
| auto convQuantWeightChannelThreshold = static_cast<size_t>(std::stoull(convWeightChannelThreshold)); | |||||
| mStrategy.reset(new QuantStrategy(quantSize, convQuantWeightChannelThreshold)); | |||||
| } | |||||
| STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) { | |||||
| for (auto &cnode : nodes) { | |||||
| if (!mStrategy->CanConvOpQuantized(cnode)) { | |||||
| continue; | |||||
| } | |||||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||||
| if (primitive_c == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto inputNode = cnode->input(2); | |||||
| if (!inputNode->isa<Parameter>()) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto paramNode = inputNode->cast<ParameterPtr>(); | |||||
| if (!paramNode->has_default()) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| std::vector<schema::QuantParamT> quant_params; | |||||
| primitive_c->AddInputQuantParam(quant_params); | |||||
| auto op_type = (schema::PrimitiveType)primitive_c->Type(); | |||||
| bool depthwise = op_type == schema::PrimitiveType_DepthwiseConv2D ? true : false; | |||||
| ParamValueLitePtr param_value = std::static_pointer_cast<ParamValueLite>(paramNode->default_param()); | |||||
| auto status = QuantFilter<uint8_t>(param_value, primitive_c, QuantType_WeightQuant, 255, 0, | |||||
| bitNum, true, depthwise); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||||
| return status; | |||||
| } | |||||
| param_value->set_tensor_type(kNumberTypeUInt8); | |||||
| primitive_c->SetQuantType(schema::QuantType_WeightQuant); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) { | |||||
| for (auto &node : nodes) { | |||||
| if (!mStrategy->CanMulOpQuantized(node)) { | |||||
| continue; | |||||
| } | |||||
| ParamValueLitePtr param_value = nullptr; | |||||
| for (size_t i = 1; i < node->size(); i++) { | |||||
| auto inputNode = node->input(i); | |||||
| if (inputNode->isa<Parameter>() == true) { | |||||
| auto paramNode = inputNode->cast<ParameterPtr>(); | |||||
| if ((paramNode != nullptr) && (paramNode->has_default() == true)) { | |||||
| param_value = std::static_pointer_cast<ParamValueLite>(paramNode->default_param()); | |||||
| if ((param_value == nullptr) || (param_value->tensor_size() == 0) | |||||
| || (param_value->tensor_shape().size() != 4) | |||||
| || (param_value->tensor_addr() == nullptr) | |||||
| || (param_value->tensor_type() != mindspore::kNumberTypeFloat32)) { | |||||
| param_value = nullptr; | |||||
| continue; | |||||
| } else { | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| if (param_value == nullptr) { | |||||
| MS_LOG(ERROR) << "No valid input param node !"; | |||||
| return RET_ERROR;; | |||||
| } | |||||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(node->input(0)); | |||||
| if (primitive_c == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto status = QuantFilter<uint8_t>(param_value, primitive_c, QuantType_WeightQuant, 255, 0, bitNum, true, false); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||||
| return status; | |||||
| } | |||||
| param_value->set_tensor_type(kNumberTypeUInt8); | |||||
| primitive_c->SetQuantType(schema::QuantType_WeightQuant); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS WeightQuantizer::DoQuantize(FuncGraphPtr funcGraph) { | |||||
| auto ret = RET_OK; | |||||
| auto cnodes = funcGraph->GetOrderedCnodes(); | |||||
| ret = DoConvQuantize(cnodes); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "DoConvQuantize failed :" << ret; | |||||
| return ret; | |||||
| } | |||||
| ret = DoMulQuantize(cnodes); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "DoMulQuantize failed :" << ret; | |||||
| return ret; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| } // namespace quant | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,53 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef WEIGHT_QUANTIZER_H | |||||
| #define WEIGHT_QUANTIZER_H | |||||
| #include <memory> | |||||
| #include <list> | |||||
| #include <string> | |||||
| #include "tools/converter/quantizer/quantizer.h" | |||||
| #include "tools/converter/quantizer/quantize_util.h" | |||||
| #include "ir/func_graph.h" | |||||
| #include "ir/anf.h" | |||||
| #include "include/model.h" | |||||
| #include "base/base.h" | |||||
| #include "abstract/dshape.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| namespace quant { | |||||
| class WeightQuantizer : public Quantizer { | |||||
| public: | |||||
| WeightQuantizer(FuncGraphPtr graph, const std::string& weightSize, | |||||
| const std::string& covWeightChannelThreshold, const std::string& bitNum); | |||||
| ~WeightQuantizer() = default; | |||||
| STATUS DoQuantize(FuncGraphPtr funcGraph) override; | |||||
| STATUS DoConvQuantize(const std::list<CNodePtr> &nodes); | |||||
| STATUS DoMulQuantize(const std::list<CNodePtr> &nodes); | |||||
| private: | |||||
| std::unique_ptr<QuantStrategy> mStrategy; | |||||
| size_t bitNum; | |||||
| }; | |||||
| } // namespace quant | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif | |||||