| @@ -103,8 +103,9 @@ int ModelImpl::BuildOps() { | |||
| auto cNode = meta_graph_->nodes()->GetAs<schema::CNode>(i); | |||
| auto name = cNode->name()->str(); | |||
| auto srcPrim = cNode->primitive(); | |||
| this->ops_[name] = PrimitiveC::UnPackFromSchemaPrimitive(const_cast<schema::Primitive *>(srcPrim)); | |||
| auto prim = PrimitiveC::UnPackFromSchemaPrimitive(const_cast<schema::Primitive *>(srcPrim)); | |||
| prim->SetQuantType(cNode->quantType()); | |||
| this->ops_[name] = prim; | |||
| } | |||
| return 0; | |||
| } | |||
| @@ -688,6 +688,10 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(const schema::Primitive *primi | |||
| } | |||
| return nullptr; | |||
| } | |||
| void PrimitiveC::SetQuantType(schema::QuantType quant_type) { | |||
| this->quant_type_ = quant_type; | |||
| } | |||
| schema::QuantType PrimitiveC::GetQuantType() const { return quant_type_;} | |||
| #endif | |||
| int PrimitiveC::Type() const { | |||
| @@ -145,6 +145,9 @@ class PrimitiveC { | |||
| int Type() const; | |||
| void SetQuantType(schema::QuantType quant_type); | |||
| schema::QuantType GetQuantType() const; | |||
| protected: | |||
| template <typename T, typename = std::enable_if<std::is_base_of<PrimitiveC, T>::value>> | |||
| static PrimitiveC *NewPrimitiveC(const schema::Primitive *primitive) { | |||
| @@ -194,6 +197,7 @@ class PrimitiveC { | |||
| const schema::Primitive *primitive_ = nullptr; | |||
| char *primitive_buf_ = nullptr; | |||
| bool infer_flag_ = true; | |||
| schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; | |||
| }; | |||
| #endif | |||
| } // namespace lite | |||
| @@ -331,4 +331,46 @@ int ConvolutionBaseCPUKernel::SetQuantParam() { | |||
| return RET_OK; | |||
| } | |||
| int ConvolutionBaseCPUKernel::RestoreFilter(lite::tensor::Tensor *input_tensor) { | |||
| MS_ASSERT(input_tensor != nullptr); | |||
| if (input_tensor->GetQuantParams().empty()) { | |||
| MS_LOG(ERROR) << "no quant param"; | |||
| return RET_ERROR; | |||
| } | |||
| const auto* quant_data = static_cast<const uint8_t*>(input_tensor->Data()); | |||
| auto* dequant_data = static_cast<float *>(malloc(input_tensor->DataSize() * sizeof(float))); | |||
| if (dequant_data == nullptr) { | |||
| MS_LOG(ERROR) << "malloc faile"; | |||
| return RET_ERROR; | |||
| } | |||
| if (input_tensor->GetQuantParams().size() != kPerTensor) { | |||
| size_t channels = static_cast<size_t>(input_tensor->Batch()); | |||
| if (input_tensor->GetQuantParams().size() != channels) { | |||
| MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->GetQuantParams().size() << channels; | |||
| return RET_ERROR; | |||
| } | |||
| size_t per_channel_size = input_tensor->DataSize() / channels; | |||
| auto quant_param = input_tensor->GetQuantParams(); | |||
| for (size_t i = 0; i < channels; i++) { | |||
| auto param = quant_param.at(i); | |||
| auto scale = param.scale; | |||
| auto zero_point = param.zeroPoint; | |||
| for (size_t j = 0; j < per_channel_size; j++) { | |||
| dequant_data[per_channel_size * i + j] = static_cast<float>( | |||
| (quant_data[per_channel_size * i + j] - zero_point) * scale); | |||
| } | |||
| } | |||
| } else { | |||
| auto quant_param = input_tensor->GetQuantParams(); | |||
| auto param = quant_param.front(); | |||
| auto scale = param.scale; | |||
| auto zero_point = param.zeroPoint; | |||
| for (int64_t j = 0; j < input_tensor->DataSize(); j++) { | |||
| dequant_data[j] = static_cast<float>((quant_data[j] - zero_point) * scale); | |||
| } | |||
| } | |||
| input_tensor->SetData(dequant_data); | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::kernel | |||
| @@ -60,6 +60,7 @@ class ConvolutionBaseCPUKernel : public LiteKernel { | |||
| int SetQuantMultiplier(); | |||
| int CheckResizeValid(); | |||
| void FreeQuantParam(); | |||
| static int RestoreFilter(lite::tensor::Tensor *input_tensor); | |||
| protected: | |||
| int tile_num_; | |||
| @@ -239,6 +239,12 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten | |||
| CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func); | |||
| } | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->Data(); | |||
| if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { | |||
| ConvolutionBaseCPUKernel::RestoreFilter(inputs.at(kWeightIndex)); | |||
| } | |||
| kernel::LiteKernel *kernel; | |||
| if (kernel_h == 1 && kernel_w == 1) { | |||
| kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(op_parameter, inputs, outputs, ctx, primitive); | |||
| @@ -263,6 +269,12 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_)); | |||
| return nullptr; | |||
| } | |||
| if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return kernel; | |||
| } | |||
| @@ -131,6 +131,13 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::tensor::T | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| MS_ASSERT(opParameter != nullptr); | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); | |||
| auto *weight_tensor = inputs.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->Data(); | |||
| if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { | |||
| ConvolutionBaseCPUKernel::RestoreFilter(inputs.at(kWeightIndex)); | |||
| } | |||
| auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | |||
| kernel::LiteKernel *kernel; | |||
| if (conv_param->input_channel_ < 32) { | |||
| @@ -149,6 +156,12 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::tensor::T | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| return nullptr; | |||
| } | |||
| if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return kernel; | |||
| } | |||
| @@ -64,7 +64,8 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me | |||
| MS_ASSERT(dst_node != nullptr); | |||
| // add quant param | |||
| dst_node->quantType = primitive->GetQuantType(); | |||
| if (dst_node->quantType == schema::QuantType_PostTraining || dst_node->quantType == schema::QuantType_AwareTraining) { | |||
| if (dst_node->quantType == schema::QuantType_PostTraining || dst_node->quantType == schema::QuantType_AwareTraining | |||
| || dst_node->quantType == schema::QuantType_WeightQuant) { | |||
| MS_LOG(DEBUG) << "node: " << dst_node->name << " add QuantParam"; | |||
| // activation | |||
| auto input_quant_params = primitive->GetInputQuantParams(); | |||
| @@ -103,7 +104,7 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me | |||
| } | |||
| } else { | |||
| for (auto output_quant_param : output_quant_params[0]) { | |||
| if (tensor_output->quantParams.empty()) { | |||
| if (tensor_output->quantParams.empty() && dst_node->quantType != schema::QuantType_WeightQuant) { | |||
| std::unique_ptr<schema::QuantParamT> output_quant_param_ptr = | |||
| std::make_unique<schema::QuantParamT>(output_quant_param); | |||
| MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale | |||
| @@ -26,6 +26,7 @@ | |||
| #include "tools/optimizer/fusion/constant_folding_fusion.h" | |||
| #include "tools/converter/quantizer/post_training_quantizer.h" | |||
| #include "tools/converter/quantizer/quant_cast.h" | |||
| #include "tools/converter/quantizer/weight_quantizer.h" | |||
| using std::string; | |||
| namespace mindspore { | |||
| @@ -57,11 +58,20 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||
| FuncGraphPtr new_graph = optimizer->Optimize(old_graph); | |||
| // quant | |||
| if (config != nullptr && config->quantType == schema::QuantType_PostTraining) { | |||
| this->mQuantizer = std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, 8); | |||
| if (mQuantizer == nullptr) { | |||
| MS_LOG(ERROR) << "New PostTrainingQuantizer failed"; | |||
| return nullptr; | |||
| if (config != nullptr) { | |||
| if (config->quantType == schema::QuantType_PostTraining) { | |||
| this->mQuantizer = std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, 8); | |||
| if (mQuantizer == nullptr) { | |||
| MS_LOG(ERROR) << "New PostTrainingQuantizer failed"; | |||
| return nullptr; | |||
| } | |||
| } else if (config->quantType == schema::QuantType_WeightQuant) { | |||
| this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, config->quantSize, | |||
| config->convWeightQuantChannelThreshold, config->bitNum); | |||
| if (mQuantizer == nullptr) { | |||
| MS_LOG(ERROR) << "New PostTrainingQuantizer failed"; | |||
| return nullptr; | |||
| } | |||
| } | |||
| } | |||
| if (mQuantizer != nullptr) { | |||
| @@ -71,12 +81,14 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||
| MS_LOG(ERROR) << "Quant failed " << status; | |||
| return nullptr; | |||
| } | |||
| quant::QuantCast quant_cast; | |||
| quant_cast.SetInputDataDType(kNumberTypeFloat32); | |||
| status = quant_cast.Run(new_graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "add QuantCast error"; | |||
| return nullptr; | |||
| if (config->quantType == schema::QuantType_PostTraining) { | |||
| quant::QuantCast quant_cast; | |||
| quant_cast.SetInputDataDType(kNumberTypeFloat32); | |||
| status = quant_cast.Run(new_graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "add QuantCast error"; | |||
| return nullptr; | |||
| } | |||
| } | |||
| } | |||
| @@ -36,6 +36,8 @@ Flags::Flags() { | |||
| AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128"); | |||
| AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "-0.5"); | |||
| AddFlag(&Flags::quantSize, "quantSize", "Weight quantization size threshold", "0"); | |||
| AddFlag(&Flags::convWeightQuantChannelThreshold, "convWeightQuantChannelThreshold", | |||
| "convWeightQuantChannelThreshold", "16"); | |||
| AddFlag(&Flags::configFile, "config_file", "Configuration for post-training.", ""); | |||
| AddFlag(&Flags::formatTrans, "formatTrans", "whether transform format. true | false", "true"); | |||
| } | |||
| @@ -191,6 +191,7 @@ STATUS WeightFormatHardCodePass::HardCodeTFLITE(const std::unique_ptr<CNodeT> &n | |||
| switch (this->quantType) { | |||
| case QuantType_AwareTraining: | |||
| case QuantType_PostTraining: | |||
| case QuantType_WeightQuant: | |||
| case QuantType_QUANT_NONE: { | |||
| if (opType == schema::PrimitiveType_Conv2D) { | |||
| weightTensor->format = schema::Format_KHWC; | |||
| @@ -31,7 +31,7 @@ void WeightFormatTransformPass::SetDstFormat(Format format) { this->dstFormat = | |||
| STATUS WeightFormatTransformPass::Run(MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| if (this->quantType == QuantType_AwareTraining) { | |||
| if (this->quantType == QuantType_AwareTraining || this->quantType == QuantType_WeightQuant) { | |||
| auto status = QuantDataFormatTrans(graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status; | |||
| @@ -11,6 +11,7 @@ add_library(quantizer_mid OBJECT | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/weight_quantizer.cc | |||
| ) | |||
| if(ENABLE_ASAN) | |||
| @@ -530,7 +530,8 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<P | |||
| return RET_ERROR; | |||
| } | |||
| auto status = | |||
| QuantFilter(paramValue, primitive_c, QuantType_PostTraining, quant_max, quant_min, bit_num, perchanel, depthwise); | |||
| QuantFilter<int8_t>(paramValue, primitive_c, QuantType_PostTraining, quant_max, | |||
| quant_min, bit_num, perchanel, depthwise); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantFilter failed: " << status; | |||
| return status; | |||
| @@ -279,171 +279,6 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl | |||
| return RET_OK; | |||
| } | |||
| STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType, | |||
| int quant_max, int quant_min, size_t bitNum, bool per_channel, bool depth_wise) { | |||
| auto dims = weight->tensor_shape(); | |||
| if (per_channel) { | |||
| if (dims.size() != 4) { | |||
| MS_LOG(ERROR) << "weight dims size error: " << dims.size() << " Back to per layer."; | |||
| per_channel = false; | |||
| } else { | |||
| uint32_t channels = dims[0]; | |||
| if (channels == 0) { | |||
| MS_LOG(ERROR) << "channels is 0"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } | |||
| vector<schema::QuantParamT> quant_params; | |||
| size_t elem_count = weight->tensor_shape_size(); | |||
| auto *raw_datas = static_cast<float *>(weight->tensor_addr()); | |||
| if (raw_datas == nullptr) { | |||
| MS_LOG(ERROR) << "rawDatas is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| vector<int8_t> quant_datas(elem_count); | |||
| if (per_channel) { | |||
| // notice: | |||
| // at now for tflite model, Conv2D's weight format is KHWC, so is DepthwiseConv2D | |||
| // if TransWeightFormat is done before PostTraingingQuantization, the DepthwiseCon2D's weight is CHWK | |||
| if (depth_wise) { | |||
| // channel at last | |||
| auto channels = dims[3]; | |||
| if (channels == 0) { | |||
| MS_LOG(ERROR) << "channels is zero"; | |||
| return RET_ERROR; | |||
| } | |||
| size_t one_filter_size = elem_count / channels; | |||
| for (int i = 0; i < channels; i++) { | |||
| float min = FLT_MAX; | |||
| float max = -FLT_MAX; | |||
| // find min and max | |||
| for (size_t j = 0; j < one_filter_size; j++) { | |||
| auto index = i + j * channels; | |||
| if (index >= elem_count) { | |||
| MS_LOG(ERROR) << "over flow!"; | |||
| return RET_ERROR; | |||
| } | |||
| min = std::min(min, raw_datas[index]); | |||
| max = std::max(max, raw_datas[index]); | |||
| } | |||
| schema::QuantParamT quant_param; | |||
| STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "CalQuantizationParams failed" << status; | |||
| return status; | |||
| } | |||
| quant_params.emplace_back(quant_param); | |||
| // do quantization | |||
| for (uint32_t j = 0; j < one_filter_size; j++) { | |||
| auto index = i + j * channels; | |||
| if (index >= elem_count) { | |||
| MS_LOG(ERROR) << "over flow!"; | |||
| return RET_ERROR; | |||
| } | |||
| float raw_data = raw_datas[index]; | |||
| auto quant_data = QuantizeData<int8_t>(raw_data, quant_param, quant_max, quant_min); | |||
| quant_datas[index] = quant_data; | |||
| } | |||
| } | |||
| auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), | |||
| elem_count * sizeof(int8_t)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||
| return RET_ERROR; | |||
| } | |||
| weight->set_tensor_size(elem_count * sizeof(int8_t)); | |||
| } else { | |||
| // channel at first | |||
| auto channels = dims[0]; | |||
| if (channels == 0) { | |||
| MS_LOG(ERROR) << "channels is zero"; | |||
| return RET_ERROR; | |||
| } | |||
| size_t one_filter_size = elem_count / channels; | |||
| for (int i = 0; i < channels; i++) { | |||
| float min = FLT_MAX; | |||
| float max = -FLT_MAX; | |||
| // find min and max | |||
| for (size_t j = 0; j < one_filter_size; j++) { | |||
| auto index = j + i * one_filter_size; | |||
| if (index >= elem_count) { | |||
| MS_LOG(ERROR) << "over flow!"; | |||
| return RET_ERROR; | |||
| } | |||
| min = std::min(min, raw_datas[index]); | |||
| max = std::max(max, raw_datas[index]); | |||
| } | |||
| schema::QuantParamT quant_param; | |||
| STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "CalQuantizationParams failed" << status; | |||
| return status; | |||
| } | |||
| quant_params.emplace_back(quant_param); | |||
| // do quantization | |||
| for (uint32_t j = 0; j < one_filter_size; j++) { | |||
| auto index = j + i * one_filter_size; | |||
| if (index >= elem_count) { | |||
| MS_LOG(ERROR) << "over flow!"; | |||
| return RET_ERROR; | |||
| } | |||
| float raw_data = raw_datas[index]; | |||
| auto quant_data = QuantizeData<int8_t>(raw_data, quant_param, quant_max, quant_min); | |||
| quant_datas[index] = quant_data; | |||
| } | |||
| } | |||
| auto ret = | |||
| memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||
| return RET_ERROR; | |||
| } | |||
| weight->set_tensor_size(elem_count * sizeof(int8_t)); | |||
| } | |||
| } else { | |||
| // per layer | |||
| float min = FLT_MAX; | |||
| float max = -FLT_MIN; | |||
| for (uint32_t i = 0; i < elem_count; i++) { | |||
| // find max min | |||
| min = std::min(min, raw_datas[i]); | |||
| max = std::max(max, raw_datas[i]); | |||
| } | |||
| schema::QuantParamT quant_param; | |||
| STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "CalQuantizationParams failed" << status; | |||
| return status; | |||
| } | |||
| quant_params.emplace_back(quant_param); | |||
| // update data and datatype | |||
| for (uint32_t i = 0; i < elem_count; i++) { | |||
| float raw_data = raw_datas[i]; | |||
| auto quant_data = QuantizeData<int8_t>(raw_data, quant_param, quant_max, quant_min); | |||
| quant_datas[i] = quant_data; | |||
| } | |||
| auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||
| return RET_ERROR; | |||
| } | |||
| weight->set_tensor_size(elem_count * sizeof(int8_t)); | |||
| } | |||
| if (quant_params.empty()) { | |||
| MS_LOG(ERROR) << "quant_params empty"; | |||
| return RET_ERROR; | |||
| } | |||
| primitive_c->AddInputQuantParam(quant_params); | |||
| return RET_OK; | |||
| } | |||
| STATUS PostBitPack(float *weight, size_t shapeSize, size_t bitNum) { | |||
| auto *rawDatas = reinterpret_cast<uint8_t *>(weight); | |||
| vector<uint8_t> qDatas(rawDatas, rawDatas + shapeSize); | |||
| @@ -21,6 +21,8 @@ | |||
| #include <string> | |||
| #include <cmath> | |||
| #include <array> | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "tools/converter/quantizer/quantizer.h" | |||
| #include "src/ops/primitive_c.h" | |||
| #include "include/errorcode.h" | |||
| @@ -117,10 +119,171 @@ T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quan | |||
| return static_cast<T>(quant_data); | |||
| }(); | |||
| } | |||
| template <typename T> | |||
| STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType, | |||
| int quant_max, int quant_min, size_t bitNum = UINT8_QUANTIZATION, bool per_channel = false, | |||
| bool depth_wise = false); | |||
| int quant_max, int quant_min, size_t bitNum, bool per_channel, bool depth_wise) { | |||
| auto dims = weight->tensor_shape(); | |||
| if (per_channel) { | |||
| if (dims.size() != 4) { | |||
| MS_LOG(ERROR) << "weight dims size error: " << dims.size() << " Back to per layer."; | |||
| per_channel = false; | |||
| } else { | |||
| uint32_t channels = dims[0]; | |||
| if (channels == 0) { | |||
| MS_LOG(ERROR) << "channels is 0"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } | |||
| std::vector<schema::QuantParamT> quant_params; | |||
| size_t elem_count = weight->tensor_shape_size(); | |||
| auto *raw_datas = static_cast<float *>(weight->tensor_addr()); | |||
| if (raw_datas == nullptr) { | |||
| MS_LOG(ERROR) << "rawDatas is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<T> quant_datas(elem_count); | |||
| if (per_channel) { | |||
| // notice: | |||
| // at now for tflite model, Conv2D's weight format is KHWC, so is DepthwiseConv2D | |||
| // if TransWeightFormat is done before PostTraingingQuantization, the DepthwiseCon2D's weight is CHWK | |||
| if (depth_wise) { | |||
| // channel at last | |||
| auto channels = dims[3]; | |||
| if (channels == 0) { | |||
| MS_LOG(ERROR) << "channels is zero"; | |||
| return RET_ERROR; | |||
| } | |||
| size_t one_filter_size = elem_count / channels; | |||
| for (int i = 0; i < channels; i++) { | |||
| float min = FLT_MAX; | |||
| float max = -FLT_MAX; | |||
| // find min and max | |||
| for (size_t j = 0; j < one_filter_size; j++) { | |||
| auto index = i + j * channels; | |||
| if (index >= elem_count) { | |||
| MS_LOG(ERROR) << "over flow!"; | |||
| return RET_ERROR; | |||
| } | |||
| min = std::min(min, raw_datas[index]); | |||
| max = std::max(max, raw_datas[index]); | |||
| } | |||
| schema::QuantParamT quant_param; | |||
| STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "CalQuantizationParams failed" << status; | |||
| return status; | |||
| } | |||
| quant_params.emplace_back(quant_param); | |||
| // do quantization | |||
| for (uint32_t j = 0; j < one_filter_size; j++) { | |||
| auto index = i + j * channels; | |||
| if (index >= elem_count) { | |||
| MS_LOG(ERROR) << "over flow!"; | |||
| return RET_ERROR; | |||
| } | |||
| float raw_data = raw_datas[index]; | |||
| auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min); | |||
| quant_datas[index] = quant_data; | |||
| } | |||
| } | |||
| auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), | |||
| elem_count * sizeof(T)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||
| return RET_ERROR; | |||
| } | |||
| weight->set_tensor_size(elem_count * sizeof(T)); | |||
| } else { | |||
| // channel at first | |||
| auto channels = dims[0]; | |||
| if (channels == 0) { | |||
| MS_LOG(ERROR) << "channels is zero"; | |||
| return RET_ERROR; | |||
| } | |||
| size_t one_filter_size = elem_count / channels; | |||
| for (int i = 0; i < channels; i++) { | |||
| float min = FLT_MAX; | |||
| float max = -FLT_MAX; | |||
| // find min and max | |||
| for (size_t j = 0; j < one_filter_size; j++) { | |||
| auto index = j + i * one_filter_size; | |||
| if (index >= elem_count) { | |||
| MS_LOG(ERROR) << "over flow!"; | |||
| return RET_ERROR; | |||
| } | |||
| min = std::min(min, raw_datas[index]); | |||
| max = std::max(max, raw_datas[index]); | |||
| } | |||
| schema::QuantParamT quant_param; | |||
| STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "CalQuantizationParams failed" << status; | |||
| return status; | |||
| } | |||
| quant_params.emplace_back(quant_param); | |||
| // do quantization | |||
| for (uint32_t j = 0; j < one_filter_size; j++) { | |||
| auto index = j + i * one_filter_size; | |||
| if (index >= elem_count) { | |||
| MS_LOG(ERROR) << "over flow!"; | |||
| return RET_ERROR; | |||
| } | |||
| float raw_data = raw_datas[index]; | |||
| auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min); | |||
| quant_datas[index] = quant_data; | |||
| } | |||
| } | |||
| auto ret = | |||
| memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||
| return RET_ERROR; | |||
| } | |||
| weight->set_tensor_size(elem_count * sizeof(T)); | |||
| } | |||
| } else { | |||
| // per layer | |||
| float min = FLT_MAX; | |||
| float max = -FLT_MIN; | |||
| for (uint32_t i = 0; i < elem_count; i++) { | |||
| // find max min | |||
| min = std::min(min, raw_datas[i]); | |||
| max = std::max(max, raw_datas[i]); | |||
| } | |||
| schema::QuantParamT quant_param; | |||
| STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "CalQuantizationParams failed" << status; | |||
| return status; | |||
| } | |||
| quant_params.emplace_back(quant_param); | |||
| // update data and datatype | |||
| for (uint32_t i = 0; i < elem_count; i++) { | |||
| float raw_data = raw_datas[i]; | |||
| auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min); | |||
| quant_datas[i] = quant_data; | |||
| } | |||
| auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||
| return RET_ERROR; | |||
| } | |||
| weight->set_tensor_size(elem_count * sizeof(T)); | |||
| } | |||
| if (quant_params.empty()) { | |||
| MS_LOG(ERROR) << "quant_params empty"; | |||
| return RET_ERROR; | |||
| } | |||
| primitive_c->AddInputQuantParam(quant_params); | |||
| return RET_OK; | |||
| } | |||
| STATUS PostBitPack(float *weights, size_t shapeSize, size_t bitNum = UINT8_QUANTIZATION); | |||
| } // namespace quant | |||
| @@ -0,0 +1,148 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/quantizer/weight_quantizer.h" | |||
| #include <list> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "src/common/common.h" | |||
| #include "ir/dtype/type_id.h" | |||
| using std::string; | |||
| using std::vector; | |||
| namespace mindspore { | |||
| namespace lite { | |||
| namespace quant { | |||
| WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const string &weightSize, | |||
| const std::string &convWeightChannelThreshold, const std::string &bitNum) | |||
| : Quantizer(graph) { | |||
| auto quantSize = static_cast<size_t>(std::stoull(weightSize)); | |||
| this->bitNum = static_cast<size_t>(std::stoull(bitNum)); | |||
| auto convQuantWeightChannelThreshold = static_cast<size_t>(std::stoull(convWeightChannelThreshold)); | |||
| mStrategy.reset(new QuantStrategy(quantSize, convQuantWeightChannelThreshold)); | |||
| } | |||
| STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) { | |||
| for (auto &cnode : nodes) { | |||
| if (!mStrategy->CanConvOpQuantized(cnode)) { | |||
| continue; | |||
| } | |||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto inputNode = cnode->input(2); | |||
| if (!inputNode->isa<Parameter>()) { | |||
| return RET_ERROR; | |||
| } | |||
| auto paramNode = inputNode->cast<ParameterPtr>(); | |||
| if (!paramNode->has_default()) { | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<schema::QuantParamT> quant_params; | |||
| primitive_c->AddInputQuantParam(quant_params); | |||
| auto op_type = (schema::PrimitiveType)primitive_c->Type(); | |||
| bool depthwise = op_type == schema::PrimitiveType_DepthwiseConv2D ? true : false; | |||
| ParamValueLitePtr param_value = std::static_pointer_cast<ParamValueLite>(paramNode->default_param()); | |||
| auto status = QuantFilter<uint8_t>(param_value, primitive_c, QuantType_WeightQuant, 255, 0, | |||
| bitNum, true, depthwise); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||
| return status; | |||
| } | |||
| param_value->set_tensor_type(kNumberTypeUInt8); | |||
| primitive_c->SetQuantType(schema::QuantType_WeightQuant); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) { | |||
| for (auto &node : nodes) { | |||
| if (!mStrategy->CanMulOpQuantized(node)) { | |||
| continue; | |||
| } | |||
| ParamValueLitePtr param_value = nullptr; | |||
| for (size_t i = 1; i < node->size(); i++) { | |||
| auto inputNode = node->input(i); | |||
| if (inputNode->isa<Parameter>() == true) { | |||
| auto paramNode = inputNode->cast<ParameterPtr>(); | |||
| if ((paramNode != nullptr) && (paramNode->has_default() == true)) { | |||
| param_value = std::static_pointer_cast<ParamValueLite>(paramNode->default_param()); | |||
| if ((param_value == nullptr) || (param_value->tensor_size() == 0) | |||
| || (param_value->tensor_shape().size() != 4) | |||
| || (param_value->tensor_addr() == nullptr) | |||
| || (param_value->tensor_type() != mindspore::kNumberTypeFloat32)) { | |||
| param_value = nullptr; | |||
| continue; | |||
| } else { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (param_value == nullptr) { | |||
| MS_LOG(ERROR) << "No valid input param node !"; | |||
| return RET_ERROR;; | |||
| } | |||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(node->input(0)); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto status = QuantFilter<uint8_t>(param_value, primitive_c, QuantType_WeightQuant, 255, 0, bitNum, true, false); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||
| return status; | |||
| } | |||
| param_value->set_tensor_type(kNumberTypeUInt8); | |||
| primitive_c->SetQuantType(schema::QuantType_WeightQuant); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS WeightQuantizer::DoQuantize(FuncGraphPtr funcGraph) { | |||
| auto ret = RET_OK; | |||
| auto cnodes = funcGraph->GetOrderedCnodes(); | |||
| ret = DoConvQuantize(cnodes); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "DoConvQuantize failed :" << ret; | |||
| return ret; | |||
| } | |||
| ret = DoMulQuantize(cnodes); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "DoMulQuantize failed :" << ret; | |||
| return ret; | |||
| } | |||
| return ret; | |||
| } | |||
| } // namespace quant | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,53 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef WEIGHT_QUANTIZER_H | |||
| #define WEIGHT_QUANTIZER_H | |||
| #include <memory> | |||
| #include <list> | |||
| #include <string> | |||
| #include "tools/converter/quantizer/quantizer.h" | |||
| #include "tools/converter/quantizer/quantize_util.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/anf.h" | |||
| #include "include/model.h" | |||
| #include "base/base.h" | |||
| #include "abstract/dshape.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| namespace quant { | |||
| class WeightQuantizer : public Quantizer { | |||
| public: | |||
| WeightQuantizer(FuncGraphPtr graph, const std::string& weightSize, | |||
| const std::string& covWeightChannelThreshold, const std::string& bitNum); | |||
| ~WeightQuantizer() = default; | |||
| STATUS DoQuantize(FuncGraphPtr funcGraph) override; | |||
| STATUS DoConvQuantize(const std::list<CNodePtr> &nodes); | |||
| STATUS DoMulQuantize(const std::list<CNodePtr> &nodes); | |||
| private: | |||
| std::unique_ptr<QuantStrategy> mStrategy; | |||
| size_t bitNum; | |||
| }; | |||
| } // namespace quant | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif | |||