From: @jianghui58 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -14,7 +14,10 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <cmath> | #include <cmath> | ||||
| #include <string> | |||||
| #include <memory> | |||||
| #include "src/dequant.h" | #include "src/dequant.h" | ||||
| #include "src/huffman_decode.h" | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) { | float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) { | ||||
| @@ -34,13 +37,24 @@ float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) { | |||||
| } | } | ||||
| } | } | ||||
| void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_int_data) { | |||||
| int DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_int_data) { | |||||
| MS_ASSERT(input_tensor != nullptr); | MS_ASSERT(input_tensor != nullptr); | ||||
| MS_ASSERT(unpack_int_data != nullptr); | MS_ASSERT(unpack_int_data != nullptr); | ||||
| auto quant_params = input_tensor->quantParams(); | auto quant_params = input_tensor->quantParams(); | ||||
| if (quant_params == nullptr) { | if (quant_params == nullptr) { | ||||
| MS_LOG(ERROR) << "low bits quantparams is empty."; | MS_LOG(ERROR) << "low bits quantparams is empty."; | ||||
| return; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto enable_huffman_code = input_tensor->enableHuffmanCode(); | |||||
| if (enable_huffman_code) { | |||||
| std::string encode_str(input_tensor->data()->begin(), input_tensor->data()->end()); | |||||
| auto huffman_decode = std::make_unique<lite::HuffmanDecode>(); | |||||
| auto ret = huffman_decode->DoHuffmanDecode(encode_str, unpack_int_data); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "DoHuffmanDecode failed."; | |||||
| return ret; | |||||
| } | |||||
| return RET_OK; | |||||
| } | } | ||||
| int origin_bit = quant_params->Get(0)->numBits(); | int origin_bit = quant_params->Get(0)->numBits(); | ||||
| if (origin_bit < 8 && origin_bit > 0) { | if (origin_bit < 8 && origin_bit > 0) { | ||||
| @@ -48,6 +62,7 @@ void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_i | |||||
| } else if (origin_bit < 16 && origin_bit > 8) { | } else if (origin_bit < 16 && origin_bit > 8) { | ||||
| UnPackUtil<int16_t, uint16_t>(input_tensor, origin_bit, unpack_int_data); | UnPackUtil<int16_t, uint16_t>(input_tensor, origin_bit, unpack_int_data); | ||||
| } | } | ||||
| return RET_OK; | |||||
| } | } | ||||
| std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const std::vector<Tensor *> &in_tensors, | std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const std::vector<Tensor *> &in_tensors, | ||||
| @@ -31,7 +31,7 @@ class DequantUtil { | |||||
| public: | public: | ||||
| static float *DequantWeight(lite::Tensor *input_tensor); | static float *DequantWeight(lite::Tensor *input_tensor); | ||||
| static void UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data); | |||||
| static int UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data); | |||||
| static std::map<Tensor *, std::pair<TypeId, void *>> DequantTensor(const std::vector<Tensor *> &in_tensors, | static std::map<Tensor *, std::pair<TypeId, void *>> DequantTensor(const std::vector<Tensor *> &in_tensors, | ||||
| TypeId data_type, bool need_restore = true); | TypeId data_type, bool need_restore = true); | ||||
| @@ -110,6 +110,21 @@ class DequantUtil { | |||||
| return dequant_datas; | return dequant_datas; | ||||
| } | } | ||||
| template <typename T1, typename T2> | |||||
| static void UnpackUtil(const T1 *weight_data, int pack_size, int origin_bit, void *unpack_int_data) { | |||||
| if (weight_data == nullptr || unpack_int_data == nullptr) { | |||||
| MS_LOG(ERROR) << "data is nullptr"; | |||||
| return; | |||||
| } | |||||
| std::queue<bool> unpack_bit_data; | |||||
| size_t count = 0; | |||||
| for (int i = 0; i < pack_size; ++i) { | |||||
| T2 pack_data = (static_cast<const T2 *>(static_cast<const void *>(weight_data)))[i]; | |||||
| bool is_last = i == pack_size - 1; | |||||
| UnPackData<T1, T2>(origin_bit, pack_data, &unpack_bit_data, unpack_int_data, &count, is_last); | |||||
| } | |||||
| } | |||||
| private: | private: | ||||
| template <typename T1, typename T2> | template <typename T1, typename T2> | ||||
| static void UnPackData(int origin_bit, const T2 &packed_data, std::queue<bool> *unpack_bit_data, void *unpack_int, | static void UnPackData(int origin_bit, const T2 &packed_data, std::queue<bool> *unpack_bit_data, void *unpack_int, | ||||
| @@ -19,7 +19,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS huffman_decode::DoHuffmanDecode(const std::string &input_str, void *decoded_data) { | |||||
| STATUS HuffmanDecode::DoHuffmanDecode(const std::string &input_str, void *decoded_data) { | |||||
| if (decoded_data == nullptr) { | if (decoded_data == nullptr) { | ||||
| MS_LOG(ERROR) << "decoded_data is nullptr."; | MS_LOG(ERROR) << "decoded_data is nullptr."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -64,7 +64,7 @@ STATUS huffman_decode::DoHuffmanDecode(const std::string &input_str, void *decod | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS huffman_decode::RebuildHuffmanTree(std::string keys, std::string codes, const HuffmanNodePtr &root) { | |||||
| STATUS HuffmanDecode::RebuildHuffmanTree(std::string keys, std::string codes, const HuffmanNodePtr &root) { | |||||
| HuffmanNodePtr cur_node, tmp_node, new_node; | HuffmanNodePtr cur_node, tmp_node, new_node; | ||||
| auto huffman_keys = Str2Vec(std::move(keys)); | auto huffman_keys = Str2Vec(std::move(keys)); | ||||
| @@ -121,7 +121,7 @@ STATUS huffman_decode::RebuildHuffmanTree(std::string keys, std::string codes, c | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS huffman_decode::DoHuffmanDecompress(HuffmanNodePtr root, std::string encoded_data, std::string *decoded_str) { | |||||
| STATUS HuffmanDecode::DoHuffmanDecompress(HuffmanNodePtr root, std::string encoded_data, std::string *decoded_str) { | |||||
| HuffmanNodePtr cur_node = root; | HuffmanNodePtr cur_node = root; | ||||
| bool pseudo_eof = false; | bool pseudo_eof = false; | ||||
| size_t pos = 0; | size_t pos = 0; | ||||
| @@ -157,7 +157,7 @@ STATUS huffman_decode::DoHuffmanDecompress(HuffmanNodePtr root, std::string enco | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| huffman_decode::~huffman_decode() { | |||||
| HuffmanDecode::~HuffmanDecode() { | |||||
| for (auto &node : this->huffman_nodes_) { | for (auto &node : this->huffman_nodes_) { | ||||
| delete node; | delete node; | ||||
| } | } | ||||
| @@ -38,11 +38,11 @@ struct HuffmanNode { | |||||
| }; | }; | ||||
| using HuffmanNodePtr = HuffmanNode *; | using HuffmanNodePtr = HuffmanNode *; | ||||
| class huffman_decode { | |||||
| class HuffmanDecode { | |||||
| public: | public: | ||||
| huffman_decode() = default; | |||||
| HuffmanDecode() = default; | |||||
| ~huffman_decode(); | |||||
| ~HuffmanDecode(); | |||||
| STATUS DoHuffmanDecode(const std::string &input_str, void *decoded_data); | STATUS DoHuffmanDecode(const std::string &input_str, void *decoded_data); | ||||
| @@ -28,7 +28,6 @@ | |||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/lite_model.h" | #include "src/lite_model.h" | ||||
| #include "src/dequant.h" | #include "src/dequant.h" | ||||
| #include "src/huffman_decode.h" | |||||
| #if SUPPORT_NPU | #if SUPPORT_NPU | ||||
| #include "src/runtime/agent/npu/npu_manager.h" | #include "src/runtime/agent/npu/npu_manager.h" | ||||
| #include "src/runtime/agent/npu/optimizer/npu_pass_manager.h" | #include "src/runtime/agent/npu/optimizer/npu_pass_manager.h" | ||||
| @@ -96,13 +95,6 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde | |||||
| int org_size = dst_tensor->Size(); | int org_size = dst_tensor->Size(); | ||||
| return (pack_size != org_size) && (data_type == kNumberTypeInt8 || data_type == kNumberTypeInt16); | return (pack_size != org_size) && (data_type == kNumberTypeInt8 || data_type == kNumberTypeInt16); | ||||
| }; | }; | ||||
| auto NeedHuffmanDecode = [&src_tensor, &dst_tensor]() -> bool { | |||||
| auto data_type = src_tensor->dataType(); | |||||
| auto enable_huffman_code = src_tensor->enableHuffmanCode(); | |||||
| int pack_size = src_tensor->data()->size(); | |||||
| int org_size = dst_tensor->Size(); | |||||
| return (pack_size != org_size) && (data_type == kNumberTypeInt8) && enable_huffman_code; | |||||
| }; | |||||
| auto src_category = TensorCategory(src_tensor); | auto src_category = TensorCategory(src_tensor); | ||||
| if ((src_category == Tensor::Category::CONST_TENSOR || src_category == Tensor::Category::CONST_SCALAR) && | if ((src_category == Tensor::Category::CONST_TENSOR || src_category == Tensor::Category::CONST_SCALAR) && | ||||
| src_tensor->data() != nullptr && src_tensor->data()->size() > 0) { | src_tensor->data() != nullptr && src_tensor->data()->size() > 0) { | ||||
| @@ -116,21 +108,6 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } else { | } else { | ||||
| if (NeedHuffmanDecode()) { | |||||
| auto dst_data = dst_tensor->MutableData(); | |||||
| if (dst_data == nullptr) { | |||||
| MS_LOG(ERROR) << "Data from tensor is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::string encode_str(src_tensor->data()->begin(), src_tensor->data()->end()); | |||||
| auto huffman_decode = std::make_unique<lite::huffman_decode>(); | |||||
| auto ret = huffman_decode->DoHuffmanDecode(encode_str, dst_data); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "DoHuffmanDecode failed."; | |||||
| return ret; | |||||
| } | |||||
| copyed_tensor_idxes_.emplace_back(tensor_index); | |||||
| } | |||||
| if (WeightTensorNeedCopy(model, tensor_index)) { | if (WeightTensorNeedCopy(model, tensor_index)) { | ||||
| auto dst_data = dst_tensor->MutableData(); | auto dst_data = dst_tensor->MutableData(); | ||||
| if (dst_data == nullptr) { | if (dst_data == nullptr) { | ||||
| @@ -138,7 +115,11 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| if (NeedUnPack()) { | if (NeedUnPack()) { | ||||
| DequantUtil::UnPackToInt(src_tensor, dst_data); | |||||
| auto ret = DequantUtil::UnPackToInt(src_tensor, dst_data); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "unpack to int failed."; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| } else { | } else { | ||||
| memcpy(dst_data, src_tensor->data()->data(), dst_tensor->Size()); | memcpy(dst_data, src_tensor->data()->data(), dst_tensor->Size()); | ||||
| } | } | ||||
| @@ -148,9 +129,13 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde | |||||
| auto dst_data = dst_tensor->MutableData(); | auto dst_data = dst_tensor->MutableData(); | ||||
| if (dst_data == nullptr) { | if (dst_data == nullptr) { | ||||
| MS_LOG(ERROR) << "Data from tensor is nullptr"; | MS_LOG(ERROR) << "Data from tensor is nullptr"; | ||||
| return RET_NULL_PTR; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto ret = DequantUtil::UnPackToInt(src_tensor, dst_data); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "unpack to int failed."; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| DequantUtil::UnPackToInt(src_tensor, dst_data); | |||||
| copyed_tensor_idxes_.emplace_back(tensor_index); | copyed_tensor_idxes_.emplace_back(tensor_index); | ||||
| } else { | } else { | ||||
| dst_tensor->set_data(const_cast<unsigned char *>(src_tensor->data()->data())); | dst_tensor->set_data(const_cast<unsigned char *>(src_tensor->data()->data())); | ||||
| @@ -227,8 +227,8 @@ function Run_Converter() { | |||||
| fi | fi | ||||
| model_name=`echo ${weight_quant_line_info}|awk -F ' ' '{print $1}'` | model_name=`echo ${weight_quant_line_info}|awk -F ' ' '{print $1}'` | ||||
| echo ${model_name} >> "${run_converter_log_file}" | echo ${model_name} >> "${run_converter_log_file}" | ||||
| echo './converter_lite --fmk=TFLITE --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}'--quantType=WeightQuant --bitNum=8 --quantWeightChannel=0 --enableHuffmanCode=true' >> "${run_converter_log_file}" | |||||
| ./converter_lite --fmk=TFLITE --modelFile=$models_path/${model_name} --outputFile=${ms_models_path}/${model_name}_weightquant --quantType=WeightQuant --bitNum=8 --quantWeightChannel=0 --enableHuffmanCode=true | |||||
| echo './converter_lite --fmk=TFLITE --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}'--quantType=WeightQuant --bitNum=8 --quantWeightChannel=0' >> "${run_converter_log_file}" | |||||
| ./converter_lite --fmk=TFLITE --modelFile=$models_path/${model_name} --outputFile=${ms_models_path}/${model_name}_weightquant --quantType=WeightQuant --bitNum=8 --quantWeightChannel=0 | |||||
| if [ $? = 0 ]; then | if [ $? = 0 ]; then | ||||
| converter_result='converter weight_quant '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file} | converter_result='converter weight_quant '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file} | ||||
| else | else | ||||
| @@ -217,26 +217,14 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Fla | |||||
| const FuncGraphPtr &new_graph) { | const FuncGraphPtr &new_graph) { | ||||
| // quant | // quant | ||||
| if (config->quantType == schema::QuantType_PostTraining) { | if (config->quantType == schema::QuantType_PostTraining) { | ||||
| if (!quant::WeightQuantizer::IsPosNum(config->bitNum)) { | |||||
| MS_LOG(ERROR) << "bitNum must be valid pos num."; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->mQuantizer = | |||||
| std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, std::stoi(config->bitNum)); | |||||
| this->mQuantizer = std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, config->bitNum); | |||||
| if (mQuantizer == nullptr) { | if (mQuantizer == nullptr) { | ||||
| MS_LOG(ERROR) << "New PostTrainingQuantizer failed"; | MS_LOG(ERROR) << "New PostTrainingQuantizer failed"; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } else if (config->quantType == schema::QuantType_WeightQuant) { | } else if (config->quantType == schema::QuantType_WeightQuant) { | ||||
| if (quant::WeightQuantizer::WeightQuantInputCheck(config) != RET_OK) { | |||||
| MS_LOG(ERROR) << "weight quant input param error"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, config->configFile, config->quantWeightSize, | |||||
| config->quantWeightChannel, config->bitNum); | |||||
| this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, *config); | |||||
| if (mQuantizer == nullptr) { | if (mQuantizer == nullptr) { | ||||
| MS_LOG(ERROR) << "New WeightQuantizer failed"; | MS_LOG(ERROR) << "New WeightQuantizer failed"; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); | ||||
| @@ -255,10 +243,15 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Fla | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int AnfTransform::DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph) { | |||||
| if (config->quantType == schema::QuantType_WeightQuant && config->bitNum == "8" && config->enableHuffmanCode) { | |||||
| auto huffman_encode = std::make_unique<lite::huffman_encode>(); | |||||
| auto status = huffman_encode->DoHuffmanEncode(new_graph); | |||||
| int AnfTransform::DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph, | |||||
| bool enableHuffmanCode) { | |||||
| if (config->quantType == schema::QuantType_WeightQuant && enableHuffmanCode) { | |||||
| if (config->bitNum < 16 && config->bitNum > 8) { | |||||
| MS_LOG(WARNING) << "don't support huffman encode when 8 < bitNum < 16 currently."; | |||||
| return RET_OK; | |||||
| } | |||||
| auto huffman_encode = std::make_unique<lite::HuffmanEncode>(); | |||||
| auto status = huffman_encode->DoHuffmanEncode(new_graph, config->bitNum); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "Huffman encode failed."; | MS_LOG(ERROR) << "Huffman encode failed."; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| @@ -322,7 +315,7 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| status = DoHuffmanEncode(config, new_graph); | |||||
| status = DoHuffmanEncode(config, new_graph, false); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "Do HuffmanCode failed."; | MS_LOG(ERROR) << "Do HuffmanCode failed."; | ||||
| return nullptr; | return nullptr; | ||||
| @@ -59,7 +59,7 @@ class AnfTransform { | |||||
| int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config, const FuncGraphPtr &new_graph); | int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config, const FuncGraphPtr &new_graph); | ||||
| int DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph); | |||||
| int DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph, bool enableHuffmanCode); | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -38,16 +38,12 @@ Flags::Flags() { | |||||
| "UINT8 | DEFAULT", | "UINT8 | DEFAULT", | ||||
| "DEFAULT"); | "DEFAULT"); | ||||
| AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. PostTraining | WeightQuant", ""); | AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. PostTraining | WeightQuant", ""); | ||||
| AddFlag(&Flags::bitNum, "bitNum", "Weight quantization bitNum", "8"); | |||||
| AddFlag(&Flags::quantWeightSize, "quantWeightSize", "Weight quantization size threshold", "0"); | |||||
| AddFlag(&Flags::quantWeightChannel, "quantWeightChannel", "Channel threshold for weight quantization", "16"); | |||||
| AddFlag(&Flags::bitNumIn, "bitNum", "Weight quantization bitNum", "8"); | |||||
| AddFlag(&Flags::quantWeightSizeIn, "quantWeightSize", "Weight quantization size threshold", "0"); | |||||
| AddFlag(&Flags::quantWeightChannelIn, "quantWeightChannel", "Channel threshold for weight quantization", "16"); | |||||
| AddFlag(&Flags::configFile, "configFile", "Configuration for post-training.", ""); | AddFlag(&Flags::configFile, "configFile", "Configuration for post-training.", ""); | ||||
| AddFlag(&Flags::enableHuffmanCodeIn, "enableHuffmanCode", | |||||
| "whether the weight quant model is going to use huffman code." | |||||
| "true | false", | |||||
| "false"); | |||||
| AddFlag(&Flags::trainModelIn, "trainModel", | AddFlag(&Flags::trainModelIn, "trainModel", | ||||
| "whether the model is going to be trained on device." | |||||
| "whether the model is going to be trained on device. " | |||||
| "true | false", | "true | false", | ||||
| "false"); | "false"); | ||||
| } | } | ||||
| @@ -107,7 +103,41 @@ int Flags::InitFmk() { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int Flags::InitQuantType() { | |||||
| bool Flags::IsValidNum(const std::string &str, int *num) { | |||||
| char *ptr; | |||||
| *num = strtol(str.c_str(), &ptr, 10); | |||||
| return ptr == (str.c_str() + str.size()); | |||||
| } | |||||
| int Flags::QuantParamInputCheck() { | |||||
| if (!Flags::IsValidNum(this->quantWeightChannelIn, &this->quantWeightChannel)) { | |||||
| std::cerr << "quantWeightChannel should be a valid number."; | |||||
| return RET_INPUT_PARAM_INVALID; | |||||
| } | |||||
| if (this->quantWeightChannel < 0) { | |||||
| std::cerr << "quantWeightChannel should be greater than or equal to zero."; | |||||
| return RET_INPUT_PARAM_INVALID; | |||||
| } | |||||
| if (!Flags::IsValidNum(this->quantWeightSizeIn, &this->quantWeightSize)) { | |||||
| std::cerr << "quantWeightSize should be a valid number."; | |||||
| return RET_INPUT_PARAM_INVALID; | |||||
| } | |||||
| if (this->quantWeightSize < 0) { | |||||
| std::cerr << "quantWeightSize should be greater than or equal to zero."; | |||||
| return RET_INPUT_PARAM_INVALID; | |||||
| } | |||||
| if (!Flags::IsValidNum(this->bitNumIn, &this->bitNum)) { | |||||
| std::cerr << "bitNum should be a valid number."; | |||||
| return RET_INPUT_PARAM_INVALID; | |||||
| } | |||||
| if (this->bitNum <= 0 || this->bitNum > 16) { | |||||
| std::cerr << "bitNum should be greater than zero and lesser than 16 currently."; | |||||
| return RET_INPUT_PARAM_INVALID; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int Flags::InitQuantParam() { | |||||
| if (this->quantTypeIn == "WeightQuant") { | if (this->quantTypeIn == "WeightQuant") { | ||||
| this->quantType = QuantType_WeightQuant; | this->quantType = QuantType_WeightQuant; | ||||
| } else if (this->quantTypeIn == "PostTraining") { | } else if (this->quantTypeIn == "PostTraining") { | ||||
| @@ -118,19 +148,9 @@ int Flags::InitQuantType() { | |||||
| std::cerr << "INPUT ILLEGAL: quantType must be WeightQuant|PostTraining"; | std::cerr << "INPUT ILLEGAL: quantType must be WeightQuant|PostTraining"; | ||||
| return RET_INPUT_PARAM_INVALID; | return RET_INPUT_PARAM_INVALID; | ||||
| } | } | ||||
| return RET_OK; | |||||
| } | |||||
| int Flags::InitHuffmanCode() { | |||||
| if (this->enableHuffmanCodeIn == "true") { | |||||
| this->enableHuffmanCode = true; | |||||
| } else if (this->enableHuffmanCodeIn == "false") { | |||||
| this->enableHuffmanCode = false; | |||||
| } else { | |||||
| std::cerr << "INPUT ILLEGAL: trainModel must be true|false "; | |||||
| return RET_INPUT_PARAM_INVALID; | |||||
| } | |||||
| return RET_OK; | |||||
| auto ret = QuantParamInputCheck(); | |||||
| return ret; | |||||
| } | } | ||||
| int Flags::InitTrainModel() { | int Flags::InitTrainModel() { | ||||
| @@ -218,15 +238,9 @@ int Flags::Init(int argc, const char **argv) { | |||||
| return RET_INPUT_PARAM_INVALID; | return RET_INPUT_PARAM_INVALID; | ||||
| } | } | ||||
| ret = InitQuantType(); | |||||
| if (ret != RET_OK) { | |||||
| std::cerr << "Init quant type failed."; | |||||
| return RET_INPUT_PARAM_INVALID; | |||||
| } | |||||
| ret = InitHuffmanCode(); | |||||
| ret = InitQuantParam(); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| std::cerr << "Init huffman code failed."; | |||||
| std::cerr << "Init quant param failed."; | |||||
| return RET_INPUT_PARAM_INVALID; | return RET_INPUT_PARAM_INVALID; | ||||
| } | } | ||||
| @@ -49,9 +49,11 @@ class Flags : public virtual mindspore::lite::FlagParser { | |||||
| int InitFmk(); | int InitFmk(); | ||||
| int InitQuantType(); | |||||
| bool IsValidNum(const std::string &str, int *num); | |||||
| int InitHuffmanCode(); | |||||
| int QuantParamInputCheck(); | |||||
| int InitQuantParam(); | |||||
| int InitTrainModel(); | int InitTrainModel(); | ||||
| @@ -76,12 +78,13 @@ class Flags : public virtual mindspore::lite::FlagParser { | |||||
| TypeId inputDataType; | TypeId inputDataType; | ||||
| TypeId outputDataType; | TypeId outputDataType; | ||||
| // used for post-trainning-weight | // used for post-trainning-weight | ||||
| std::string quantWeightSize; | |||||
| std::string bitNum; | |||||
| std::string quantWeightSizeIn; | |||||
| int quantWeightSize; | |||||
| std::string bitNumIn; | |||||
| int bitNum; | |||||
| std::string configFile; | std::string configFile; | ||||
| std::string quantWeightChannel; | |||||
| std::string enableHuffmanCodeIn; | |||||
| bool enableHuffmanCode = false; | |||||
| std::string quantWeightChannelIn; | |||||
| int quantWeightChannel; | |||||
| std::string trainModelIn; | std::string trainModelIn; | ||||
| bool trainModel = false; | bool trainModel = false; | ||||
| }; | }; | ||||
| @@ -18,18 +18,51 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "securec/include/securec.h" | |||||
| #include "src/param_value_lite.h" | |||||
| #include "src/dequant.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) { | |||||
| STATUS HuffmanEncode::GetParamValueLitePtr(const std::shared_ptr<AnfNode> &input_node, ParamValueLitePtr *param_value) { | |||||
| if (!input_node->isa<Parameter>()) { | |||||
| return RET_CONTINUE; | |||||
| } | |||||
| auto abstract_base = input_node->abstract(); | |||||
| if (abstract_base == nullptr) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << input_node->fullname_with_scope(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << input_node->fullname_with_scope(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base); | |||||
| if (abstract_tensor->element() == nullptr) { | |||||
| MS_LOG(ERROR) << "abstract tensor element is nullptr, " << input_node->fullname_with_scope(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto tensor_type = abstract_tensor->element()->GetTypeTrack(); | |||||
| MS_ASSERT(tensor_type != nullptr); | |||||
| auto tensor_type_id = tensor_type->type_id(); | |||||
| if (tensor_type_id != kNumberTypeInt8) { | |||||
| return RET_CONTINUE; | |||||
| } | |||||
| auto param_node = input_node->cast<ParameterPtr>(); | |||||
| if (param_node == nullptr) { | |||||
| MS_LOG(ERROR) << "parameter node is nullptr, " << input_node->fullname_with_scope(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!param_node->has_default()) { | |||||
| MS_LOG(WARNING) << "param_node don't have default: " << input_node->fullname_with_scope(); | |||||
| return RET_CONTINUE; | |||||
| } | |||||
| *param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param()); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS HuffmanEncode::DoHuffmanEncode(const FuncGraphPtr &func_graph, const int &bit_num) { | |||||
| auto cnodes = func_graph->GetOrderedCnodes(); | auto cnodes = func_graph->GetOrderedCnodes(); | ||||
| STATUS status; | |||||
| for (auto &cnode : cnodes) { | for (auto &cnode : cnodes) { | ||||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | ||||
| if (primitive_c == nullptr) { | if (primitive_c == nullptr) { | ||||
| @@ -41,45 +74,33 @@ STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) { | |||||
| } | } | ||||
| for (size_t i = 1; i < cnode->inputs().size(); i++) { | for (size_t i = 1; i < cnode->inputs().size(); i++) { | ||||
| auto input_node = cnode->input(i); | auto input_node = cnode->input(i); | ||||
| if (!input_node->isa<Parameter>()) { | |||||
| continue; | |||||
| } | |||||
| auto abstract_base = input_node->abstract(); | |||||
| if (abstract_base == nullptr) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << input_node->fullname_with_scope(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << input_node->fullname_with_scope(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base); | |||||
| if (abstract_tensor->element() == nullptr) { | |||||
| MS_LOG(ERROR) << "abstract tensor element is nullptr, " << input_node->fullname_with_scope(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto tensor_type = abstract_tensor->element()->GetTypeTrack(); | |||||
| MS_ASSERT(tensor_type != nullptr); | |||||
| auto tensor_type_id = tensor_type->type_id(); | |||||
| if (tensor_type_id != kNumberTypeInt8) { | |||||
| ParamValueLitePtr param_value; | |||||
| auto status = GetParamValueLitePtr(input_node, ¶m_value); | |||||
| if (status == RET_CONTINUE) { | |||||
| continue; | continue; | ||||
| } | |||||
| auto param_node = input_node->cast<ParameterPtr>(); | |||||
| if (param_node == nullptr) { | |||||
| MS_LOG(ERROR) << "parameter node is nullptr, " << input_node->fullname_with_scope(); | |||||
| } else if (status == RET_ERROR) { | |||||
| MS_LOG(ERROR) << "Get param value lite ptr failed. " << cnode->fullname_with_scope(); | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (!param_node->has_default()) { | |||||
| MS_LOG(WARNING) << "param_node don't have default: " << cnode->fullname_with_scope(); | |||||
| continue; | |||||
| } | |||||
| ParamValueLitePtr param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param()); | |||||
| size_t elem_count = param_value->tensor_shape_size(); | size_t elem_count = param_value->tensor_shape_size(); | ||||
| size_t packed_size = param_value->tensor_size(); | |||||
| auto *raw_datas = static_cast<int8_t *>(param_value->tensor_addr()); | auto *raw_datas = static_cast<int8_t *>(param_value->tensor_addr()); | ||||
| if (raw_datas == nullptr) { | if (raw_datas == nullptr) { | ||||
| MS_LOG(ERROR) << "rawDatas is nullptr"; | MS_LOG(ERROR) << "rawDatas is nullptr"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (bit_num < 8 && bit_num > 0) { | |||||
| auto dst_data = new (std::nothrow) int8_t[elem_count]; | |||||
| if (dst_data == nullptr) { | |||||
| MS_LOG(ERROR) << "new int8_t[] failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| DequantUtil::UnpackUtil<int8_t, uint8_t>(raw_datas, packed_size, bit_num, dst_data); | |||||
| if (memcpy_s(raw_datas, elem_count, dst_data, elem_count) != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy_s failed."; | |||||
| return RET_MEMORY_FAILED; | |||||
| } | |||||
| } | |||||
| HuffmanPriorityQueue pq; | HuffmanPriorityQueue pq; | ||||
| status = GetHuffmanPriorityQueue(raw_datas, elem_count, &pq); | status = GetHuffmanPriorityQueue(raw_datas, elem_count, &pq); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| @@ -97,12 +118,14 @@ STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) { | |||||
| return status; | return status; | ||||
| } | } | ||||
| size_t ch_size = huffman_encoded_str_.length(); | size_t ch_size = huffman_encoded_str_.length(); | ||||
| if (ch_size < elem_count) { | |||||
| if (ch_size < packed_size) { | |||||
| auto encode_data = new (std::nothrow) char[ch_size]; | auto encode_data = new (std::nothrow) char[ch_size]; | ||||
| if (encode_data == nullptr) { | if (encode_data == nullptr) { | ||||
| MS_LOG(ERROR) << "new char[] failed."; | MS_LOG(ERROR) << "new char[] failed."; | ||||
| delete[] raw_datas; | |||||
| return RET_MEMORY_FAILED; | return RET_MEMORY_FAILED; | ||||
| } | } | ||||
| delete[] raw_datas; | |||||
| if (memcpy_s(encode_data, ch_size, huffman_encoded_str_.c_str(), ch_size) != EOK) { | if (memcpy_s(encode_data, ch_size, huffman_encoded_str_.c_str(), ch_size) != EOK) { | ||||
| MS_LOG(ERROR) << "memcpy_s failed."; | MS_LOG(ERROR) << "memcpy_s failed."; | ||||
| delete[] encode_data; | delete[] encode_data; | ||||
| @@ -118,7 +141,7 @@ STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) { | |||||
| return RET_SUCCESS; | return RET_SUCCESS; | ||||
| } | } | ||||
| STATUS huffman_encode::GetHuffmanPriorityQueue(const int8_t *data, const size_t data_size, HuffmanPriorityQueue *pq) { | |||||
| STATUS HuffmanEncode::GetHuffmanPriorityQueue(const int8_t *data, const size_t data_size, HuffmanPriorityQueue *pq) { | |||||
| MS_ASSERT(data != nullptr); | MS_ASSERT(data != nullptr); | ||||
| std::map<int8_t, size_t> freq_map; | std::map<int8_t, size_t> freq_map; | ||||
| @@ -166,7 +189,7 @@ STATUS huffman_encode::GetHuffmanPriorityQueue(const int8_t *data, const size_t | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| void huffman_encode::GenerateHuffmanTable(const HuffmanNodePtr node, bool is_left_node) { | |||||
| void HuffmanEncode::GenerateHuffmanTable(const HuffmanNodePtr node, bool is_left_node) { | |||||
| if (is_left_node) { | if (is_left_node) { | ||||
| node->code = node->parent->code + "0"; | node->code = node->parent->code + "0"; | ||||
| } else { | } else { | ||||
| @@ -185,7 +208,7 @@ void huffman_encode::GenerateHuffmanTable(const HuffmanNodePtr node, bool is_lef | |||||
| } | } | ||||
| } | } | ||||
| STATUS huffman_encode::BuildHuffmanTree(HuffmanPriorityQueue *pq) { | |||||
| STATUS HuffmanEncode::BuildHuffmanTree(HuffmanPriorityQueue *pq) { | |||||
| HuffmanNodePtr root = nullptr; | HuffmanNodePtr root = nullptr; | ||||
| while (!pq->empty()) { | while (!pq->empty()) { | ||||
| @@ -228,7 +251,7 @@ STATUS huffman_encode::BuildHuffmanTree(HuffmanPriorityQueue *pq) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS huffman_encode::DoHuffmanCompress(const int8_t *input_datas, const size_t data_size) { | |||||
| STATUS HuffmanEncode::DoHuffmanCompress(const int8_t *input_datas, const size_t data_size) { | |||||
| unsigned char out_c; | unsigned char out_c; | ||||
| string code_str; | string code_str; | ||||
| std::map<int, string>::iterator iter; | std::map<int, string>::iterator iter; | ||||
| @@ -270,7 +293,7 @@ STATUS huffman_encode::DoHuffmanCompress(const int8_t *input_datas, const size_t | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| huffman_encode::~huffman_encode() { | |||||
| HuffmanEncode::~HuffmanEncode() { | |||||
| for (auto &node : this->huffman_nodes_) { | for (auto &node : this->huffman_nodes_) { | ||||
| delete node; | delete node; | ||||
| } | } | ||||
| @@ -23,9 +23,12 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <queue> | #include <queue> | ||||
| #include <map> | #include <map> | ||||
| #include <memory> | |||||
| #include <fstream> | #include <fstream> | ||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #include "src/ops/primitive_c.h" | #include "src/ops/primitive_c.h" | ||||
| #include "securec/include/securec.h" | |||||
| #include "src/param_value_lite.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -49,13 +52,15 @@ struct cmp { | |||||
| }; | }; | ||||
| using HuffmanPriorityQueue = std::priority_queue<HuffmanNodePtr, std::vector<HuffmanNodePtr>, cmp>; | using HuffmanPriorityQueue = std::priority_queue<HuffmanNodePtr, std::vector<HuffmanNodePtr>, cmp>; | ||||
| class huffman_encode { | |||||
| class HuffmanEncode { | |||||
| public: | public: | ||||
| huffman_encode() = default; | |||||
| HuffmanEncode() = default; | |||||
| ~huffman_encode(); | |||||
| ~HuffmanEncode(); | |||||
| STATUS DoHuffmanEncode(const FuncGraphPtr &func_graph); | |||||
| STATUS GetParamValueLitePtr(const std::shared_ptr<AnfNode> &input_node, ParamValueLitePtr *param_value); | |||||
| STATUS DoHuffmanEncode(const FuncGraphPtr &func_graph, const int &bit_num); | |||||
| private: | private: | ||||
| std::map<int, std::string> huffman_table_; | std::map<int, std::string> huffman_table_; | ||||
| @@ -25,52 +25,16 @@ using std::string; | |||||
| using std::vector; | using std::vector; | ||||
| namespace mindspore::lite::quant { | namespace mindspore::lite::quant { | ||||
| bool WeightQuantizer::IsPosNum(const std::string &str) { | |||||
| for (size_t i = 0; i < str.size(); i++) { | |||||
| if (str.at(i) < '0' || str.at(i) > '9') { | |||||
| return false; | |||||
| } | |||||
| if (str.at(i) == '0' && i == 0 && str.size() != 1) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) { | |||||
| MS_ASSERT(config != nullptr); | |||||
| if (!WeightQuantizer::IsPosNum(config->quantWeightChannel)) { | |||||
| MS_LOG(ERROR) << "quantWeightChannel must be valid pos num."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!WeightQuantizer::IsPosNum(config->quantWeightSize)) { | |||||
| MS_LOG(ERROR) << "quantWeightSize must be valid pos num."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!WeightQuantizer::IsPosNum(config->bitNum)) { | |||||
| MS_LOG(ERROR) << "bitNum must be valid pos num."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| int bitNum = std::stoi(config->bitNum); | |||||
| if (bitNum <= 0 || bitNum > 16) { | |||||
| MS_LOG(ERROR) << "bitNum should be more than 0 and less than 16 currently."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config) : Quantizer(graph) { | WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config) : Quantizer(graph) { | ||||
| quant_strategy_ = std::make_unique<QuantStrategy>(0, 0); | quant_strategy_ = std::make_unique<QuantStrategy>(0, 0); | ||||
| config_param_ = config; | config_param_ = config; | ||||
| } | } | ||||
| WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const std::string &config_file, const string &weightSize, | |||||
| const std::string &convWeightChannelThreshold, const std::string &bitNum) | |||||
| : Quantizer(graph) { | |||||
| this->config_file_ = config_file; | |||||
| auto quantSize = static_cast<size_t>(std::stoull(weightSize)); | |||||
| this->bit_num_ = static_cast<size_t>(std::stoull(bitNum)); | |||||
| auto convQuantWeightChannelThreshold = static_cast<size_t>(std::stoull(convWeightChannelThreshold)); | |||||
| WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const converter::Flags &config) : Quantizer(graph) { | |||||
| this->config_file_ = config.configFile; | |||||
| auto quantSize = config.quantWeightSize; | |||||
| this->bit_num_ = config.bitNum; | |||||
| auto convQuantWeightChannelThreshold = config.quantWeightChannel; | |||||
| quant_strategy_ = std::make_unique<QuantStrategy>(quantSize, convQuantWeightChannelThreshold); | quant_strategy_ = std::make_unique<QuantStrategy>(quantSize, convQuantWeightChannelThreshold); | ||||
| quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1; | quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1; | ||||
| quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1)); | quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1)); | ||||
| @@ -222,7 +186,7 @@ STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) { | |||||
| STATUS WeightQuantizer::DoLstmQuantize(CNodePtr cnode) { | |||||
| MS_ASSERT(cnode != nullptr); | MS_ASSERT(cnode != nullptr); | ||||
| auto op_name = cnode->fullname_with_scope(); | auto op_name = cnode->fullname_with_scope(); | ||||
| @@ -233,110 +197,29 @@ STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) { | |||||
| MS_LOG(ERROR) << op_name << " inputs is " << cnode->inputs().size(); | MS_LOG(ERROR) << op_name << " inputs is " << cnode->inputs().size(); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| { | |||||
| auto weight_i = cnode->input(2); | |||||
| ParameterPtr param_node; | |||||
| ParamValueLitePtr param_value; | |||||
| GetLiteParameter(weight_i, ¶m_node, ¶m_value); | |||||
| if (param_node == nullptr || param_value == nullptr) { | |||||
| MS_LOG(ERROR) << "GetLiteParameter error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { | |||||
| MS_LOG(WARNING) << "param_value tensor type is: " << param_value->tensor_type() << " not quant"; | |||||
| return RET_OK; | |||||
| } | |||||
| if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) { | |||||
| MS_LOG(INFO) << op_name << " weight_i cnt: " << param_value->tensor_size() / 4 << " < " | |||||
| << quant_strategy_->mWeightSize; | |||||
| return RET_OK; | |||||
| } | |||||
| auto status = RET_ERROR; | |||||
| if (type_id_ == kNumberTypeInt8) { | |||||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, | |||||
| false, 1); | |||||
| } else if (type_id_ == kNumberTypeInt16) { | |||||
| status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, | |||||
| false, 1); | |||||
| } | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||||
| return status; | |||||
| } | |||||
| status = SetAbstract(param_value, param_node, primitive_c); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "SetAbstract failed : " << status; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto status = ProcessLstmWeightByIndex(cnode, primitive_c, 2); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Process lstm weight i failed."; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| { | |||||
| auto weight_h = cnode->input(3); | |||||
| ParameterPtr param_node; | |||||
| ParamValueLitePtr param_value; | |||||
| GetLiteParameter(weight_h, ¶m_node, ¶m_value); | |||||
| if (param_node == nullptr || param_value == nullptr) { | |||||
| MS_LOG(ERROR) << "GetLiteParameter error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { | |||||
| MS_LOG(ERROR) << "param_value tensor type is: " << param_value->tensor_type() << " not quant"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto status = RET_ERROR; | |||||
| if (type_id_ == kNumberTypeInt8) { | |||||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, | |||||
| false, 2); | |||||
| } else if (type_id_ == kNumberTypeInt16) { | |||||
| status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, | |||||
| false, 2); | |||||
| } | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||||
| return status; | |||||
| } | |||||
| status = SetAbstract(param_value, param_node, primitive_c); | |||||
| status = ProcessLstmWeightByIndex(cnode, primitive_c, 3); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Process lstm weight h failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (cnode->inputs().size() > 4) { | |||||
| status = ProcessLstmWeightByIndex(cnode, primitive_c, 4); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "SetAbstract failed : " << status; | |||||
| MS_LOG(ERROR) << "Process lstm bias failed."; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| { | |||||
| if (cnode->inputs().size() > 4) { | |||||
| auto bias = cnode->input(4); | |||||
| ParameterPtr param_node; | |||||
| ParamValueLitePtr param_value; | |||||
| GetLiteParameter(bias, ¶m_node, ¶m_value); | |||||
| if (param_node == nullptr || param_value == nullptr) { | |||||
| MS_LOG(ERROR) << "GetLiteParameter error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { | |||||
| MS_LOG(ERROR) << "param_value tensor type is: " << param_value->tensor_type() << " not quant"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto status = RET_ERROR; | |||||
| if (type_id_ == kNumberTypeInt8) { | |||||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, | |||||
| false, 3); | |||||
| } else if (type_id_ == kNumberTypeInt16) { | |||||
| status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, | |||||
| false, 3); | |||||
| } | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||||
| return status; | |||||
| } | |||||
| status = SetAbstract(param_value, param_node, primitive_c); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "SetAbstract failed : " << status; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| return status; | |||||
| } | } | ||||
| STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) { | |||||
| STATUS WeightQuantizer::DoGatherQuantize(CNodePtr cnode) { | |||||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | ||||
| MS_ASSERT(primitive_c != nullptr); | MS_ASSERT(primitive_c != nullptr); | ||||
| @@ -375,6 +258,46 @@ STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS WeightQuantizer::ProcessLstmWeightByIndex(const CNodePtr &cnode, const std::shared_ptr<PrimitiveC> &primitive_c, | |||||
| const int &index) { | |||||
| auto op_name = cnode->fullname_with_scope(); | |||||
| auto weight_i = cnode->input(index); | |||||
| ParameterPtr param_node; | |||||
| ParamValueLitePtr param_value; | |||||
| GetLiteParameter(weight_i, ¶m_node, ¶m_value); | |||||
| if (param_node == nullptr || param_value == nullptr) { | |||||
| MS_LOG(ERROR) << "GetLiteParameter error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { | |||||
| MS_LOG(WARNING) << "param_value tensor type is: " << param_value->tensor_type() << " not quant"; | |||||
| return RET_OK; | |||||
| } | |||||
| if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) { | |||||
| MS_LOG(INFO) << op_name << " weight_i cnt: " << param_value->tensor_size() / 4 << " < " | |||||
| << quant_strategy_->mWeightSize; | |||||
| return RET_OK; | |||||
| } | |||||
| auto status = RET_ERROR; | |||||
| if (type_id_ == kNumberTypeInt8) { | |||||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, | |||||
| false, index - 1); | |||||
| } else if (type_id_ == kNumberTypeInt16) { | |||||
| status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, | |||||
| false, index - 1); | |||||
| } | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||||
| return status; | |||||
| } | |||||
| status = SetAbstract(param_value, param_node, primitive_c); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "SetAbstract failed : " << status; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| constexpr float relative_tolerance = 1e-5; | constexpr float relative_tolerance = 1e-5; | ||||
| constexpr float abs_tolerance = 1e-4; | constexpr float abs_tolerance = 1e-4; | ||||
| @@ -510,37 +433,28 @@ STATUS WeightQuantizer::RunFp32Graph(FuncGraphPtr func_graph) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) { | |||||
| // 0.2 Parse input calib files | |||||
| auto status = CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "CollectCalibInputs fail"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| MS_LOG(DEBUG) << "run fp32 model"; | |||||
| status = RunFp32Graph(func_graph); | |||||
| if (status != RET_OK) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| STATUS WeightQuantizer::DoMixedQuantize(const FuncGraphPtr &func_graph) { | |||||
| auto cnodes = func_graph->GetOrderedCnodes(); | auto cnodes = func_graph->GetOrderedCnodes(); | ||||
| int status = RET_OK; | |||||
| for (auto &cnode : cnodes) { | for (auto &cnode : cnodes) { | ||||
| auto op_type = NodePrimitiveType(cnode); | auto op_type = NodePrimitiveType(cnode); | ||||
| if (op_type == schema::PrimitiveType_Lstm) { | if (op_type == schema::PrimitiveType_Lstm) { | ||||
| status = DoLstmQuntize(cnode); | |||||
| status = DoLstmQuantize(cnode); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "DoLstmQuntize error"; | |||||
| MS_LOG(ERROR) << "DoLstmQuantize error"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } else if (op_type == schema::PrimitiveType_Gather) { | } else if (op_type == schema::PrimitiveType_Gather) { | ||||
| status = DoGatherQuntize(cnode); | |||||
| status = DoGatherQuantize(cnode); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "DoGatherQuntize error"; | |||||
| MS_LOG(ERROR) << "DoGatherQuantize error"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| return status; | |||||
| } | |||||
| STATUS WeightQuantizer::CheckImageCnt() { | |||||
| auto image_cnt = images_.at(0).size(); | auto image_cnt = images_.at(0).size(); | ||||
| if (!config_param_.input_shapes.empty()) { | if (!config_param_.input_shapes.empty()) { | ||||
| if (config_param_.input_shapes.size() != image_cnt) { | if (config_param_.input_shapes.size() != image_cnt) { | ||||
| @@ -548,7 +462,62 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| return RET_OK; | |||||
| } | |||||
| STATUS WeightQuantizer::GetParamNodeAndValue(const std::shared_ptr<AnfNode> &input_node, const std::string &op_name, | |||||
| ParameterPtr *param_node, ParamValueLitePtr *param_value) { | |||||
| if (!input_node->isa<Parameter>()) { | |||||
| MS_LOG(WARNING) << op_name << " the second input is not parameter"; | |||||
| return RET_CONTINUE; | |||||
| } | |||||
| *param_node = input_node->cast<ParameterPtr>(); | |||||
| if (!(*param_node)->has_default()) { | |||||
| MS_LOG(WARNING) << op_name << " the second input can not convert to parameter"; | |||||
| return RET_CONTINUE; | |||||
| } | |||||
| *param_value = std::static_pointer_cast<ParamValueLite>((*param_node)->default_param()); | |||||
| if (*param_value == nullptr) { | |||||
| MS_LOG(WARNING) << op_name << " the second input can not convert to parameter"; | |||||
| return RET_CONTINUE; | |||||
| } | |||||
| if ((*param_value)->tensor_type() != TypeId::kNumberTypeFloat32) { | |||||
| MS_LOG(WARNING) << op_name << " the second input type is not float"; | |||||
| return RET_CONTINUE; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS WeightQuantizer::TryQuant(const int &bit_num_t, const ParameterPtr ¶m_node, | |||||
| const ParamValueLitePtr ¶m_value, const std::shared_ptr<PrimitiveC> &primitive_c) { | |||||
| int status; | |||||
| type_id_ = TypeId::kNumberTypeInt8; | |||||
| int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1; | |||||
| int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1)); | |||||
| if (type_id_ == TypeId::kNumberTypeInt8) { | |||||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t, | |||||
| bit_num_t, true); | |||||
| } else if (type_id_ == TypeId::kNumberTypeInt16) { | |||||
| status = QuantFilter<int16_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t, | |||||
| bit_num_t, true); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "unexpected type_id_: " << type_id_; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "quant filter failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| status = SetAbstract(param_value, param_node, primitive_c); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "SetAbstract failed : " << status; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return status; | |||||
| } | |||||
| STATUS WeightQuantizer::DoQuantSearch(const FuncGraphPtr &func_graph) { | |||||
| auto cnodes = func_graph->GetOrderedCnodes(); | |||||
| auto image_cnt = images_.at(0).size(); | |||||
| int status = RET_OK; | |||||
| for (auto iter = cnodes.end(); iter != cnodes.begin();) { | for (auto iter = cnodes.end(); iter != cnodes.begin();) { | ||||
| auto cnode = *(--iter); | auto cnode = *(--iter); | ||||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | ||||
| @@ -561,22 +530,10 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) { | |||||
| << " type: " << schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive_c->Type()); | << " type: " << schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive_c->Type()); | ||||
| if (quant_strategy_->CanConvOpQuantized(cnode) || quant_strategy_->CanMulOpQuantized(cnode)) { | if (quant_strategy_->CanConvOpQuantized(cnode) || quant_strategy_->CanMulOpQuantized(cnode)) { | ||||
| auto input_node = cnode->input(2); | auto input_node = cnode->input(2); | ||||
| if (!input_node->isa<Parameter>()) { | |||||
| MS_LOG(WARNING) << op_name << " the second input is not parameter"; | |||||
| continue; | |||||
| } | |||||
| auto param_node = input_node->cast<ParameterPtr>(); | |||||
| if (!param_node->has_default()) { | |||||
| MS_LOG(WARNING) << op_name << " the second input can not convert to parameter"; | |||||
| continue; | |||||
| } | |||||
| auto param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param()); | |||||
| if (param_value == nullptr) { | |||||
| MS_LOG(WARNING) << op_name << " the second input can not convert to parameter"; | |||||
| continue; | |||||
| } | |||||
| if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { | |||||
| MS_LOG(WARNING) << op_name << " the second input type is not float"; | |||||
| ParameterPtr param_node; | |||||
| ParamValueLitePtr param_value; | |||||
| status = GetParamNodeAndValue(input_node, op_name, ¶m_node, ¶m_value); | |||||
| if (status == RET_CONTINUE) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| // copy origin data in case to recover | // copy origin data in case to recover | ||||
| @@ -591,27 +548,9 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) { | |||||
| } | } | ||||
| // 1. try quant | // 1. try quant | ||||
| for (int bit_num_t = 2; bit_num_t <= 8; bit_num_t++) { | for (int bit_num_t = 2; bit_num_t <= 8; bit_num_t++) { | ||||
| type_id_ = TypeId::kNumberTypeInt8; | |||||
| int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1; | |||||
| int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1)); | |||||
| if (type_id_ == TypeId::kNumberTypeInt8) { | |||||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, | |||||
| quant_min_t, bit_num_t, true); | |||||
| } else if (type_id_ == TypeId::kNumberTypeInt16) { | |||||
| status = QuantFilter<int16_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, | |||||
| quant_min_t, bit_num_t, true); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "unexpected type_id_: " << type_id_; | |||||
| return RET_ERROR; | |||||
| } | |||||
| status = TryQuant(bit_num_t, param_node, param_value, primitive_c); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "quant filter fail."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| status = SetAbstract(param_value, param_node, primitive_c); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "SetAbstract failed : " << status; | |||||
| MS_LOG(ERROR) << "TryQuant failed."; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| // 2. evaluate the quant | // 2. evaluate the quant | ||||
| @@ -679,6 +618,41 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) { | |||||
| free(origin_data); | free(origin_data); | ||||
| } // if: conv and matmul | } // if: conv and matmul | ||||
| } // end loop: all cnode | } // end loop: all cnode | ||||
| return status; | |||||
| } | |||||
| STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) { | |||||
| // 0.2 Parse input calib files | |||||
| auto status = CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "CollectCalibInputs failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| status = RunFp32Graph(func_graph); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "RunFp32Graph failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| status = DoMixedQuantize(func_graph); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "DoMixedQuantize failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| status = CheckImageCnt(); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "CheckImageCnt failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| status = DoQuantSearch(func_graph); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "DoQuantSearch failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| for (const auto &kv : opname_bit_) { | for (const auto &kv : opname_bit_) { | ||||
| MS_LOG(INFO) << "op: " << kv.first << " bit:" << kv.second; | MS_LOG(INFO) << "op: " << kv.first << " bit:" << kv.second; | ||||
| } | } | ||||
| @@ -709,15 +683,15 @@ STATUS WeightQuantizer::DoFixedQuant(FuncGraphPtr func_graph) { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } else if (op_type == schema::PrimitiveType_Lstm) { | } else if (op_type == schema::PrimitiveType_Lstm) { | ||||
| auto status = DoLstmQuntize(cnode); | |||||
| auto status = DoLstmQuantize(cnode); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "DoLstmQuntize error"; | |||||
| MS_LOG(ERROR) << "DoLstmQuantize error"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } else if (op_type == schema::PrimitiveType_Gather) { | } else if (op_type == schema::PrimitiveType_Gather) { | ||||
| auto status = DoGatherQuntize(cnode); | |||||
| auto status = DoGatherQuantize(cnode); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "DoGatherQuntize error"; | |||||
| MS_LOG(ERROR) << "DoGatherQuantize error"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -36,18 +36,18 @@ | |||||
| namespace mindspore::lite::quant { | namespace mindspore::lite::quant { | ||||
| class WeightQuantizer : public Quantizer { | class WeightQuantizer : public Quantizer { | ||||
| public: | public: | ||||
| WeightQuantizer(FuncGraphPtr graph, const std::string &config_file, const std::string &weightSize, | |||||
| const std::string &covWeightChannelThreshold, const std::string &bitNum); | |||||
| WeightQuantizer(FuncGraphPtr graph, const converter::Flags &config); | |||||
| WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config); | WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config); | ||||
| ~WeightQuantizer(); | ~WeightQuantizer(); | ||||
| STATUS DoQuantize(FuncGraphPtr func_graph) override; | STATUS DoQuantize(FuncGraphPtr func_graph) override; | ||||
| STATUS DoConvQuantize(CNodePtr); | STATUS DoConvQuantize(CNodePtr); | ||||
| STATUS DoMulQuantize(CNodePtr); | STATUS DoMulQuantize(CNodePtr); | ||||
| STATUS DoLstmQuntize(CNodePtr cnode); | |||||
| STATUS DoGatherQuntize(CNodePtr cnode); | |||||
| static STATUS WeightQuantInputCheck(const converter::Flags *config); | |||||
| static bool IsPosNum(const std::string &str); | |||||
| STATUS DoLstmQuantize(CNodePtr cnode); | |||||
| STATUS DoGatherQuantize(CNodePtr cnode); | |||||
| STATUS ProcessLstmWeightByIndex(const CNodePtr &cnode, const std::shared_ptr<PrimitiveC> &primitive_c, | |||||
| const int &index); | |||||
| int quant_max_{127}; | int quant_max_{127}; | ||||
| int quant_min_{-128}; | int quant_min_{-128}; | ||||
| @@ -66,6 +66,14 @@ class WeightQuantizer : public Quantizer { | |||||
| STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr<PrimitiveC> primitive_c); | STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr<PrimitiveC> primitive_c); | ||||
| STATUS DoFixedQuant(FuncGraphPtr); | STATUS DoFixedQuant(FuncGraphPtr); | ||||
| STATUS RunFp32Graph(FuncGraphPtr); | STATUS RunFp32Graph(FuncGraphPtr); | ||||
| STATUS DoMixedQuantize(const FuncGraphPtr &func_graph); | |||||
| STATUS CheckImageCnt(); | |||||
| STATUS GetParamNodeAndValue(const std::shared_ptr<AnfNode> &input_node, const std::string &op_name, | |||||
| ParameterPtr *param_node, ParamValueLitePtr *param_value); | |||||
| STATUS TryQuant(const int &bit_num_t, const ParameterPtr ¶m_node, const ParamValueLitePtr ¶m_value, | |||||
| const std::shared_ptr<PrimitiveC> &primitive_c); | |||||
| STATUS DoQuantSearch(const FuncGraphPtr &func_graph); | |||||
| }; | }; | ||||
| } // namespace mindspore::lite::quant | } // namespace mindspore::lite::quant | ||||
| #endif | #endif | ||||