Browse Source

!11695 [MS][LITE]huffman code support 1~8 bit && change it to internal interface

From: @jianghui58
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
5412b6ba3f
14 changed files with 379 additions and 344 deletions
  1. +17
    -2
      mindspore/lite/src/dequant.cc
  2. +16
    -1
      mindspore/lite/src/dequant.h
  3. +4
    -4
      mindspore/lite/src/huffman_decode.cc
  4. +3
    -3
      mindspore/lite/src/huffman_decode.h
  5. +11
    -26
      mindspore/lite/src/lite_session.cc
  6. +2
    -2
      mindspore/lite/test/run_benchmark_nets.sh
  7. +12
    -19
      mindspore/lite/tools/converter/anf_transform.cc
  8. +1
    -1
      mindspore/lite/tools/converter/anf_transform.h
  9. +43
    -29
      mindspore/lite/tools/converter/converter_flags.cc
  10. +10
    -7
      mindspore/lite/tools/converter/converter_flags.h
  11. +65
    -42
      mindspore/lite/tools/converter/quantizer/huffman_encode.cc
  12. +9
    -4
      mindspore/lite/tools/converter/quantizer/huffman_encode.h
  13. +172
    -198
      mindspore/lite/tools/converter/quantizer/weight_quantizer.cc
  14. +14
    -6
      mindspore/lite/tools/converter/quantizer/weight_quantizer.h

+ 17
- 2
mindspore/lite/src/dequant.cc View File

@@ -14,7 +14,10 @@
* limitations under the License. * limitations under the License.
*/ */
#include <cmath> #include <cmath>
#include <string>
#include <memory>
#include "src/dequant.h" #include "src/dequant.h"
#include "src/huffman_decode.h"


namespace mindspore::lite { namespace mindspore::lite {
float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) { float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) {
@@ -34,13 +37,24 @@ float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) {
} }
} }


void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_int_data) {
int DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_int_data) {
MS_ASSERT(input_tensor != nullptr); MS_ASSERT(input_tensor != nullptr);
MS_ASSERT(unpack_int_data != nullptr); MS_ASSERT(unpack_int_data != nullptr);
auto quant_params = input_tensor->quantParams(); auto quant_params = input_tensor->quantParams();
if (quant_params == nullptr) { if (quant_params == nullptr) {
MS_LOG(ERROR) << "low bits quantparams is empty."; MS_LOG(ERROR) << "low bits quantparams is empty.";
return;
return RET_ERROR;
}
auto enable_huffman_code = input_tensor->enableHuffmanCode();
if (enable_huffman_code) {
std::string encode_str(input_tensor->data()->begin(), input_tensor->data()->end());
auto huffman_decode = std::make_unique<lite::HuffmanDecode>();
auto ret = huffman_decode->DoHuffmanDecode(encode_str, unpack_int_data);
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoHuffmanDecode failed.";
return ret;
}
return RET_OK;
} }
int origin_bit = quant_params->Get(0)->numBits(); int origin_bit = quant_params->Get(0)->numBits();
if (origin_bit < 8 && origin_bit > 0) { if (origin_bit < 8 && origin_bit > 0) {
@@ -48,6 +62,7 @@ void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_i
} else if (origin_bit < 16 && origin_bit > 8) { } else if (origin_bit < 16 && origin_bit > 8) {
UnPackUtil<int16_t, uint16_t>(input_tensor, origin_bit, unpack_int_data); UnPackUtil<int16_t, uint16_t>(input_tensor, origin_bit, unpack_int_data);
} }
return RET_OK;
} }


std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const std::vector<Tensor *> &in_tensors, std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const std::vector<Tensor *> &in_tensors,


+ 16
- 1
mindspore/lite/src/dequant.h View File

@@ -31,7 +31,7 @@ class DequantUtil {
public: public:
static float *DequantWeight(lite::Tensor *input_tensor); static float *DequantWeight(lite::Tensor *input_tensor);


static void UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data);
static int UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data);


static std::map<Tensor *, std::pair<TypeId, void *>> DequantTensor(const std::vector<Tensor *> &in_tensors, static std::map<Tensor *, std::pair<TypeId, void *>> DequantTensor(const std::vector<Tensor *> &in_tensors,
TypeId data_type, bool need_restore = true); TypeId data_type, bool need_restore = true);
@@ -110,6 +110,21 @@ class DequantUtil {
return dequant_datas; return dequant_datas;
} }


template <typename T1, typename T2>
static void UnpackUtil(const T1 *weight_data, int pack_size, int origin_bit, void *unpack_int_data) {
if (weight_data == nullptr || unpack_int_data == nullptr) {
MS_LOG(ERROR) << "data is nullptr";
return;
}
std::queue<bool> unpack_bit_data;
size_t count = 0;
for (int i = 0; i < pack_size; ++i) {
T2 pack_data = (static_cast<const T2 *>(static_cast<const void *>(weight_data)))[i];
bool is_last = i == pack_size - 1;
UnPackData<T1, T2>(origin_bit, pack_data, &unpack_bit_data, unpack_int_data, &count, is_last);
}
}

private: private:
template <typename T1, typename T2> template <typename T1, typename T2>
static void UnPackData(int origin_bit, const T2 &packed_data, std::queue<bool> *unpack_bit_data, void *unpack_int, static void UnPackData(int origin_bit, const T2 &packed_data, std::queue<bool> *unpack_bit_data, void *unpack_int,


+ 4
- 4
mindspore/lite/src/huffman_decode.cc View File

@@ -19,7 +19,7 @@
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {


STATUS huffman_decode::DoHuffmanDecode(const std::string &input_str, void *decoded_data) {
STATUS HuffmanDecode::DoHuffmanDecode(const std::string &input_str, void *decoded_data) {
if (decoded_data == nullptr) { if (decoded_data == nullptr) {
MS_LOG(ERROR) << "decoded_data is nullptr."; MS_LOG(ERROR) << "decoded_data is nullptr.";
return RET_ERROR; return RET_ERROR;
@@ -64,7 +64,7 @@ STATUS huffman_decode::DoHuffmanDecode(const std::string &input_str, void *decod
return RET_OK; return RET_OK;
} }


STATUS huffman_decode::RebuildHuffmanTree(std::string keys, std::string codes, const HuffmanNodePtr &root) {
STATUS HuffmanDecode::RebuildHuffmanTree(std::string keys, std::string codes, const HuffmanNodePtr &root) {
HuffmanNodePtr cur_node, tmp_node, new_node; HuffmanNodePtr cur_node, tmp_node, new_node;


auto huffman_keys = Str2Vec(std::move(keys)); auto huffman_keys = Str2Vec(std::move(keys));
@@ -121,7 +121,7 @@ STATUS huffman_decode::RebuildHuffmanTree(std::string keys, std::string codes, c
return RET_OK; return RET_OK;
} }


STATUS huffman_decode::DoHuffmanDecompress(HuffmanNodePtr root, std::string encoded_data, std::string *decoded_str) {
STATUS HuffmanDecode::DoHuffmanDecompress(HuffmanNodePtr root, std::string encoded_data, std::string *decoded_str) {
HuffmanNodePtr cur_node = root; HuffmanNodePtr cur_node = root;
bool pseudo_eof = false; bool pseudo_eof = false;
size_t pos = 0; size_t pos = 0;
@@ -157,7 +157,7 @@ STATUS huffman_decode::DoHuffmanDecompress(HuffmanNodePtr root, std::string enco
return RET_OK; return RET_OK;
} }


huffman_decode::~huffman_decode() {
HuffmanDecode::~HuffmanDecode() {
for (auto &node : this->huffman_nodes_) { for (auto &node : this->huffman_nodes_) {
delete node; delete node;
} }


+ 3
- 3
mindspore/lite/src/huffman_decode.h View File

@@ -38,11 +38,11 @@ struct HuffmanNode {
}; };
using HuffmanNodePtr = HuffmanNode *; using HuffmanNodePtr = HuffmanNode *;


class huffman_decode {
class HuffmanDecode {
public: public:
huffman_decode() = default;
HuffmanDecode() = default;


~huffman_decode();
~HuffmanDecode();


STATUS DoHuffmanDecode(const std::string &input_str, void *decoded_data); STATUS DoHuffmanDecode(const std::string &input_str, void *decoded_data);




+ 11
- 26
mindspore/lite/src/lite_session.cc View File

@@ -28,7 +28,6 @@
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "src/lite_model.h" #include "src/lite_model.h"
#include "src/dequant.h" #include "src/dequant.h"
#include "src/huffman_decode.h"
#if SUPPORT_NPU #if SUPPORT_NPU
#include "src/runtime/agent/npu/npu_manager.h" #include "src/runtime/agent/npu/npu_manager.h"
#include "src/runtime/agent/npu/optimizer/npu_pass_manager.h" #include "src/runtime/agent/npu/optimizer/npu_pass_manager.h"
@@ -96,13 +95,6 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde
int org_size = dst_tensor->Size(); int org_size = dst_tensor->Size();
return (pack_size != org_size) && (data_type == kNumberTypeInt8 || data_type == kNumberTypeInt16); return (pack_size != org_size) && (data_type == kNumberTypeInt8 || data_type == kNumberTypeInt16);
}; };
auto NeedHuffmanDecode = [&src_tensor, &dst_tensor]() -> bool {
auto data_type = src_tensor->dataType();
auto enable_huffman_code = src_tensor->enableHuffmanCode();
int pack_size = src_tensor->data()->size();
int org_size = dst_tensor->Size();
return (pack_size != org_size) && (data_type == kNumberTypeInt8) && enable_huffman_code;
};
auto src_category = TensorCategory(src_tensor); auto src_category = TensorCategory(src_tensor);
if ((src_category == Tensor::Category::CONST_TENSOR || src_category == Tensor::Category::CONST_SCALAR) && if ((src_category == Tensor::Category::CONST_TENSOR || src_category == Tensor::Category::CONST_SCALAR) &&
src_tensor->data() != nullptr && src_tensor->data()->size() > 0) { src_tensor->data() != nullptr && src_tensor->data()->size() > 0) {
@@ -116,21 +108,6 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde
return RET_ERROR; return RET_ERROR;
} }
} else { } else {
if (NeedHuffmanDecode()) {
auto dst_data = dst_tensor->MutableData();
if (dst_data == nullptr) {
MS_LOG(ERROR) << "Data from tensor is nullptr";
return RET_NULL_PTR;
}
std::string encode_str(src_tensor->data()->begin(), src_tensor->data()->end());
auto huffman_decode = std::make_unique<lite::huffman_decode>();
auto ret = huffman_decode->DoHuffmanDecode(encode_str, dst_data);
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoHuffmanDecode failed.";
return ret;
}
copyed_tensor_idxes_.emplace_back(tensor_index);
}
if (WeightTensorNeedCopy(model, tensor_index)) { if (WeightTensorNeedCopy(model, tensor_index)) {
auto dst_data = dst_tensor->MutableData(); auto dst_data = dst_tensor->MutableData();
if (dst_data == nullptr) { if (dst_data == nullptr) {
@@ -138,7 +115,11 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde
return RET_NULL_PTR; return RET_NULL_PTR;
} }
if (NeedUnPack()) { if (NeedUnPack()) {
DequantUtil::UnPackToInt(src_tensor, dst_data);
auto ret = DequantUtil::UnPackToInt(src_tensor, dst_data);
if (ret != RET_OK) {
MS_LOG(ERROR) << "unpack to int failed.";
return RET_NULL_PTR;
}
} else { } else {
memcpy(dst_data, src_tensor->data()->data(), dst_tensor->Size()); memcpy(dst_data, src_tensor->data()->data(), dst_tensor->Size());
} }
@@ -148,9 +129,13 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde
auto dst_data = dst_tensor->MutableData(); auto dst_data = dst_tensor->MutableData();
if (dst_data == nullptr) { if (dst_data == nullptr) {
MS_LOG(ERROR) << "Data from tensor is nullptr"; MS_LOG(ERROR) << "Data from tensor is nullptr";
return RET_NULL_PTR;
return RET_ERROR;
}
auto ret = DequantUtil::UnPackToInt(src_tensor, dst_data);
if (ret != RET_OK) {
MS_LOG(ERROR) << "unpack to int failed.";
return RET_ERROR;
} }
DequantUtil::UnPackToInt(src_tensor, dst_data);
copyed_tensor_idxes_.emplace_back(tensor_index); copyed_tensor_idxes_.emplace_back(tensor_index);
} else { } else {
dst_tensor->set_data(const_cast<unsigned char *>(src_tensor->data()->data())); dst_tensor->set_data(const_cast<unsigned char *>(src_tensor->data()->data()));


+ 2
- 2
mindspore/lite/test/run_benchmark_nets.sh View File

@@ -227,8 +227,8 @@ function Run_Converter() {
fi fi
model_name=`echo ${weight_quant_line_info}|awk -F ' ' '{print $1}'` model_name=`echo ${weight_quant_line_info}|awk -F ' ' '{print $1}'`
echo ${model_name} >> "${run_converter_log_file}" echo ${model_name} >> "${run_converter_log_file}"
echo './converter_lite --fmk=TFLITE --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}'--quantType=WeightQuant --bitNum=8 --quantWeightChannel=0 --enableHuffmanCode=true' >> "${run_converter_log_file}"
./converter_lite --fmk=TFLITE --modelFile=$models_path/${model_name} --outputFile=${ms_models_path}/${model_name}_weightquant --quantType=WeightQuant --bitNum=8 --quantWeightChannel=0 --enableHuffmanCode=true
echo './converter_lite --fmk=TFLITE --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}'--quantType=WeightQuant --bitNum=8 --quantWeightChannel=0' >> "${run_converter_log_file}"
./converter_lite --fmk=TFLITE --modelFile=$models_path/${model_name} --outputFile=${ms_models_path}/${model_name}_weightquant --quantType=WeightQuant --bitNum=8 --quantWeightChannel=0
if [ $? = 0 ]; then if [ $? = 0 ]; then
converter_result='converter weight_quant '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file} converter_result='converter weight_quant '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file}
else else


+ 12
- 19
mindspore/lite/tools/converter/anf_transform.cc View File

@@ -217,26 +217,14 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Fla
const FuncGraphPtr &new_graph) { const FuncGraphPtr &new_graph) {
// quant // quant
if (config->quantType == schema::QuantType_PostTraining) { if (config->quantType == schema::QuantType_PostTraining) {
if (!quant::WeightQuantizer::IsPosNum(config->bitNum)) {
MS_LOG(ERROR) << "bitNum must be valid pos num.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return RET_ERROR;
}
this->mQuantizer =
std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, std::stoi(config->bitNum));
this->mQuantizer = std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, config->bitNum);
if (mQuantizer == nullptr) { if (mQuantizer == nullptr) {
MS_LOG(ERROR) << "New PostTrainingQuantizer failed"; MS_LOG(ERROR) << "New PostTrainingQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
return RET_ERROR; return RET_ERROR;
} }
} else if (config->quantType == schema::QuantType_WeightQuant) { } else if (config->quantType == schema::QuantType_WeightQuant) {
if (quant::WeightQuantizer::WeightQuantInputCheck(config) != RET_OK) {
MS_LOG(ERROR) << "weight quant input param error";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return RET_ERROR;
}
this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, config->configFile, config->quantWeightSize,
config->quantWeightChannel, config->bitNum);
this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, *config);
if (mQuantizer == nullptr) { if (mQuantizer == nullptr) {
MS_LOG(ERROR) << "New WeightQuantizer failed"; MS_LOG(ERROR) << "New WeightQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
@@ -255,10 +243,15 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Fla
return RET_OK; return RET_OK;
} }


int AnfTransform::DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph) {
if (config->quantType == schema::QuantType_WeightQuant && config->bitNum == "8" && config->enableHuffmanCode) {
auto huffman_encode = std::make_unique<lite::huffman_encode>();
auto status = huffman_encode->DoHuffmanEncode(new_graph);
int AnfTransform::DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph,
bool enableHuffmanCode) {
if (config->quantType == schema::QuantType_WeightQuant && enableHuffmanCode) {
if (config->bitNum < 16 && config->bitNum > 8) {
MS_LOG(WARNING) << "don't support huffman encode when 8 < bitNum < 16 currently.";
return RET_OK;
}
auto huffman_encode = std::make_unique<lite::HuffmanEncode>();
auto status = huffman_encode->DoHuffmanEncode(new_graph, config->bitNum);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Huffman encode failed."; MS_LOG(ERROR) << "Huffman encode failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
@@ -322,7 +315,7 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
return nullptr; return nullptr;
} }


status = DoHuffmanEncode(config, new_graph);
status = DoHuffmanEncode(config, new_graph, false);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Do HuffmanCode failed."; MS_LOG(ERROR) << "Do HuffmanCode failed.";
return nullptr; return nullptr;


+ 1
- 1
mindspore/lite/tools/converter/anf_transform.h View File

@@ -59,7 +59,7 @@ class AnfTransform {


int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config, const FuncGraphPtr &new_graph); int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config, const FuncGraphPtr &new_graph);


int DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph);
int DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph, bool enableHuffmanCode);
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


+ 43
- 29
mindspore/lite/tools/converter/converter_flags.cc View File

@@ -38,16 +38,12 @@ Flags::Flags() {
"UINT8 | DEFAULT", "UINT8 | DEFAULT",
"DEFAULT"); "DEFAULT");
AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. PostTraining | WeightQuant", ""); AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. PostTraining | WeightQuant", "");
AddFlag(&Flags::bitNum, "bitNum", "Weight quantization bitNum", "8");
AddFlag(&Flags::quantWeightSize, "quantWeightSize", "Weight quantization size threshold", "0");
AddFlag(&Flags::quantWeightChannel, "quantWeightChannel", "Channel threshold for weight quantization", "16");
AddFlag(&Flags::bitNumIn, "bitNum", "Weight quantization bitNum", "8");
AddFlag(&Flags::quantWeightSizeIn, "quantWeightSize", "Weight quantization size threshold", "0");
AddFlag(&Flags::quantWeightChannelIn, "quantWeightChannel", "Channel threshold for weight quantization", "16");
AddFlag(&Flags::configFile, "configFile", "Configuration for post-training.", ""); AddFlag(&Flags::configFile, "configFile", "Configuration for post-training.", "");
AddFlag(&Flags::enableHuffmanCodeIn, "enableHuffmanCode",
"whether the weight quant model is going to use huffman code."
"true | false",
"false");
AddFlag(&Flags::trainModelIn, "trainModel", AddFlag(&Flags::trainModelIn, "trainModel",
"whether the model is going to be trained on device."
"whether the model is going to be trained on device. "
"true | false", "true | false",
"false"); "false");
} }
@@ -107,7 +103,41 @@ int Flags::InitFmk() {
return RET_OK; return RET_OK;
} }


int Flags::InitQuantType() {
bool Flags::IsValidNum(const std::string &str, int *num) {
char *ptr;
*num = strtol(str.c_str(), &ptr, 10);
return ptr == (str.c_str() + str.size());
}

int Flags::QuantParamInputCheck() {
if (!Flags::IsValidNum(this->quantWeightChannelIn, &this->quantWeightChannel)) {
std::cerr << "quantWeightChannel should be a valid number.";
return RET_INPUT_PARAM_INVALID;
}
if (this->quantWeightChannel < 0) {
std::cerr << "quantWeightChannel should be greater than or equal to zero.";
return RET_INPUT_PARAM_INVALID;
}
if (!Flags::IsValidNum(this->quantWeightSizeIn, &this->quantWeightSize)) {
std::cerr << "quantWeightSize should be a valid number.";
return RET_INPUT_PARAM_INVALID;
}
if (this->quantWeightSize < 0) {
std::cerr << "quantWeightSize should be greater than or equal to zero.";
return RET_INPUT_PARAM_INVALID;
}
if (!Flags::IsValidNum(this->bitNumIn, &this->bitNum)) {
std::cerr << "bitNum should be a valid number.";
return RET_INPUT_PARAM_INVALID;
}
if (this->bitNum <= 0 || this->bitNum > 16) {
std::cerr << "bitNum should be greater than zero and lesser than 16 currently.";
return RET_INPUT_PARAM_INVALID;
}
return RET_OK;
}

int Flags::InitQuantParam() {
if (this->quantTypeIn == "WeightQuant") { if (this->quantTypeIn == "WeightQuant") {
this->quantType = QuantType_WeightQuant; this->quantType = QuantType_WeightQuant;
} else if (this->quantTypeIn == "PostTraining") { } else if (this->quantTypeIn == "PostTraining") {
@@ -118,19 +148,9 @@ int Flags::InitQuantType() {
std::cerr << "INPUT ILLEGAL: quantType must be WeightQuant|PostTraining"; std::cerr << "INPUT ILLEGAL: quantType must be WeightQuant|PostTraining";
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
return RET_OK;
}


int Flags::InitHuffmanCode() {
if (this->enableHuffmanCodeIn == "true") {
this->enableHuffmanCode = true;
} else if (this->enableHuffmanCodeIn == "false") {
this->enableHuffmanCode = false;
} else {
std::cerr << "INPUT ILLEGAL: trainModel must be true|false ";
return RET_INPUT_PARAM_INVALID;
}
return RET_OK;
auto ret = QuantParamInputCheck();
return ret;
} }


int Flags::InitTrainModel() { int Flags::InitTrainModel() {
@@ -218,15 +238,9 @@ int Flags::Init(int argc, const char **argv) {
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }


ret = InitQuantType();
if (ret != RET_OK) {
std::cerr << "Init quant type failed.";
return RET_INPUT_PARAM_INVALID;
}

ret = InitHuffmanCode();
ret = InitQuantParam();
if (ret != RET_OK) { if (ret != RET_OK) {
std::cerr << "Init huffman code failed.";
std::cerr << "Init quant param failed.";
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }




+ 10
- 7
mindspore/lite/tools/converter/converter_flags.h View File

@@ -49,9 +49,11 @@ class Flags : public virtual mindspore::lite::FlagParser {


int InitFmk(); int InitFmk();


int InitQuantType();
bool IsValidNum(const std::string &str, int *num);


int InitHuffmanCode();
int QuantParamInputCheck();

int InitQuantParam();


int InitTrainModel(); int InitTrainModel();


@@ -76,12 +78,13 @@ class Flags : public virtual mindspore::lite::FlagParser {
TypeId inputDataType; TypeId inputDataType;
TypeId outputDataType; TypeId outputDataType;
// used for post-trainning-weight // used for post-trainning-weight
std::string quantWeightSize;
std::string bitNum;
std::string quantWeightSizeIn;
int quantWeightSize;
std::string bitNumIn;
int bitNum;
std::string configFile; std::string configFile;
std::string quantWeightChannel;
std::string enableHuffmanCodeIn;
bool enableHuffmanCode = false;
std::string quantWeightChannelIn;
int quantWeightChannel;
std::string trainModelIn; std::string trainModelIn;
bool trainModel = false; bool trainModel = false;
}; };


+ 65
- 42
mindspore/lite/tools/converter/quantizer/huffman_encode.cc View File

@@ -18,18 +18,51 @@


#include <utility> #include <utility>
#include <iostream> #include <iostream>
#include <memory>
#include <vector>


#include "securec/include/securec.h"
#include "src/param_value_lite.h"
#include "src/dequant.h"


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {


STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) {
STATUS HuffmanEncode::GetParamValueLitePtr(const std::shared_ptr<AnfNode> &input_node, ParamValueLitePtr *param_value) {
if (!input_node->isa<Parameter>()) {
return RET_CONTINUE;
}
auto abstract_base = input_node->abstract();
if (abstract_base == nullptr) {
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << input_node->fullname_with_scope();
return RET_ERROR;
}
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << input_node->fullname_with_scope();
return RET_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
if (abstract_tensor->element() == nullptr) {
MS_LOG(ERROR) << "abstract tensor element is nullptr, " << input_node->fullname_with_scope();
return RET_ERROR;
}
auto tensor_type = abstract_tensor->element()->GetTypeTrack();
MS_ASSERT(tensor_type != nullptr);
auto tensor_type_id = tensor_type->type_id();
if (tensor_type_id != kNumberTypeInt8) {
return RET_CONTINUE;
}
auto param_node = input_node->cast<ParameterPtr>();
if (param_node == nullptr) {
MS_LOG(ERROR) << "parameter node is nullptr, " << input_node->fullname_with_scope();
return RET_ERROR;
}
if (!param_node->has_default()) {
MS_LOG(WARNING) << "param_node don't have default: " << input_node->fullname_with_scope();
return RET_CONTINUE;
}
*param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
return RET_OK;
}

STATUS HuffmanEncode::DoHuffmanEncode(const FuncGraphPtr &func_graph, const int &bit_num) {
auto cnodes = func_graph->GetOrderedCnodes(); auto cnodes = func_graph->GetOrderedCnodes();
STATUS status;
for (auto &cnode : cnodes) { for (auto &cnode : cnodes) {
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) { if (primitive_c == nullptr) {
@@ -41,45 +74,33 @@ STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) {
} }
for (size_t i = 1; i < cnode->inputs().size(); i++) { for (size_t i = 1; i < cnode->inputs().size(); i++) {
auto input_node = cnode->input(i); auto input_node = cnode->input(i);
if (!input_node->isa<Parameter>()) {
continue;
}
auto abstract_base = input_node->abstract();
if (abstract_base == nullptr) {
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << input_node->fullname_with_scope();
return RET_ERROR;
}
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << input_node->fullname_with_scope();
return RET_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
if (abstract_tensor->element() == nullptr) {
MS_LOG(ERROR) << "abstract tensor element is nullptr, " << input_node->fullname_with_scope();
return RET_ERROR;
}
auto tensor_type = abstract_tensor->element()->GetTypeTrack();
MS_ASSERT(tensor_type != nullptr);
auto tensor_type_id = tensor_type->type_id();
if (tensor_type_id != kNumberTypeInt8) {
ParamValueLitePtr param_value;
auto status = GetParamValueLitePtr(input_node, &param_value);
if (status == RET_CONTINUE) {
continue; continue;
}
auto param_node = input_node->cast<ParameterPtr>();
if (param_node == nullptr) {
MS_LOG(ERROR) << "parameter node is nullptr, " << input_node->fullname_with_scope();
} else if (status == RET_ERROR) {
MS_LOG(ERROR) << "Get param value lite ptr failed. " << cnode->fullname_with_scope();
return RET_ERROR; return RET_ERROR;
} }
if (!param_node->has_default()) {
MS_LOG(WARNING) << "param_node don't have default: " << cnode->fullname_with_scope();
continue;
}
ParamValueLitePtr param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
size_t elem_count = param_value->tensor_shape_size(); size_t elem_count = param_value->tensor_shape_size();
size_t packed_size = param_value->tensor_size();
auto *raw_datas = static_cast<int8_t *>(param_value->tensor_addr()); auto *raw_datas = static_cast<int8_t *>(param_value->tensor_addr());
if (raw_datas == nullptr) { if (raw_datas == nullptr) {
MS_LOG(ERROR) << "rawDatas is nullptr"; MS_LOG(ERROR) << "rawDatas is nullptr";
return RET_ERROR; return RET_ERROR;
} }
if (bit_num < 8 && bit_num > 0) {
auto dst_data = new (std::nothrow) int8_t[elem_count];
if (dst_data == nullptr) {
MS_LOG(ERROR) << "new int8_t[] failed";
return RET_ERROR;
}
DequantUtil::UnpackUtil<int8_t, uint8_t>(raw_datas, packed_size, bit_num, dst_data);
if (memcpy_s(raw_datas, elem_count, dst_data, elem_count) != EOK) {
MS_LOG(ERROR) << "memcpy_s failed.";
return RET_MEMORY_FAILED;
}
}
HuffmanPriorityQueue pq; HuffmanPriorityQueue pq;
status = GetHuffmanPriorityQueue(raw_datas, elem_count, &pq); status = GetHuffmanPriorityQueue(raw_datas, elem_count, &pq);
if (status != RET_OK) { if (status != RET_OK) {
@@ -97,12 +118,14 @@ STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) {
return status; return status;
} }
size_t ch_size = huffman_encoded_str_.length(); size_t ch_size = huffman_encoded_str_.length();
if (ch_size < elem_count) {
if (ch_size < packed_size) {
auto encode_data = new (std::nothrow) char[ch_size]; auto encode_data = new (std::nothrow) char[ch_size];
if (encode_data == nullptr) { if (encode_data == nullptr) {
MS_LOG(ERROR) << "new char[] failed."; MS_LOG(ERROR) << "new char[] failed.";
delete[] raw_datas;
return RET_MEMORY_FAILED; return RET_MEMORY_FAILED;
} }
delete[] raw_datas;
if (memcpy_s(encode_data, ch_size, huffman_encoded_str_.c_str(), ch_size) != EOK) { if (memcpy_s(encode_data, ch_size, huffman_encoded_str_.c_str(), ch_size) != EOK) {
MS_LOG(ERROR) << "memcpy_s failed."; MS_LOG(ERROR) << "memcpy_s failed.";
delete[] encode_data; delete[] encode_data;
@@ -118,7 +141,7 @@ STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) {
return RET_SUCCESS; return RET_SUCCESS;
} }


STATUS huffman_encode::GetHuffmanPriorityQueue(const int8_t *data, const size_t data_size, HuffmanPriorityQueue *pq) {
STATUS HuffmanEncode::GetHuffmanPriorityQueue(const int8_t *data, const size_t data_size, HuffmanPriorityQueue *pq) {
MS_ASSERT(data != nullptr); MS_ASSERT(data != nullptr);


std::map<int8_t, size_t> freq_map; std::map<int8_t, size_t> freq_map;
@@ -166,7 +189,7 @@ STATUS huffman_encode::GetHuffmanPriorityQueue(const int8_t *data, const size_t
return RET_OK; return RET_OK;
} }


void huffman_encode::GenerateHuffmanTable(const HuffmanNodePtr node, bool is_left_node) {
void HuffmanEncode::GenerateHuffmanTable(const HuffmanNodePtr node, bool is_left_node) {
if (is_left_node) { if (is_left_node) {
node->code = node->parent->code + "0"; node->code = node->parent->code + "0";
} else { } else {
@@ -185,7 +208,7 @@ void huffman_encode::GenerateHuffmanTable(const HuffmanNodePtr node, bool is_lef
} }
} }


STATUS huffman_encode::BuildHuffmanTree(HuffmanPriorityQueue *pq) {
STATUS HuffmanEncode::BuildHuffmanTree(HuffmanPriorityQueue *pq) {
HuffmanNodePtr root = nullptr; HuffmanNodePtr root = nullptr;


while (!pq->empty()) { while (!pq->empty()) {
@@ -228,7 +251,7 @@ STATUS huffman_encode::BuildHuffmanTree(HuffmanPriorityQueue *pq) {
return RET_OK; return RET_OK;
} }


STATUS huffman_encode::DoHuffmanCompress(const int8_t *input_datas, const size_t data_size) {
STATUS HuffmanEncode::DoHuffmanCompress(const int8_t *input_datas, const size_t data_size) {
unsigned char out_c; unsigned char out_c;
string code_str; string code_str;
std::map<int, string>::iterator iter; std::map<int, string>::iterator iter;
@@ -270,7 +293,7 @@ STATUS huffman_encode::DoHuffmanCompress(const int8_t *input_datas, const size_t
return RET_OK; return RET_OK;
} }


huffman_encode::~huffman_encode() {
HuffmanEncode::~HuffmanEncode() {
for (auto &node : this->huffman_nodes_) { for (auto &node : this->huffman_nodes_) {
delete node; delete node;
} }


+ 9
- 4
mindspore/lite/tools/converter/quantizer/huffman_encode.h View File

@@ -23,9 +23,12 @@
#include <vector> #include <vector>
#include <queue> #include <queue>
#include <map> #include <map>
#include <memory>
#include <fstream> #include <fstream>
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
#include "src/ops/primitive_c.h" #include "src/ops/primitive_c.h"
#include "securec/include/securec.h"
#include "src/param_value_lite.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"


namespace mindspore { namespace mindspore {
@@ -49,13 +52,15 @@ struct cmp {
}; };
using HuffmanPriorityQueue = std::priority_queue<HuffmanNodePtr, std::vector<HuffmanNodePtr>, cmp>; using HuffmanPriorityQueue = std::priority_queue<HuffmanNodePtr, std::vector<HuffmanNodePtr>, cmp>;


class huffman_encode {
class HuffmanEncode {
public: public:
huffman_encode() = default;
HuffmanEncode() = default;


~huffman_encode();
~HuffmanEncode();


STATUS DoHuffmanEncode(const FuncGraphPtr &func_graph);
STATUS GetParamValueLitePtr(const std::shared_ptr<AnfNode> &input_node, ParamValueLitePtr *param_value);

STATUS DoHuffmanEncode(const FuncGraphPtr &func_graph, const int &bit_num);


private: private:
std::map<int, std::string> huffman_table_; std::map<int, std::string> huffman_table_;


+ 172
- 198
mindspore/lite/tools/converter/quantizer/weight_quantizer.cc View File

@@ -25,52 +25,16 @@ using std::string;
using std::vector; using std::vector;


namespace mindspore::lite::quant { namespace mindspore::lite::quant {
bool WeightQuantizer::IsPosNum(const std::string &str) {
for (size_t i = 0; i < str.size(); i++) {
if (str.at(i) < '0' || str.at(i) > '9') {
return false;
}
if (str.at(i) == '0' && i == 0 && str.size() != 1) {
return false;
}
}
return true;
}

STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) {
MS_ASSERT(config != nullptr);
if (!WeightQuantizer::IsPosNum(config->quantWeightChannel)) {
MS_LOG(ERROR) << "quantWeightChannel must be valid pos num.";
return RET_ERROR;
}
if (!WeightQuantizer::IsPosNum(config->quantWeightSize)) {
MS_LOG(ERROR) << "quantWeightSize must be valid pos num.";
return RET_ERROR;
}
if (!WeightQuantizer::IsPosNum(config->bitNum)) {
MS_LOG(ERROR) << "bitNum must be valid pos num.";
return RET_ERROR;
}
int bitNum = std::stoi(config->bitNum);
if (bitNum <= 0 || bitNum > 16) {
MS_LOG(ERROR) << "bitNum should be more than 0 and less than 16 currently.";
return RET_ERROR;
}
return RET_OK;
}

WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config) : Quantizer(graph) { WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config) : Quantizer(graph) {
quant_strategy_ = std::make_unique<QuantStrategy>(0, 0); quant_strategy_ = std::make_unique<QuantStrategy>(0, 0);
config_param_ = config; config_param_ = config;
} }


WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const std::string &config_file, const string &weightSize,
const std::string &convWeightChannelThreshold, const std::string &bitNum)
: Quantizer(graph) {
this->config_file_ = config_file;
auto quantSize = static_cast<size_t>(std::stoull(weightSize));
this->bit_num_ = static_cast<size_t>(std::stoull(bitNum));
auto convQuantWeightChannelThreshold = static_cast<size_t>(std::stoull(convWeightChannelThreshold));
WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const converter::Flags &config) : Quantizer(graph) {
this->config_file_ = config.configFile;
auto quantSize = config.quantWeightSize;
this->bit_num_ = config.bitNum;
auto convQuantWeightChannelThreshold = config.quantWeightChannel;
quant_strategy_ = std::make_unique<QuantStrategy>(quantSize, convQuantWeightChannelThreshold); quant_strategy_ = std::make_unique<QuantStrategy>(quantSize, convQuantWeightChannelThreshold);
quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1; quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1;
quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1)); quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1));
@@ -222,7 +186,7 @@ STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) {
return RET_OK; return RET_OK;
} }


STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) {
STATUS WeightQuantizer::DoLstmQuantize(CNodePtr cnode) {
MS_ASSERT(cnode != nullptr); MS_ASSERT(cnode != nullptr);
auto op_name = cnode->fullname_with_scope(); auto op_name = cnode->fullname_with_scope();


@@ -233,110 +197,29 @@ STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) {
MS_LOG(ERROR) << op_name << " inputs is " << cnode->inputs().size(); MS_LOG(ERROR) << op_name << " inputs is " << cnode->inputs().size();
return RET_ERROR; return RET_ERROR;
} }
{
auto weight_i = cnode->input(2);
ParameterPtr param_node;
ParamValueLitePtr param_value;
GetLiteParameter(weight_i, &param_node, &param_value);
if (param_node == nullptr || param_value == nullptr) {
MS_LOG(ERROR) << "GetLiteParameter error";
return RET_ERROR;
}
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(WARNING) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";
return RET_OK;
}
if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) {
MS_LOG(INFO) << op_name << " weight_i cnt: " << param_value->tensor_size() / 4 << " < "
<< quant_strategy_->mWeightSize;
return RET_OK;
}
auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) {
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false, 1);
} else if (type_id_ == kNumberTypeInt16) {
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false, 1);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR;
}

auto status = ProcessLstmWeightByIndex(cnode, primitive_c, 2);
if (status != RET_OK) {
MS_LOG(ERROR) << "Process lstm weight i failed.";
return RET_ERROR;
} }
{
auto weight_h = cnode->input(3);
ParameterPtr param_node;
ParamValueLitePtr param_value;
GetLiteParameter(weight_h, &param_node, &param_value);
if (param_node == nullptr || param_value == nullptr) {
MS_LOG(ERROR) << "GetLiteParameter error";
return RET_ERROR;
}
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(ERROR) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";
return RET_ERROR;
}
auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) {
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false, 2);
} else if (type_id_ == kNumberTypeInt16) {
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false, 2);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
status = SetAbstract(param_value, param_node, primitive_c);
status = ProcessLstmWeightByIndex(cnode, primitive_c, 3);
if (status != RET_OK) {
MS_LOG(ERROR) << "Process lstm weight h failed.";
return RET_ERROR;
}
if (cnode->inputs().size() > 4) {
status = ProcessLstmWeightByIndex(cnode, primitive_c, 4);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
MS_LOG(ERROR) << "Process lstm bias failed.";
return RET_ERROR; return RET_ERROR;
} }
} }
{
if (cnode->inputs().size() > 4) {
auto bias = cnode->input(4);
ParameterPtr param_node;
ParamValueLitePtr param_value;
GetLiteParameter(bias, &param_node, &param_value);
if (param_node == nullptr || param_value == nullptr) {
MS_LOG(ERROR) << "GetLiteParameter error";
return RET_ERROR;
}
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(ERROR) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";
return RET_ERROR;
}
auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) {
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false, 3);
} else if (type_id_ == kNumberTypeInt16) {
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false, 3);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR;
}
}
}
return RET_OK;

return status;
} }


STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) {
STATUS WeightQuantizer::DoGatherQuantize(CNodePtr cnode) {
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
MS_ASSERT(primitive_c != nullptr); MS_ASSERT(primitive_c != nullptr);


@@ -375,6 +258,46 @@ STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) {
return RET_OK; return RET_OK;
} }


STATUS WeightQuantizer::ProcessLstmWeightByIndex(const CNodePtr &cnode, const std::shared_ptr<PrimitiveC> &primitive_c,
const int &index) {
auto op_name = cnode->fullname_with_scope();
auto weight_i = cnode->input(index);
ParameterPtr param_node;
ParamValueLitePtr param_value;
GetLiteParameter(weight_i, &param_node, &param_value);
if (param_node == nullptr || param_value == nullptr) {
MS_LOG(ERROR) << "GetLiteParameter error";
return RET_ERROR;
}
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(WARNING) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";
return RET_OK;
}
if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) {
MS_LOG(INFO) << op_name << " weight_i cnt: " << param_value->tensor_size() / 4 << " < "
<< quant_strategy_->mWeightSize;
return RET_OK;
}
auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) {
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false, index - 1);
} else if (type_id_ == kNumberTypeInt16) {
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false, index - 1);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR;
}
return RET_OK;
}

constexpr float relative_tolerance = 1e-5; constexpr float relative_tolerance = 1e-5;
constexpr float abs_tolerance = 1e-4; constexpr float abs_tolerance = 1e-4;


@@ -510,37 +433,28 @@ STATUS WeightQuantizer::RunFp32Graph(FuncGraphPtr func_graph) {
return RET_OK; return RET_OK;
} }


STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) {
// 0.2 Parse input calib files
auto status = CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_);
if (status != RET_OK) {
MS_LOG(ERROR) << "CollectCalibInputs fail";
return RET_ERROR;
}

MS_LOG(DEBUG) << "run fp32 model";
status = RunFp32Graph(func_graph);
if (status != RET_OK) {
return RET_ERROR;
}

STATUS WeightQuantizer::DoMixedQuantize(const FuncGraphPtr &func_graph) {
auto cnodes = func_graph->GetOrderedCnodes(); auto cnodes = func_graph->GetOrderedCnodes();
int status = RET_OK;
for (auto &cnode : cnodes) { for (auto &cnode : cnodes) {
auto op_type = NodePrimitiveType(cnode); auto op_type = NodePrimitiveType(cnode);
if (op_type == schema::PrimitiveType_Lstm) { if (op_type == schema::PrimitiveType_Lstm) {
status = DoLstmQuntize(cnode);
status = DoLstmQuantize(cnode);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "DoLstmQuntize error";
MS_LOG(ERROR) << "DoLstmQuantize error";
return RET_ERROR; return RET_ERROR;
} }
} else if (op_type == schema::PrimitiveType_Gather) { } else if (op_type == schema::PrimitiveType_Gather) {
status = DoGatherQuntize(cnode);
status = DoGatherQuantize(cnode);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "DoGatherQuntize error";
MS_LOG(ERROR) << "DoGatherQuantize error";
return RET_ERROR; return RET_ERROR;
} }
} }
} }
return status;
}
STATUS WeightQuantizer::CheckImageCnt() {
auto image_cnt = images_.at(0).size(); auto image_cnt = images_.at(0).size();
if (!config_param_.input_shapes.empty()) { if (!config_param_.input_shapes.empty()) {
if (config_param_.input_shapes.size() != image_cnt) { if (config_param_.input_shapes.size() != image_cnt) {
@@ -548,7 +462,62 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) {
return RET_ERROR; return RET_ERROR;
} }
} }

return RET_OK;
}
STATUS WeightQuantizer::GetParamNodeAndValue(const std::shared_ptr<AnfNode> &input_node, const std::string &op_name,
ParameterPtr *param_node, ParamValueLitePtr *param_value) {
if (!input_node->isa<Parameter>()) {
MS_LOG(WARNING) << op_name << " the second input is not parameter";
return RET_CONTINUE;
}
*param_node = input_node->cast<ParameterPtr>();
if (!(*param_node)->has_default()) {
MS_LOG(WARNING) << op_name << " the second input can not convert to parameter";
return RET_CONTINUE;
}
*param_value = std::static_pointer_cast<ParamValueLite>((*param_node)->default_param());
if (*param_value == nullptr) {
MS_LOG(WARNING) << op_name << " the second input can not convert to parameter";
return RET_CONTINUE;
}
if ((*param_value)->tensor_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(WARNING) << op_name << " the second input type is not float";
return RET_CONTINUE;
}
return RET_OK;
}
STATUS WeightQuantizer::TryQuant(const int &bit_num_t, const ParameterPtr &param_node,
const ParamValueLitePtr &param_value, const std::shared_ptr<PrimitiveC> &primitive_c) {
int status;
type_id_ = TypeId::kNumberTypeInt8;
int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1;
int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1));

if (type_id_ == TypeId::kNumberTypeInt8) {
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t,
bit_num_t, true);
} else if (type_id_ == TypeId::kNumberTypeInt16) {
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t,
bit_num_t, true);
} else {
MS_LOG(ERROR) << "unexpected type_id_: " << type_id_;
return RET_ERROR;
}
if (status != RET_OK) {
MS_LOG(ERROR) << "quant filter failed.";
return RET_ERROR;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR;
}
return status;
}
STATUS WeightQuantizer::DoQuantSearch(const FuncGraphPtr &func_graph) {
auto cnodes = func_graph->GetOrderedCnodes();
auto image_cnt = images_.at(0).size();
int status = RET_OK;
for (auto iter = cnodes.end(); iter != cnodes.begin();) { for (auto iter = cnodes.end(); iter != cnodes.begin();) {
auto cnode = *(--iter); auto cnode = *(--iter);
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
@@ -561,22 +530,10 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) {
<< " type: " << schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive_c->Type()); << " type: " << schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive_c->Type());
if (quant_strategy_->CanConvOpQuantized(cnode) || quant_strategy_->CanMulOpQuantized(cnode)) { if (quant_strategy_->CanConvOpQuantized(cnode) || quant_strategy_->CanMulOpQuantized(cnode)) {
auto input_node = cnode->input(2); auto input_node = cnode->input(2);
if (!input_node->isa<Parameter>()) {
MS_LOG(WARNING) << op_name << " the second input is not parameter";
continue;
}
auto param_node = input_node->cast<ParameterPtr>();
if (!param_node->has_default()) {
MS_LOG(WARNING) << op_name << " the second input can not convert to parameter";
continue;
}
auto param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
if (param_value == nullptr) {
MS_LOG(WARNING) << op_name << " the second input can not convert to parameter";
continue;
}
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(WARNING) << op_name << " the second input type is not float";
ParameterPtr param_node;
ParamValueLitePtr param_value;
status = GetParamNodeAndValue(input_node, op_name, &param_node, &param_value);
if (status == RET_CONTINUE) {
continue; continue;
} }
// copy origin data in case to recover // copy origin data in case to recover
@@ -591,27 +548,9 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) {
} }
// 1. try quant // 1. try quant
for (int bit_num_t = 2; bit_num_t <= 8; bit_num_t++) { for (int bit_num_t = 2; bit_num_t <= 8; bit_num_t++) {
type_id_ = TypeId::kNumberTypeInt8;
int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1;
int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1));

if (type_id_ == TypeId::kNumberTypeInt8) {
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t,
quant_min_t, bit_num_t, true);
} else if (type_id_ == TypeId::kNumberTypeInt16) {
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t,
quant_min_t, bit_num_t, true);
} else {
MS_LOG(ERROR) << "unexpected type_id_: " << type_id_;
return RET_ERROR;
}
status = TryQuant(bit_num_t, param_node, param_value, primitive_c);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "quant filter fail.";
return RET_ERROR;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
MS_LOG(ERROR) << "TryQuant failed.";
return RET_ERROR; return RET_ERROR;
} }
// 2. evaluate the quant // 2. evaluate the quant
@@ -679,6 +618,41 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) {
free(origin_data); free(origin_data);
} // if: conv and matmul } // if: conv and matmul
} // end loop: all cnode } // end loop: all cnode
return status;
}

STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) {
// 0.2 Parse input calib files
auto status = CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_);
if (status != RET_OK) {
MS_LOG(ERROR) << "CollectCalibInputs failed.";
return RET_ERROR;
}

status = RunFp32Graph(func_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "RunFp32Graph failed.";
return RET_ERROR;
}

status = DoMixedQuantize(func_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoMixedQuantize failed.";
return RET_ERROR;
}

status = CheckImageCnt();
if (status != RET_OK) {
MS_LOG(ERROR) << "CheckImageCnt failed.";
return RET_ERROR;
}

status = DoQuantSearch(func_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantSearch failed.";
return RET_ERROR;
}

for (const auto &kv : opname_bit_) { for (const auto &kv : opname_bit_) {
MS_LOG(INFO) << "op: " << kv.first << " bit:" << kv.second; MS_LOG(INFO) << "op: " << kv.first << " bit:" << kv.second;
} }
@@ -709,15 +683,15 @@ STATUS WeightQuantizer::DoFixedQuant(FuncGraphPtr func_graph) {
return RET_ERROR; return RET_ERROR;
} }
} else if (op_type == schema::PrimitiveType_Lstm) { } else if (op_type == schema::PrimitiveType_Lstm) {
auto status = DoLstmQuntize(cnode);
auto status = DoLstmQuantize(cnode);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "DoLstmQuntize error";
MS_LOG(ERROR) << "DoLstmQuantize error";
return RET_ERROR; return RET_ERROR;
} }
} else if (op_type == schema::PrimitiveType_Gather) { } else if (op_type == schema::PrimitiveType_Gather) {
auto status = DoGatherQuntize(cnode);
auto status = DoGatherQuantize(cnode);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "DoGatherQuntize error";
MS_LOG(ERROR) << "DoGatherQuantize error";
return RET_ERROR; return RET_ERROR;
} }
} else { } else {


+ 14
- 6
mindspore/lite/tools/converter/quantizer/weight_quantizer.h View File

@@ -36,18 +36,18 @@
namespace mindspore::lite::quant { namespace mindspore::lite::quant {
class WeightQuantizer : public Quantizer { class WeightQuantizer : public Quantizer {
public: public:
WeightQuantizer(FuncGraphPtr graph, const std::string &config_file, const std::string &weightSize,
const std::string &covWeightChannelThreshold, const std::string &bitNum);
WeightQuantizer(FuncGraphPtr graph, const converter::Flags &config);
WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config); WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config);
~WeightQuantizer(); ~WeightQuantizer();


STATUS DoQuantize(FuncGraphPtr func_graph) override; STATUS DoQuantize(FuncGraphPtr func_graph) override;
STATUS DoConvQuantize(CNodePtr); STATUS DoConvQuantize(CNodePtr);
STATUS DoMulQuantize(CNodePtr); STATUS DoMulQuantize(CNodePtr);
STATUS DoLstmQuntize(CNodePtr cnode);
STATUS DoGatherQuntize(CNodePtr cnode);
static STATUS WeightQuantInputCheck(const converter::Flags *config);
static bool IsPosNum(const std::string &str);
STATUS DoLstmQuantize(CNodePtr cnode);
STATUS DoGatherQuantize(CNodePtr cnode);

STATUS ProcessLstmWeightByIndex(const CNodePtr &cnode, const std::shared_ptr<PrimitiveC> &primitive_c,
const int &index);


int quant_max_{127}; int quant_max_{127};
int quant_min_{-128}; int quant_min_{-128};
@@ -66,6 +66,14 @@ class WeightQuantizer : public Quantizer {
STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr<PrimitiveC> primitive_c); STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr<PrimitiveC> primitive_c);
STATUS DoFixedQuant(FuncGraphPtr); STATUS DoFixedQuant(FuncGraphPtr);
STATUS RunFp32Graph(FuncGraphPtr); STATUS RunFp32Graph(FuncGraphPtr);

STATUS DoMixedQuantize(const FuncGraphPtr &func_graph);
STATUS CheckImageCnt();
STATUS GetParamNodeAndValue(const std::shared_ptr<AnfNode> &input_node, const std::string &op_name,
ParameterPtr *param_node, ParamValueLitePtr *param_value);
STATUS TryQuant(const int &bit_num_t, const ParameterPtr &param_node, const ParamValueLitePtr &param_value,
const std::shared_ptr<PrimitiveC> &primitive_c);
STATUS DoQuantSearch(const FuncGraphPtr &func_graph);
}; };
} // namespace mindspore::lite::quant } // namespace mindspore::lite::quant
#endif #endif

Loading…
Cancel
Save