|
|
|
@@ -27,6 +27,33 @@ using std::vector; |
|
|
|
namespace mindspore { |
|
|
|
namespace lite { |
|
|
|
namespace quant { |
|
|
|
bool WeightQuantizer::IsPosNum(const std::string &str) { |
|
|
|
for (size_t i = 0; i < str.size(); i++) { |
|
|
|
if (str.at(i) < '0' || str.at(i) > '9') { |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (str.at(i) == '0' && i == 0 && str.size() != 1) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) { |
|
|
|
MS_ASSERT(config != nullptr); |
|
|
|
if (!WeightQuantizer::IsPosNum(config->convWeightQuantChannelThreshold)) { |
|
|
|
MS_LOG(ERROR) << "convWeightQuantChannelThreshold must be valid pos num."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
if (!WeightQuantizer::IsPosNum(config->quantSize)) { |
|
|
|
MS_LOG(ERROR) << "quantSize must be valid pos num."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
if (!WeightQuantizer::IsPosNum(config->bitNum) || config->bitNum != "8") { |
|
|
|
MS_LOG(ERROR) << "bitNum must be valid pos num, current only support 8 bit weight quant."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const string &weightSize, |
|
|
|
const std::string &convWeightChannelThreshold, const std::string &bitNum) |
|
|
|
: Quantizer(graph) { |
|
|
|
|