Browse Source

!6487 add weight quant input param check

Merge pull request !6487 from wangchangkai/master
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
aef2eb6caa
4 changed files with 32 additions and 4 deletions
  1. +2
    -3
      mindspore/lite/tools/converter/anf_transform.cc
  2. +1
    -1
      mindspore/lite/tools/converter/quantizer/quantize_util.h
  3. +27
    -0
      mindspore/lite/tools/converter/quantizer/weight_quantizer.cc
  4. +2
    -0
      mindspore/lite/tools/converter/quantizer/weight_quantizer.h

+ 2
- 3
mindspore/lite/tools/converter/anf_transform.cc View File

@@ -80,9 +80,8 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
return nullptr;
}
} else if (config->quantType == schema::QuantType_WeightQuant) {
auto bitNum = static_cast<size_t>(std::stoull(config->bitNum));
if (bitNum != quant::UINT8_QUANTIZATION) {
MS_LOG(ERROR) << "Current Only Support 8 bit weight quant";
if (quant::WeightQuantizer::WeightQuantInputCheck(config) != RET_OK) {
MS_LOG(ERROR) << "weight quant input param error";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}


+ 1
- 1
mindspore/lite/tools/converter/quantizer/quantize_util.h View File

@@ -124,7 +124,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
auto dims = weight->tensor_shape();
if (per_channel) {
if (dims.size() != 4) {
MS_LOG(ERROR) << "weight dims size error: " << dims.size() << " Back to per layer.";
MS_LOG(ERROR) << "weight dims size: " << dims.size() << " switch to per-layer quant mode.";
per_channel = false;
} else {
uint32_t channels = dims[0];


+ 27
- 0
mindspore/lite/tools/converter/quantizer/weight_quantizer.cc View File

@@ -27,6 +27,33 @@ using std::vector;
namespace mindspore {
namespace lite {
namespace quant {
bool WeightQuantizer::IsPosNum(const std::string &str) {
for (size_t i = 0; i < str.size(); i++) {
if (str.at(i) < '0' || str.at(i) > '9') {
return false;
}
if (str.at(i) == '0' && i == 0 && str.size() != 1) {
return false;
}
}
return true;
}
STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) {
MS_ASSERT(config != nullptr);
if (!WeightQuantizer::IsPosNum(config->convWeightQuantChannelThreshold)) {
MS_LOG(ERROR) << "convWeightQuantChannelThreshold must be valid pos num.";
return RET_ERROR;
}
if (!WeightQuantizer::IsPosNum(config->quantSize)) {
MS_LOG(ERROR) << "quantSize must be valid pos num.";
return RET_ERROR;
}
if (!WeightQuantizer::IsPosNum(config->bitNum) || config->bitNum != "8") {
MS_LOG(ERROR) << "bitNum must be valid pos num, current only support 8 bit weight quant.";
return RET_ERROR;
}
return RET_OK;
}
WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const string &weightSize,
const std::string &convWeightChannelThreshold, const std::string &bitNum)
: Quantizer(graph) {


+ 2
- 0
mindspore/lite/tools/converter/quantizer/weight_quantizer.h View File

@@ -41,6 +41,8 @@ class WeightQuantizer : public Quantizer {
STATUS DoQuantize(FuncGraphPtr funcGraph) override;
STATUS DoConvQuantize(const std::list<CNodePtr> &nodes);
STATUS DoMulQuantize(const std::list<CNodePtr> &nodes);
static STATUS WeightQuantInputCheck(const converter::Flags *config);
static bool IsPosNum(const std::string &str);
int quant_max{INT8_MAX};
int quant_min{INT8_MIN};
private:


Loading…
Cancel
Save