| @@ -227,10 +227,10 @@ struct DivergInfo { | |||||
| int zero_point = 0; | int zero_point = 0; | ||||
| if (quant_min == 0 && quant_max == 255) { | if (quant_min == 0 && quant_max == 255) { | ||||
| zero_point = 128; | zero_point = 128; | ||||
| } else if (quant_min == -128 && quant_max == 127) { | |||||
| } else if (quant_min == -127 && quant_max == 127) { | |||||
| zero_point = 0; | zero_point = 0; | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "unexpectd quant range, quant_min: " << quant_min << " quant_max: " << quant_max; | |||||
| MS_LOG(WARNING) << "unexpectd quant range, quant_min: " << quant_min << " quant_max: " << quant_max; | |||||
| } | } | ||||
| return std::make_pair(this->cnode, zero_point); | return std::make_pair(this->cnode, zero_point); | ||||
| } | } | ||||
| @@ -486,7 +486,7 @@ PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, in | |||||
| this->target_type_ = target_type; | this->target_type_ = target_type; | ||||
| if (target_type == kNumberTypeInt8) { | if (target_type == kNumberTypeInt8) { | ||||
| quant_max = (1 << (this->bit_num - 1)) - 1; // 127 | quant_max = (1 << (this->bit_num - 1)) - 1; // 127 | ||||
| quant_min = -(1 << (this->bit_num - 1)); // -128 | |||||
| quant_min = -quant_max; // -127 | |||||
| } else if (target_type == kNumberTypeUInt8) { | } else if (target_type == kNumberTypeUInt8) { | ||||
| quant_max = (1 << this->bit_num) - 1; // 255 | quant_max = (1 << this->bit_num) - 1; // 255 | ||||
| quant_min = 0; | quant_min = 0; | ||||
| @@ -100,7 +100,7 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { | |||||
| } | } | ||||
| std::vector<AnfNodePtr> op_inputs = {value_node, input_cnode}; | std::vector<AnfNodePtr> op_inputs = {value_node, input_cnode}; | ||||
| auto quant_cast_cnode = graph->NewCNode(op_inputs); | auto quant_cast_cnode = graph->NewCNode(op_inputs); | ||||
| quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast"); | |||||
| quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast_" + std::to_string(i)); | |||||
| cnode->set_input(i, quant_cast_cnode); | cnode->set_input(i, quant_cast_cnode); | ||||
| MS_LOG(DEBUG) << "Add quant cast. " | MS_LOG(DEBUG) << "Add quant cast. " | ||||
| << "cur_node: " << cnode->fullname_with_scope() << " quant_type: " << curnode_quant_type | << "cur_node: " << cnode->fullname_with_scope() << " quant_type: " << curnode_quant_type | ||||
| @@ -220,11 +220,11 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl | |||||
| bool narrowRange, int numBits) { | bool narrowRange, int numBits) { | ||||
| MS_ASSERT(quantParam != nullptr); | MS_ASSERT(quantParam != nullptr); | ||||
| if (mMin > 0.0f) { | if (mMin > 0.0f) { | ||||
| MS_LOG(ERROR) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision"; | |||||
| MS_LOG(DEBUG) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision"; | |||||
| mMin = 0.0f; | mMin = 0.0f; | ||||
| } | } | ||||
| if (mMax < 0.0f) { | if (mMax < 0.0f) { | ||||
| MS_LOG(ERROR) << "mMax " << mMax << " is smaller than 0, set to 0, this may course low precision"; | |||||
| MS_LOG(DEBUG) << "mMax " << mMax << " is smaller than 0, set to 0, this may course low precision"; | |||||
| mMax = 0.0f; | mMax = 0.0f; | ||||
| } | } | ||||
| if (mMin > mMax) { | if (mMin > mMax) { | ||||