From: @jianghui58 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -14,7 +14,10 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include <cmath> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "src/dequant.h" | |||
| #include "src/huffman_decode.h" | |||
| namespace mindspore::lite { | |||
| float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) { | |||
| @@ -34,13 +37,24 @@ float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) { | |||
| } | |||
| } | |||
| void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_int_data) { | |||
| int DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_int_data) { | |||
| MS_ASSERT(input_tensor != nullptr); | |||
| MS_ASSERT(unpack_int_data != nullptr); | |||
| auto quant_params = input_tensor->quantParams(); | |||
| if (quant_params == nullptr) { | |||
| MS_LOG(ERROR) << "low bits quantparams is empty."; | |||
| return; | |||
| return RET_ERROR; | |||
| } | |||
| auto enable_huffman_code = input_tensor->enableHuffmanCode(); | |||
| if (enable_huffman_code) { | |||
| std::string encode_str(input_tensor->data()->begin(), input_tensor->data()->end()); | |||
| auto huffman_decode = std::make_unique<lite::HuffmanDecode>(); | |||
| auto ret = huffman_decode->DoHuffmanDecode(encode_str, unpack_int_data); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "DoHuffmanDecode failed."; | |||
| return ret; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int origin_bit = quant_params->Get(0)->numBits(); | |||
| if (origin_bit < 8 && origin_bit > 0) { | |||
| @@ -48,6 +62,7 @@ void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_i | |||
| } else if (origin_bit < 16 && origin_bit > 8) { | |||
| UnPackUtil<int16_t, uint16_t>(input_tensor, origin_bit, unpack_int_data); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const std::vector<Tensor *> &in_tensors, | |||
| @@ -31,7 +31,7 @@ class DequantUtil { | |||
| public: | |||
| static float *DequantWeight(lite::Tensor *input_tensor); | |||
| static void UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data); | |||
| static int UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data); | |||
| static std::map<Tensor *, std::pair<TypeId, void *>> DequantTensor(const std::vector<Tensor *> &in_tensors, | |||
| TypeId data_type, bool need_restore = true); | |||
| @@ -110,6 +110,21 @@ class DequantUtil { | |||
| return dequant_datas; | |||
| } | |||
| template <typename T1, typename T2> | |||
| static void UnpackUtil(const T1 *weight_data, int pack_size, int origin_bit, void *unpack_int_data) { | |||
| if (weight_data == nullptr || unpack_int_data == nullptr) { | |||
| MS_LOG(ERROR) << "data is nullptr"; | |||
| return; | |||
| } | |||
| std::queue<bool> unpack_bit_data; | |||
| size_t count = 0; | |||
| for (int i = 0; i < pack_size; ++i) { | |||
| T2 pack_data = (static_cast<const T2 *>(static_cast<const void *>(weight_data)))[i]; | |||
| bool is_last = i == pack_size - 1; | |||
| UnPackData<T1, T2>(origin_bit, pack_data, &unpack_bit_data, unpack_int_data, &count, is_last); | |||
| } | |||
| } | |||
| private: | |||
| template <typename T1, typename T2> | |||
| static void UnPackData(int origin_bit, const T2 &packed_data, std::queue<bool> *unpack_bit_data, void *unpack_int, | |||
| @@ -19,7 +19,7 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS huffman_decode::DoHuffmanDecode(const std::string &input_str, void *decoded_data) { | |||
| STATUS HuffmanDecode::DoHuffmanDecode(const std::string &input_str, void *decoded_data) { | |||
| if (decoded_data == nullptr) { | |||
| MS_LOG(ERROR) << "decoded_data is nullptr."; | |||
| return RET_ERROR; | |||
| @@ -64,7 +64,7 @@ STATUS huffman_decode::DoHuffmanDecode(const std::string &input_str, void *decod | |||
| return RET_OK; | |||
| } | |||
| STATUS huffman_decode::RebuildHuffmanTree(std::string keys, std::string codes, const HuffmanNodePtr &root) { | |||
| STATUS HuffmanDecode::RebuildHuffmanTree(std::string keys, std::string codes, const HuffmanNodePtr &root) { | |||
| HuffmanNodePtr cur_node, tmp_node, new_node; | |||
| auto huffman_keys = Str2Vec(std::move(keys)); | |||
| @@ -121,7 +121,7 @@ STATUS huffman_decode::RebuildHuffmanTree(std::string keys, std::string codes, c | |||
| return RET_OK; | |||
| } | |||
| STATUS huffman_decode::DoHuffmanDecompress(HuffmanNodePtr root, std::string encoded_data, std::string *decoded_str) { | |||
| STATUS HuffmanDecode::DoHuffmanDecompress(HuffmanNodePtr root, std::string encoded_data, std::string *decoded_str) { | |||
| HuffmanNodePtr cur_node = root; | |||
| bool pseudo_eof = false; | |||
| size_t pos = 0; | |||
| @@ -157,7 +157,7 @@ STATUS huffman_decode::DoHuffmanDecompress(HuffmanNodePtr root, std::string enco | |||
| return RET_OK; | |||
| } | |||
| huffman_decode::~huffman_decode() { | |||
| HuffmanDecode::~HuffmanDecode() { | |||
| for (auto &node : this->huffman_nodes_) { | |||
| delete node; | |||
| } | |||
| @@ -38,11 +38,11 @@ struct HuffmanNode { | |||
| }; | |||
| using HuffmanNodePtr = HuffmanNode *; | |||
| class huffman_decode { | |||
| class HuffmanDecode { | |||
| public: | |||
| huffman_decode() = default; | |||
| HuffmanDecode() = default; | |||
| ~huffman_decode(); | |||
| ~HuffmanDecode(); | |||
| STATUS DoHuffmanDecode(const std::string &input_str, void *decoded_data); | |||
| @@ -28,7 +28,6 @@ | |||
| #include "src/kernel_registry.h" | |||
| #include "src/lite_model.h" | |||
| #include "src/dequant.h" | |||
| #include "src/huffman_decode.h" | |||
| #if SUPPORT_NPU | |||
| #include "src/runtime/agent/npu/npu_manager.h" | |||
| #include "src/runtime/agent/npu/optimizer/npu_pass_manager.h" | |||
| @@ -96,13 +95,6 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde | |||
| int org_size = dst_tensor->Size(); | |||
| return (pack_size != org_size) && (data_type == kNumberTypeInt8 || data_type == kNumberTypeInt16); | |||
| }; | |||
| auto NeedHuffmanDecode = [&src_tensor, &dst_tensor]() -> bool { | |||
| auto data_type = src_tensor->dataType(); | |||
| auto enable_huffman_code = src_tensor->enableHuffmanCode(); | |||
| int pack_size = src_tensor->data()->size(); | |||
| int org_size = dst_tensor->Size(); | |||
| return (pack_size != org_size) && (data_type == kNumberTypeInt8) && enable_huffman_code; | |||
| }; | |||
| auto src_category = TensorCategory(src_tensor); | |||
| if ((src_category == Tensor::Category::CONST_TENSOR || src_category == Tensor::Category::CONST_SCALAR) && | |||
| src_tensor->data() != nullptr && src_tensor->data()->size() > 0) { | |||
| @@ -116,21 +108,6 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde | |||
| return RET_ERROR; | |||
| } | |||
| } else { | |||
| if (NeedHuffmanDecode()) { | |||
| auto dst_data = dst_tensor->MutableData(); | |||
| if (dst_data == nullptr) { | |||
| MS_LOG(ERROR) << "Data from tensor is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::string encode_str(src_tensor->data()->begin(), src_tensor->data()->end()); | |||
| auto huffman_decode = std::make_unique<lite::huffman_decode>(); | |||
| auto ret = huffman_decode->DoHuffmanDecode(encode_str, dst_data); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "DoHuffmanDecode failed."; | |||
| return ret; | |||
| } | |||
| copyed_tensor_idxes_.emplace_back(tensor_index); | |||
| } | |||
| if (WeightTensorNeedCopy(model, tensor_index)) { | |||
| auto dst_data = dst_tensor->MutableData(); | |||
| if (dst_data == nullptr) { | |||
| @@ -138,7 +115,11 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (NeedUnPack()) { | |||
| DequantUtil::UnPackToInt(src_tensor, dst_data); | |||
| auto ret = DequantUtil::UnPackToInt(src_tensor, dst_data); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "unpack to int failed."; | |||
| return RET_NULL_PTR; | |||
| } | |||
| } else { | |||
| memcpy(dst_data, src_tensor->data()->data(), dst_tensor->Size()); | |||
| } | |||
| @@ -148,9 +129,13 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde | |||
| auto dst_data = dst_tensor->MutableData(); | |||
| if (dst_data == nullptr) { | |||
| MS_LOG(ERROR) << "Data from tensor is nullptr"; | |||
| return RET_NULL_PTR; | |||
| return RET_ERROR; | |||
| } | |||
| auto ret = DequantUtil::UnPackToInt(src_tensor, dst_data); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "unpack to int failed."; | |||
| return RET_ERROR; | |||
| } | |||
| DequantUtil::UnPackToInt(src_tensor, dst_data); | |||
| copyed_tensor_idxes_.emplace_back(tensor_index); | |||
| } else { | |||
| dst_tensor->set_data(const_cast<unsigned char *>(src_tensor->data()->data())); | |||
| @@ -227,8 +227,8 @@ function Run_Converter() { | |||
| fi | |||
| model_name=`echo ${weight_quant_line_info}|awk -F ' ' '{print $1}'` | |||
| echo ${model_name} >> "${run_converter_log_file}" | |||
| echo './converter_lite --fmk=TFLITE --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}'--quantType=WeightQuant --bitNum=8 --quantWeightChannel=0 --enableHuffmanCode=true' >> "${run_converter_log_file}" | |||
| ./converter_lite --fmk=TFLITE --modelFile=$models_path/${model_name} --outputFile=${ms_models_path}/${model_name}_weightquant --quantType=WeightQuant --bitNum=8 --quantWeightChannel=0 --enableHuffmanCode=true | |||
| echo './converter_lite --fmk=TFLITE --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}'--quantType=WeightQuant --bitNum=8 --quantWeightChannel=0' >> "${run_converter_log_file}" | |||
| ./converter_lite --fmk=TFLITE --modelFile=$models_path/${model_name} --outputFile=${ms_models_path}/${model_name}_weightquant --quantType=WeightQuant --bitNum=8 --quantWeightChannel=0 | |||
| if [ $? = 0 ]; then | |||
| converter_result='converter weight_quant '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file} | |||
| else | |||
| @@ -217,26 +217,14 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Fla | |||
| const FuncGraphPtr &new_graph) { | |||
| // quant | |||
| if (config->quantType == schema::QuantType_PostTraining) { | |||
| if (!quant::WeightQuantizer::IsPosNum(config->bitNum)) { | |||
| MS_LOG(ERROR) << "bitNum must be valid pos num."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||
| return RET_ERROR; | |||
| } | |||
| this->mQuantizer = | |||
| std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, std::stoi(config->bitNum)); | |||
| this->mQuantizer = std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, config->bitNum); | |||
| if (mQuantizer == nullptr) { | |||
| MS_LOG(ERROR) << "New PostTrainingQuantizer failed"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); | |||
| return RET_ERROR; | |||
| } | |||
| } else if (config->quantType == schema::QuantType_WeightQuant) { | |||
| if (quant::WeightQuantizer::WeightQuantInputCheck(config) != RET_OK) { | |||
| MS_LOG(ERROR) << "weight quant input param error"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||
| return RET_ERROR; | |||
| } | |||
| this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, config->configFile, config->quantWeightSize, | |||
| config->quantWeightChannel, config->bitNum); | |||
| this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, *config); | |||
| if (mQuantizer == nullptr) { | |||
| MS_LOG(ERROR) << "New WeightQuantizer failed"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); | |||
| @@ -255,10 +243,15 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Fla | |||
| return RET_OK; | |||
| } | |||
| int AnfTransform::DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph) { | |||
| if (config->quantType == schema::QuantType_WeightQuant && config->bitNum == "8" && config->enableHuffmanCode) { | |||
| auto huffman_encode = std::make_unique<lite::huffman_encode>(); | |||
| auto status = huffman_encode->DoHuffmanEncode(new_graph); | |||
| int AnfTransform::DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph, | |||
| bool enableHuffmanCode) { | |||
| if (config->quantType == schema::QuantType_WeightQuant && enableHuffmanCode) { | |||
| if (config->bitNum < 16 && config->bitNum > 8) { | |||
| MS_LOG(WARNING) << "don't support huffman encode when 8 < bitNum < 16 currently."; | |||
| return RET_OK; | |||
| } | |||
| auto huffman_encode = std::make_unique<lite::HuffmanEncode>(); | |||
| auto status = huffman_encode->DoHuffmanEncode(new_graph, config->bitNum); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Huffman encode failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| @@ -322,7 +315,7 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap | |||
| return nullptr; | |||
| } | |||
| status = DoHuffmanEncode(config, new_graph); | |||
| status = DoHuffmanEncode(config, new_graph, false); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Do HuffmanCode failed."; | |||
| return nullptr; | |||
| @@ -59,7 +59,7 @@ class AnfTransform { | |||
| int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config, const FuncGraphPtr &new_graph); | |||
| int DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph); | |||
| int DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph, bool enableHuffmanCode); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -38,16 +38,12 @@ Flags::Flags() { | |||
| "UINT8 | DEFAULT", | |||
| "DEFAULT"); | |||
| AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. PostTraining | WeightQuant", ""); | |||
| AddFlag(&Flags::bitNum, "bitNum", "Weight quantization bitNum", "8"); | |||
| AddFlag(&Flags::quantWeightSize, "quantWeightSize", "Weight quantization size threshold", "0"); | |||
| AddFlag(&Flags::quantWeightChannel, "quantWeightChannel", "Channel threshold for weight quantization", "16"); | |||
| AddFlag(&Flags::bitNumIn, "bitNum", "Weight quantization bitNum", "8"); | |||
| AddFlag(&Flags::quantWeightSizeIn, "quantWeightSize", "Weight quantization size threshold", "0"); | |||
| AddFlag(&Flags::quantWeightChannelIn, "quantWeightChannel", "Channel threshold for weight quantization", "16"); | |||
| AddFlag(&Flags::configFile, "configFile", "Configuration for post-training.", ""); | |||
| AddFlag(&Flags::enableHuffmanCodeIn, "enableHuffmanCode", | |||
| "whether the weight quant model is going to use huffman code." | |||
| "true | false", | |||
| "false"); | |||
| AddFlag(&Flags::trainModelIn, "trainModel", | |||
| "whether the model is going to be trained on device." | |||
| "whether the model is going to be trained on device. " | |||
| "true | false", | |||
| "false"); | |||
| } | |||
| @@ -107,7 +103,41 @@ int Flags::InitFmk() { | |||
| return RET_OK; | |||
| } | |||
| int Flags::InitQuantType() { | |||
| bool Flags::IsValidNum(const std::string &str, int *num) { | |||
| char *ptr; | |||
| *num = strtol(str.c_str(), &ptr, 10); | |||
| return ptr == (str.c_str() + str.size()); | |||
| } | |||
| int Flags::QuantParamInputCheck() { | |||
| if (!Flags::IsValidNum(this->quantWeightChannelIn, &this->quantWeightChannel)) { | |||
| std::cerr << "quantWeightChannel should be a valid number."; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| if (this->quantWeightChannel < 0) { | |||
| std::cerr << "quantWeightChannel should be greater than or equal to zero."; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| if (!Flags::IsValidNum(this->quantWeightSizeIn, &this->quantWeightSize)) { | |||
| std::cerr << "quantWeightSize should be a valid number."; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| if (this->quantWeightSize < 0) { | |||
| std::cerr << "quantWeightSize should be greater than or equal to zero."; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| if (!Flags::IsValidNum(this->bitNumIn, &this->bitNum)) { | |||
| std::cerr << "bitNum should be a valid number."; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| if (this->bitNum <= 0 || this->bitNum > 16) { | |||
| std::cerr << "bitNum should be greater than zero and lesser than 16 currently."; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int Flags::InitQuantParam() { | |||
| if (this->quantTypeIn == "WeightQuant") { | |||
| this->quantType = QuantType_WeightQuant; | |||
| } else if (this->quantTypeIn == "PostTraining") { | |||
| @@ -118,19 +148,9 @@ int Flags::InitQuantType() { | |||
| std::cerr << "INPUT ILLEGAL: quantType must be WeightQuant|PostTraining"; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int Flags::InitHuffmanCode() { | |||
| if (this->enableHuffmanCodeIn == "true") { | |||
| this->enableHuffmanCode = true; | |||
| } else if (this->enableHuffmanCodeIn == "false") { | |||
| this->enableHuffmanCode = false; | |||
| } else { | |||
| std::cerr << "INPUT ILLEGAL: trainModel must be true|false "; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| return RET_OK; | |||
| auto ret = QuantParamInputCheck(); | |||
| return ret; | |||
| } | |||
| int Flags::InitTrainModel() { | |||
| @@ -218,15 +238,9 @@ int Flags::Init(int argc, const char **argv) { | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| ret = InitQuantType(); | |||
| if (ret != RET_OK) { | |||
| std::cerr << "Init quant type failed."; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| ret = InitHuffmanCode(); | |||
| ret = InitQuantParam(); | |||
| if (ret != RET_OK) { | |||
| std::cerr << "Init huffman code failed."; | |||
| std::cerr << "Init quant param failed."; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| @@ -49,9 +49,11 @@ class Flags : public virtual mindspore::lite::FlagParser { | |||
| int InitFmk(); | |||
| int InitQuantType(); | |||
| bool IsValidNum(const std::string &str, int *num); | |||
| int InitHuffmanCode(); | |||
| int QuantParamInputCheck(); | |||
| int InitQuantParam(); | |||
| int InitTrainModel(); | |||
| @@ -76,12 +78,13 @@ class Flags : public virtual mindspore::lite::FlagParser { | |||
| TypeId inputDataType; | |||
| TypeId outputDataType; | |||
| // used for post-trainning-weight | |||
| std::string quantWeightSize; | |||
| std::string bitNum; | |||
| std::string quantWeightSizeIn; | |||
| int quantWeightSize; | |||
| std::string bitNumIn; | |||
| int bitNum; | |||
| std::string configFile; | |||
| std::string quantWeightChannel; | |||
| std::string enableHuffmanCodeIn; | |||
| bool enableHuffmanCode = false; | |||
| std::string quantWeightChannelIn; | |||
| int quantWeightChannel; | |||
| std::string trainModelIn; | |||
| bool trainModel = false; | |||
| }; | |||
| @@ -18,18 +18,51 @@ | |||
| #include <utility> | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "securec/include/securec.h" | |||
| #include "src/param_value_lite.h" | |||
| #include "src/dequant.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) { | |||
| STATUS HuffmanEncode::GetParamValueLitePtr(const std::shared_ptr<AnfNode> &input_node, ParamValueLitePtr *param_value) { | |||
| if (!input_node->isa<Parameter>()) { | |||
| return RET_CONTINUE; | |||
| } | |||
| auto abstract_base = input_node->abstract(); | |||
| if (abstract_base == nullptr) { | |||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << input_node->fullname_with_scope(); | |||
| return RET_ERROR; | |||
| } | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) { | |||
| MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << input_node->fullname_with_scope(); | |||
| return RET_ERROR; | |||
| } | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base); | |||
| if (abstract_tensor->element() == nullptr) { | |||
| MS_LOG(ERROR) << "abstract tensor element is nullptr, " << input_node->fullname_with_scope(); | |||
| return RET_ERROR; | |||
| } | |||
| auto tensor_type = abstract_tensor->element()->GetTypeTrack(); | |||
| MS_ASSERT(tensor_type != nullptr); | |||
| auto tensor_type_id = tensor_type->type_id(); | |||
| if (tensor_type_id != kNumberTypeInt8) { | |||
| return RET_CONTINUE; | |||
| } | |||
| auto param_node = input_node->cast<ParameterPtr>(); | |||
| if (param_node == nullptr) { | |||
| MS_LOG(ERROR) << "parameter node is nullptr, " << input_node->fullname_with_scope(); | |||
| return RET_ERROR; | |||
| } | |||
| if (!param_node->has_default()) { | |||
| MS_LOG(WARNING) << "param_node don't have default: " << input_node->fullname_with_scope(); | |||
| return RET_CONTINUE; | |||
| } | |||
| *param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param()); | |||
| return RET_OK; | |||
| } | |||
| STATUS HuffmanEncode::DoHuffmanEncode(const FuncGraphPtr &func_graph, const int &bit_num) { | |||
| auto cnodes = func_graph->GetOrderedCnodes(); | |||
| STATUS status; | |||
| for (auto &cnode : cnodes) { | |||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| if (primitive_c == nullptr) { | |||
| @@ -41,45 +74,33 @@ STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) { | |||
| } | |||
| for (size_t i = 1; i < cnode->inputs().size(); i++) { | |||
| auto input_node = cnode->input(i); | |||
| if (!input_node->isa<Parameter>()) { | |||
| continue; | |||
| } | |||
| auto abstract_base = input_node->abstract(); | |||
| if (abstract_base == nullptr) { | |||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << input_node->fullname_with_scope(); | |||
| return RET_ERROR; | |||
| } | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) { | |||
| MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << input_node->fullname_with_scope(); | |||
| return RET_ERROR; | |||
| } | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base); | |||
| if (abstract_tensor->element() == nullptr) { | |||
| MS_LOG(ERROR) << "abstract tensor element is nullptr, " << input_node->fullname_with_scope(); | |||
| return RET_ERROR; | |||
| } | |||
| auto tensor_type = abstract_tensor->element()->GetTypeTrack(); | |||
| MS_ASSERT(tensor_type != nullptr); | |||
| auto tensor_type_id = tensor_type->type_id(); | |||
| if (tensor_type_id != kNumberTypeInt8) { | |||
| ParamValueLitePtr param_value; | |||
| auto status = GetParamValueLitePtr(input_node, ¶m_value); | |||
| if (status == RET_CONTINUE) { | |||
| continue; | |||
| } | |||
| auto param_node = input_node->cast<ParameterPtr>(); | |||
| if (param_node == nullptr) { | |||
| MS_LOG(ERROR) << "parameter node is nullptr, " << input_node->fullname_with_scope(); | |||
| } else if (status == RET_ERROR) { | |||
| MS_LOG(ERROR) << "Get param value lite ptr failed. " << cnode->fullname_with_scope(); | |||
| return RET_ERROR; | |||
| } | |||
| if (!param_node->has_default()) { | |||
| MS_LOG(WARNING) << "param_node don't have default: " << cnode->fullname_with_scope(); | |||
| continue; | |||
| } | |||
| ParamValueLitePtr param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param()); | |||
| size_t elem_count = param_value->tensor_shape_size(); | |||
| size_t packed_size = param_value->tensor_size(); | |||
| auto *raw_datas = static_cast<int8_t *>(param_value->tensor_addr()); | |||
| if (raw_datas == nullptr) { | |||
| MS_LOG(ERROR) << "rawDatas is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| if (bit_num < 8 && bit_num > 0) { | |||
| auto dst_data = new (std::nothrow) int8_t[elem_count]; | |||
| if (dst_data == nullptr) { | |||
| MS_LOG(ERROR) << "new int8_t[] failed"; | |||
| return RET_ERROR; | |||
| } | |||
| DequantUtil::UnpackUtil<int8_t, uint8_t>(raw_datas, packed_size, bit_num, dst_data); | |||
| if (memcpy_s(raw_datas, elem_count, dst_data, elem_count) != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed."; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| } | |||
| HuffmanPriorityQueue pq; | |||
| status = GetHuffmanPriorityQueue(raw_datas, elem_count, &pq); | |||
| if (status != RET_OK) { | |||
| @@ -97,12 +118,14 @@ STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) { | |||
| return status; | |||
| } | |||
| size_t ch_size = huffman_encoded_str_.length(); | |||
| if (ch_size < elem_count) { | |||
| if (ch_size < packed_size) { | |||
| auto encode_data = new (std::nothrow) char[ch_size]; | |||
| if (encode_data == nullptr) { | |||
| MS_LOG(ERROR) << "new char[] failed."; | |||
| delete[] raw_datas; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| delete[] raw_datas; | |||
| if (memcpy_s(encode_data, ch_size, huffman_encoded_str_.c_str(), ch_size) != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed."; | |||
| delete[] encode_data; | |||
| @@ -118,7 +141,7 @@ STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) { | |||
| return RET_SUCCESS; | |||
| } | |||
| STATUS huffman_encode::GetHuffmanPriorityQueue(const int8_t *data, const size_t data_size, HuffmanPriorityQueue *pq) { | |||
| STATUS HuffmanEncode::GetHuffmanPriorityQueue(const int8_t *data, const size_t data_size, HuffmanPriorityQueue *pq) { | |||
| MS_ASSERT(data != nullptr); | |||
| std::map<int8_t, size_t> freq_map; | |||
| @@ -166,7 +189,7 @@ STATUS huffman_encode::GetHuffmanPriorityQueue(const int8_t *data, const size_t | |||
| return RET_OK; | |||
| } | |||
| void huffman_encode::GenerateHuffmanTable(const HuffmanNodePtr node, bool is_left_node) { | |||
| void HuffmanEncode::GenerateHuffmanTable(const HuffmanNodePtr node, bool is_left_node) { | |||
| if (is_left_node) { | |||
| node->code = node->parent->code + "0"; | |||
| } else { | |||
| @@ -185,7 +208,7 @@ void huffman_encode::GenerateHuffmanTable(const HuffmanNodePtr node, bool is_lef | |||
| } | |||
| } | |||
| STATUS huffman_encode::BuildHuffmanTree(HuffmanPriorityQueue *pq) { | |||
| STATUS HuffmanEncode::BuildHuffmanTree(HuffmanPriorityQueue *pq) { | |||
| HuffmanNodePtr root = nullptr; | |||
| while (!pq->empty()) { | |||
| @@ -228,7 +251,7 @@ STATUS huffman_encode::BuildHuffmanTree(HuffmanPriorityQueue *pq) { | |||
| return RET_OK; | |||
| } | |||
| STATUS huffman_encode::DoHuffmanCompress(const int8_t *input_datas, const size_t data_size) { | |||
| STATUS HuffmanEncode::DoHuffmanCompress(const int8_t *input_datas, const size_t data_size) { | |||
| unsigned char out_c; | |||
| string code_str; | |||
| std::map<int, string>::iterator iter; | |||
| @@ -270,7 +293,7 @@ STATUS huffman_encode::DoHuffmanCompress(const int8_t *input_datas, const size_t | |||
| return RET_OK; | |||
| } | |||
| huffman_encode::~huffman_encode() { | |||
| HuffmanEncode::~HuffmanEncode() { | |||
| for (auto &node : this->huffman_nodes_) { | |||
| delete node; | |||
| } | |||
| @@ -23,9 +23,12 @@ | |||
| #include <vector> | |||
| #include <queue> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <fstream> | |||
| #include "src/common/log_adapter.h" | |||
| #include "src/ops/primitive_c.h" | |||
| #include "securec/include/securec.h" | |||
| #include "src/param_value_lite.h" | |||
| #include "ir/func_graph.h" | |||
| namespace mindspore { | |||
| @@ -49,13 +52,15 @@ struct cmp { | |||
| }; | |||
| using HuffmanPriorityQueue = std::priority_queue<HuffmanNodePtr, std::vector<HuffmanNodePtr>, cmp>; | |||
| class huffman_encode { | |||
| class HuffmanEncode { | |||
| public: | |||
| huffman_encode() = default; | |||
| HuffmanEncode() = default; | |||
| ~huffman_encode(); | |||
| ~HuffmanEncode(); | |||
| STATUS DoHuffmanEncode(const FuncGraphPtr &func_graph); | |||
| STATUS GetParamValueLitePtr(const std::shared_ptr<AnfNode> &input_node, ParamValueLitePtr *param_value); | |||
| STATUS DoHuffmanEncode(const FuncGraphPtr &func_graph, const int &bit_num); | |||
| private: | |||
| std::map<int, std::string> huffman_table_; | |||
| @@ -25,52 +25,16 @@ using std::string; | |||
| using std::vector; | |||
| namespace mindspore::lite::quant { | |||
| bool WeightQuantizer::IsPosNum(const std::string &str) { | |||
| for (size_t i = 0; i < str.size(); i++) { | |||
| if (str.at(i) < '0' || str.at(i) > '9') { | |||
| return false; | |||
| } | |||
| if (str.at(i) == '0' && i == 0 && str.size() != 1) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) { | |||
| MS_ASSERT(config != nullptr); | |||
| if (!WeightQuantizer::IsPosNum(config->quantWeightChannel)) { | |||
| MS_LOG(ERROR) << "quantWeightChannel must be valid pos num."; | |||
| return RET_ERROR; | |||
| } | |||
| if (!WeightQuantizer::IsPosNum(config->quantWeightSize)) { | |||
| MS_LOG(ERROR) << "quantWeightSize must be valid pos num."; | |||
| return RET_ERROR; | |||
| } | |||
| if (!WeightQuantizer::IsPosNum(config->bitNum)) { | |||
| MS_LOG(ERROR) << "bitNum must be valid pos num."; | |||
| return RET_ERROR; | |||
| } | |||
| int bitNum = std::stoi(config->bitNum); | |||
| if (bitNum <= 0 || bitNum > 16) { | |||
| MS_LOG(ERROR) << "bitNum should be more than 0 and less than 16 currently."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config) : Quantizer(graph) { | |||
| quant_strategy_ = std::make_unique<QuantStrategy>(0, 0); | |||
| config_param_ = config; | |||
| } | |||
| WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const std::string &config_file, const string &weightSize, | |||
| const std::string &convWeightChannelThreshold, const std::string &bitNum) | |||
| : Quantizer(graph) { | |||
| this->config_file_ = config_file; | |||
| auto quantSize = static_cast<size_t>(std::stoull(weightSize)); | |||
| this->bit_num_ = static_cast<size_t>(std::stoull(bitNum)); | |||
| auto convQuantWeightChannelThreshold = static_cast<size_t>(std::stoull(convWeightChannelThreshold)); | |||
| WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const converter::Flags &config) : Quantizer(graph) { | |||
| this->config_file_ = config.configFile; | |||
| auto quantSize = config.quantWeightSize; | |||
| this->bit_num_ = config.bitNum; | |||
| auto convQuantWeightChannelThreshold = config.quantWeightChannel; | |||
| quant_strategy_ = std::make_unique<QuantStrategy>(quantSize, convQuantWeightChannelThreshold); | |||
| quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1; | |||
| quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1)); | |||
| @@ -222,7 +186,7 @@ STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) { | |||
| return RET_OK; | |||
| } | |||
| STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) { | |||
| STATUS WeightQuantizer::DoLstmQuantize(CNodePtr cnode) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| auto op_name = cnode->fullname_with_scope(); | |||
| @@ -233,110 +197,29 @@ STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) { | |||
| MS_LOG(ERROR) << op_name << " inputs is " << cnode->inputs().size(); | |||
| return RET_ERROR; | |||
| } | |||
| { | |||
| auto weight_i = cnode->input(2); | |||
| ParameterPtr param_node; | |||
| ParamValueLitePtr param_value; | |||
| GetLiteParameter(weight_i, ¶m_node, ¶m_value); | |||
| if (param_node == nullptr || param_value == nullptr) { | |||
| MS_LOG(ERROR) << "GetLiteParameter error"; | |||
| return RET_ERROR; | |||
| } | |||
| if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { | |||
| MS_LOG(WARNING) << "param_value tensor type is: " << param_value->tensor_type() << " not quant"; | |||
| return RET_OK; | |||
| } | |||
| if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) { | |||
| MS_LOG(INFO) << op_name << " weight_i cnt: " << param_value->tensor_size() / 4 << " < " | |||
| << quant_strategy_->mWeightSize; | |||
| return RET_OK; | |||
| } | |||
| auto status = RET_ERROR; | |||
| if (type_id_ == kNumberTypeInt8) { | |||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, | |||
| false, 1); | |||
| } else if (type_id_ == kNumberTypeInt16) { | |||
| status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, | |||
| false, 1); | |||
| } | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||
| return status; | |||
| } | |||
| status = SetAbstract(param_value, param_node, primitive_c); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "SetAbstract failed : " << status; | |||
| return RET_ERROR; | |||
| } | |||
| auto status = ProcessLstmWeightByIndex(cnode, primitive_c, 2); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Process lstm weight i failed."; | |||
| return RET_ERROR; | |||
| } | |||
| { | |||
| auto weight_h = cnode->input(3); | |||
| ParameterPtr param_node; | |||
| ParamValueLitePtr param_value; | |||
| GetLiteParameter(weight_h, ¶m_node, ¶m_value); | |||
| if (param_node == nullptr || param_value == nullptr) { | |||
| MS_LOG(ERROR) << "GetLiteParameter error"; | |||
| return RET_ERROR; | |||
| } | |||
| if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { | |||
| MS_LOG(ERROR) << "param_value tensor type is: " << param_value->tensor_type() << " not quant"; | |||
| return RET_ERROR; | |||
| } | |||
| auto status = RET_ERROR; | |||
| if (type_id_ == kNumberTypeInt8) { | |||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, | |||
| false, 2); | |||
| } else if (type_id_ == kNumberTypeInt16) { | |||
| status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, | |||
| false, 2); | |||
| } | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||
| return status; | |||
| } | |||
| status = SetAbstract(param_value, param_node, primitive_c); | |||
| status = ProcessLstmWeightByIndex(cnode, primitive_c, 3); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Process lstm weight h failed."; | |||
| return RET_ERROR; | |||
| } | |||
| if (cnode->inputs().size() > 4) { | |||
| status = ProcessLstmWeightByIndex(cnode, primitive_c, 4); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "SetAbstract failed : " << status; | |||
| MS_LOG(ERROR) << "Process lstm bias failed."; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| { | |||
| if (cnode->inputs().size() > 4) { | |||
| auto bias = cnode->input(4); | |||
| ParameterPtr param_node; | |||
| ParamValueLitePtr param_value; | |||
| GetLiteParameter(bias, ¶m_node, ¶m_value); | |||
| if (param_node == nullptr || param_value == nullptr) { | |||
| MS_LOG(ERROR) << "GetLiteParameter error"; | |||
| return RET_ERROR; | |||
| } | |||
| if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { | |||
| MS_LOG(ERROR) << "param_value tensor type is: " << param_value->tensor_type() << " not quant"; | |||
| return RET_ERROR; | |||
| } | |||
| auto status = RET_ERROR; | |||
| if (type_id_ == kNumberTypeInt8) { | |||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, | |||
| false, 3); | |||
| } else if (type_id_ == kNumberTypeInt16) { | |||
| status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, | |||
| false, 3); | |||
| } | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||
| return status; | |||
| } | |||
| status = SetAbstract(param_value, param_node, primitive_c); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "SetAbstract failed : " << status; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| return status; | |||
| } | |||
| STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) { | |||
| STATUS WeightQuantizer::DoGatherQuantize(CNodePtr cnode) { | |||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| MS_ASSERT(primitive_c != nullptr); | |||
| @@ -375,6 +258,46 @@ STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) { | |||
| return RET_OK; | |||
| } | |||
| STATUS WeightQuantizer::ProcessLstmWeightByIndex(const CNodePtr &cnode, const std::shared_ptr<PrimitiveC> &primitive_c, | |||
| const int &index) { | |||
| auto op_name = cnode->fullname_with_scope(); | |||
| auto weight_i = cnode->input(index); | |||
| ParameterPtr param_node; | |||
| ParamValueLitePtr param_value; | |||
| GetLiteParameter(weight_i, ¶m_node, ¶m_value); | |||
| if (param_node == nullptr || param_value == nullptr) { | |||
| MS_LOG(ERROR) << "GetLiteParameter error"; | |||
| return RET_ERROR; | |||
| } | |||
| if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { | |||
| MS_LOG(WARNING) << "param_value tensor type is: " << param_value->tensor_type() << " not quant"; | |||
| return RET_OK; | |||
| } | |||
| if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) { | |||
| MS_LOG(INFO) << op_name << " weight_i cnt: " << param_value->tensor_size() / 4 << " < " | |||
| << quant_strategy_->mWeightSize; | |||
| return RET_OK; | |||
| } | |||
| auto status = RET_ERROR; | |||
| if (type_id_ == kNumberTypeInt8) { | |||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, | |||
| false, index - 1); | |||
| } else if (type_id_ == kNumberTypeInt16) { | |||
| status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, | |||
| false, index - 1); | |||
| } | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||
| return status; | |||
| } | |||
| status = SetAbstract(param_value, param_node, primitive_c); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "SetAbstract failed : " << status; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| constexpr float relative_tolerance = 1e-5; | |||
| constexpr float abs_tolerance = 1e-4; | |||
| @@ -510,37 +433,28 @@ STATUS WeightQuantizer::RunFp32Graph(FuncGraphPtr func_graph) { | |||
| return RET_OK; | |||
| } | |||
| STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) { | |||
| // 0.2 Parse input calib files | |||
| auto status = CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "CollectCalibInputs fail"; | |||
| return RET_ERROR; | |||
| } | |||
| MS_LOG(DEBUG) << "run fp32 model"; | |||
| status = RunFp32Graph(func_graph); | |||
| if (status != RET_OK) { | |||
| return RET_ERROR; | |||
| } | |||
| STATUS WeightQuantizer::DoMixedQuantize(const FuncGraphPtr &func_graph) { | |||
| auto cnodes = func_graph->GetOrderedCnodes(); | |||
| int status = RET_OK; | |||
| for (auto &cnode : cnodes) { | |||
| auto op_type = NodePrimitiveType(cnode); | |||
| if (op_type == schema::PrimitiveType_Lstm) { | |||
| status = DoLstmQuntize(cnode); | |||
| status = DoLstmQuantize(cnode); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DoLstmQuntize error"; | |||
| MS_LOG(ERROR) << "DoLstmQuantize error"; | |||
| return RET_ERROR; | |||
| } | |||
| } else if (op_type == schema::PrimitiveType_Gather) { | |||
| status = DoGatherQuntize(cnode); | |||
| status = DoGatherQuantize(cnode); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DoGatherQuntize error"; | |||
| MS_LOG(ERROR) << "DoGatherQuantize error"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } | |||
| return status; | |||
| } | |||
| STATUS WeightQuantizer::CheckImageCnt() { | |||
| auto image_cnt = images_.at(0).size(); | |||
| if (!config_param_.input_shapes.empty()) { | |||
| if (config_param_.input_shapes.size() != image_cnt) { | |||
| @@ -548,7 +462,62 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) { | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS WeightQuantizer::GetParamNodeAndValue(const std::shared_ptr<AnfNode> &input_node, const std::string &op_name, | |||
| ParameterPtr *param_node, ParamValueLitePtr *param_value) { | |||
| if (!input_node->isa<Parameter>()) { | |||
| MS_LOG(WARNING) << op_name << " the second input is not parameter"; | |||
| return RET_CONTINUE; | |||
| } | |||
| *param_node = input_node->cast<ParameterPtr>(); | |||
| if (!(*param_node)->has_default()) { | |||
| MS_LOG(WARNING) << op_name << " the second input can not convert to parameter"; | |||
| return RET_CONTINUE; | |||
| } | |||
| *param_value = std::static_pointer_cast<ParamValueLite>((*param_node)->default_param()); | |||
| if (*param_value == nullptr) { | |||
| MS_LOG(WARNING) << op_name << " the second input can not convert to parameter"; | |||
| return RET_CONTINUE; | |||
| } | |||
| if ((*param_value)->tensor_type() != TypeId::kNumberTypeFloat32) { | |||
| MS_LOG(WARNING) << op_name << " the second input type is not float"; | |||
| return RET_CONTINUE; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS WeightQuantizer::TryQuant(const int &bit_num_t, const ParameterPtr ¶m_node, | |||
| const ParamValueLitePtr ¶m_value, const std::shared_ptr<PrimitiveC> &primitive_c) { | |||
| int status; | |||
| type_id_ = TypeId::kNumberTypeInt8; | |||
| int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1; | |||
| int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1)); | |||
| if (type_id_ == TypeId::kNumberTypeInt8) { | |||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t, | |||
| bit_num_t, true); | |||
| } else if (type_id_ == TypeId::kNumberTypeInt16) { | |||
| status = QuantFilter<int16_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t, | |||
| bit_num_t, true); | |||
| } else { | |||
| MS_LOG(ERROR) << "unexpected type_id_: " << type_id_; | |||
| return RET_ERROR; | |||
| } | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "quant filter failed."; | |||
| return RET_ERROR; | |||
| } | |||
| status = SetAbstract(param_value, param_node, primitive_c); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "SetAbstract failed : " << status; | |||
| return RET_ERROR; | |||
| } | |||
| return status; | |||
| } | |||
| STATUS WeightQuantizer::DoQuantSearch(const FuncGraphPtr &func_graph) { | |||
| auto cnodes = func_graph->GetOrderedCnodes(); | |||
| auto image_cnt = images_.at(0).size(); | |||
| int status = RET_OK; | |||
| for (auto iter = cnodes.end(); iter != cnodes.begin();) { | |||
| auto cnode = *(--iter); | |||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| @@ -561,22 +530,10 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) { | |||
| << " type: " << schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive_c->Type()); | |||
| if (quant_strategy_->CanConvOpQuantized(cnode) || quant_strategy_->CanMulOpQuantized(cnode)) { | |||
| auto input_node = cnode->input(2); | |||
| if (!input_node->isa<Parameter>()) { | |||
| MS_LOG(WARNING) << op_name << " the second input is not parameter"; | |||
| continue; | |||
| } | |||
| auto param_node = input_node->cast<ParameterPtr>(); | |||
| if (!param_node->has_default()) { | |||
| MS_LOG(WARNING) << op_name << " the second input can not convert to parameter"; | |||
| continue; | |||
| } | |||
| auto param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param()); | |||
| if (param_value == nullptr) { | |||
| MS_LOG(WARNING) << op_name << " the second input can not convert to parameter"; | |||
| continue; | |||
| } | |||
| if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { | |||
| MS_LOG(WARNING) << op_name << " the second input type is not float"; | |||
| ParameterPtr param_node; | |||
| ParamValueLitePtr param_value; | |||
| status = GetParamNodeAndValue(input_node, op_name, ¶m_node, ¶m_value); | |||
| if (status == RET_CONTINUE) { | |||
| continue; | |||
| } | |||
| // copy origin data in case to recover | |||
| @@ -591,27 +548,9 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) { | |||
| } | |||
| // 1. try quant | |||
| for (int bit_num_t = 2; bit_num_t <= 8; bit_num_t++) { | |||
| type_id_ = TypeId::kNumberTypeInt8; | |||
| int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1; | |||
| int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1)); | |||
| if (type_id_ == TypeId::kNumberTypeInt8) { | |||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, | |||
| quant_min_t, bit_num_t, true); | |||
| } else if (type_id_ == TypeId::kNumberTypeInt16) { | |||
| status = QuantFilter<int16_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, | |||
| quant_min_t, bit_num_t, true); | |||
| } else { | |||
| MS_LOG(ERROR) << "unexpected type_id_: " << type_id_; | |||
| return RET_ERROR; | |||
| } | |||
| status = TryQuant(bit_num_t, param_node, param_value, primitive_c); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "quant filter fail."; | |||
| return RET_ERROR; | |||
| } | |||
| status = SetAbstract(param_value, param_node, primitive_c); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "SetAbstract failed : " << status; | |||
| MS_LOG(ERROR) << "TryQuant failed."; | |||
| return RET_ERROR; | |||
| } | |||
| // 2. evaluate the quant | |||
| @@ -679,6 +618,41 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) { | |||
| free(origin_data); | |||
| } // if: conv and matmul | |||
| } // end loop: all cnode | |||
| return status; | |||
| } | |||
| STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) { | |||
| // 0.2 Parse input calib files | |||
| auto status = CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "CollectCalibInputs failed."; | |||
| return RET_ERROR; | |||
| } | |||
| status = RunFp32Graph(func_graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "RunFp32Graph failed."; | |||
| return RET_ERROR; | |||
| } | |||
| status = DoMixedQuantize(func_graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DoMixedQuantize failed."; | |||
| return RET_ERROR; | |||
| } | |||
| status = CheckImageCnt(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "CheckImageCnt failed."; | |||
| return RET_ERROR; | |||
| } | |||
| status = DoQuantSearch(func_graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DoQuantSearch failed."; | |||
| return RET_ERROR; | |||
| } | |||
| for (const auto &kv : opname_bit_) { | |||
| MS_LOG(INFO) << "op: " << kv.first << " bit:" << kv.second; | |||
| } | |||
| @@ -709,15 +683,15 @@ STATUS WeightQuantizer::DoFixedQuant(FuncGraphPtr func_graph) { | |||
| return RET_ERROR; | |||
| } | |||
| } else if (op_type == schema::PrimitiveType_Lstm) { | |||
| auto status = DoLstmQuntize(cnode); | |||
| auto status = DoLstmQuantize(cnode); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DoLstmQuntize error"; | |||
| MS_LOG(ERROR) << "DoLstmQuantize error"; | |||
| return RET_ERROR; | |||
| } | |||
| } else if (op_type == schema::PrimitiveType_Gather) { | |||
| auto status = DoGatherQuntize(cnode); | |||
| auto status = DoGatherQuantize(cnode); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DoGatherQuntize error"; | |||
| MS_LOG(ERROR) << "DoGatherQuantize error"; | |||
| return RET_ERROR; | |||
| } | |||
| } else { | |||
| @@ -36,18 +36,18 @@ | |||
| namespace mindspore::lite::quant { | |||
| class WeightQuantizer : public Quantizer { | |||
| public: | |||
| WeightQuantizer(FuncGraphPtr graph, const std::string &config_file, const std::string &weightSize, | |||
| const std::string &covWeightChannelThreshold, const std::string &bitNum); | |||
| WeightQuantizer(FuncGraphPtr graph, const converter::Flags &config); | |||
| WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config); | |||
| ~WeightQuantizer(); | |||
| STATUS DoQuantize(FuncGraphPtr func_graph) override; | |||
| STATUS DoConvQuantize(CNodePtr); | |||
| STATUS DoMulQuantize(CNodePtr); | |||
| STATUS DoLstmQuntize(CNodePtr cnode); | |||
| STATUS DoGatherQuntize(CNodePtr cnode); | |||
| static STATUS WeightQuantInputCheck(const converter::Flags *config); | |||
| static bool IsPosNum(const std::string &str); | |||
| STATUS DoLstmQuantize(CNodePtr cnode); | |||
| STATUS DoGatherQuantize(CNodePtr cnode); | |||
| STATUS ProcessLstmWeightByIndex(const CNodePtr &cnode, const std::shared_ptr<PrimitiveC> &primitive_c, | |||
| const int &index); | |||
| int quant_max_{127}; | |||
| int quant_min_{-128}; | |||
| @@ -66,6 +66,14 @@ class WeightQuantizer : public Quantizer { | |||
| STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr<PrimitiveC> primitive_c); | |||
| STATUS DoFixedQuant(FuncGraphPtr); | |||
| STATUS RunFp32Graph(FuncGraphPtr); | |||
| STATUS DoMixedQuantize(const FuncGraphPtr &func_graph); | |||
| STATUS CheckImageCnt(); | |||
| STATUS GetParamNodeAndValue(const std::shared_ptr<AnfNode> &input_node, const std::string &op_name, | |||
| ParameterPtr *param_node, ParamValueLitePtr *param_value); | |||
| STATUS TryQuant(const int &bit_num_t, const ParameterPtr ¶m_node, const ParamValueLitePtr ¶m_value, | |||
| const std::shared_ptr<PrimitiveC> &primitive_c); | |||
| STATUS DoQuantSearch(const FuncGraphPtr &func_graph); | |||
| }; | |||
| } // namespace mindspore::lite::quant | |||
| #endif | |||