Merge pull request !6487 from wangchangkai/mastertags/v1.0.0
| @@ -80,9 +80,8 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| } else if (config->quantType == schema::QuantType_WeightQuant) { | } else if (config->quantType == schema::QuantType_WeightQuant) { | ||||
| auto bitNum = static_cast<size_t>(std::stoull(config->bitNum)); | |||||
| if (bitNum != quant::UINT8_QUANTIZATION) { | |||||
| MS_LOG(ERROR) << "Current Only Support 8 bit weight quant"; | |||||
| if (quant::WeightQuantizer::WeightQuantInputCheck(config) != RET_OK) { | |||||
| MS_LOG(ERROR) << "weight quant input param error"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -124,7 +124,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti | |||||
| auto dims = weight->tensor_shape(); | auto dims = weight->tensor_shape(); | ||||
| if (per_channel) { | if (per_channel) { | ||||
| if (dims.size() != 4) { | if (dims.size() != 4) { | ||||
| MS_LOG(ERROR) << "weight dims size error: " << dims.size() << " Back to per layer."; | |||||
| MS_LOG(ERROR) << "weight dims size: " << dims.size() << " switch to per-layer quant mode."; | |||||
| per_channel = false; | per_channel = false; | ||||
| } else { | } else { | ||||
| uint32_t channels = dims[0]; | uint32_t channels = dims[0]; | ||||
| @@ -27,6 +27,33 @@ using std::vector; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace quant { | namespace quant { | ||||
| bool WeightQuantizer::IsPosNum(const std::string &str) { | |||||
| for (size_t i = 0; i < str.size(); i++) { | |||||
| if (str.at(i) < '0' || str.at(i) > '9') { | |||||
| return false; | |||||
| } | |||||
| if (str.at(i) == '0' && i == 0 && str.size() != 1) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) { | |||||
| MS_ASSERT(config != nullptr); | |||||
| if (!WeightQuantizer::IsPosNum(config->convWeightQuantChannelThreshold)) { | |||||
| MS_LOG(ERROR) << "convWeightQuantChannelThreshold must be valid pos num."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!WeightQuantizer::IsPosNum(config->quantSize)) { | |||||
| MS_LOG(ERROR) << "quantSize must be valid pos num."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!WeightQuantizer::IsPosNum(config->bitNum) || config->bitNum != "8") { | |||||
| MS_LOG(ERROR) << "bitNum must be valid pos num, current only support 8 bit weight quant."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const string &weightSize, | WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const string &weightSize, | ||||
| const std::string &convWeightChannelThreshold, const std::string &bitNum) | const std::string &convWeightChannelThreshold, const std::string &bitNum) | ||||
| : Quantizer(graph) { | : Quantizer(graph) { | ||||
| @@ -41,6 +41,8 @@ class WeightQuantizer : public Quantizer { | |||||
| STATUS DoQuantize(FuncGraphPtr funcGraph) override; | STATUS DoQuantize(FuncGraphPtr funcGraph) override; | ||||
| STATUS DoConvQuantize(const std::list<CNodePtr> &nodes); | STATUS DoConvQuantize(const std::list<CNodePtr> &nodes); | ||||
| STATUS DoMulQuantize(const std::list<CNodePtr> &nodes); | STATUS DoMulQuantize(const std::list<CNodePtr> &nodes); | ||||
| static STATUS WeightQuantInputCheck(const converter::Flags *config); | |||||
| static bool IsPosNum(const std::string &str); | |||||
| int quant_max{INT8_MAX}; | int quant_max{INT8_MAX}; | ||||
| int quant_min{INT8_MIN}; | int quant_min{INT8_MIN}; | ||||
| private: | private: | ||||