| @@ -142,6 +142,7 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) { | |||||
| if (!mStrategy->CanMulOpQuantized(node)) { | if (!mStrategy->CanMulOpQuantized(node)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| auto already_quant = false; | |||||
| ParamValueLitePtr param_value = nullptr; | ParamValueLitePtr param_value = nullptr; | ||||
| ParameterPtr param_node = nullptr; | ParameterPtr param_node = nullptr; | ||||
| for (size_t i = 1; i < node->size(); i++) { | for (size_t i = 1; i < node->size(); i++) { | ||||
| @@ -151,8 +152,16 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) { | |||||
| if ((param_node != nullptr) && param_node->has_default()) { | if ((param_node != nullptr) && param_node->has_default()) { | ||||
| param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param()); | param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param()); | ||||
| if ((param_value == nullptr) || (param_value->tensor_size() == 0) || | if ((param_value == nullptr) || (param_value->tensor_size() == 0) || | ||||
| (param_value->tensor_addr() == nullptr) || | |||||
| (param_value->tensor_type() != mindspore::kNumberTypeFloat32)) { | |||||
| (param_value->tensor_addr() == nullptr)) { | |||||
| param_value = nullptr; | |||||
| continue; | |||||
| } else if (param_value->tensor_type() == mindspore::kNumberTypeInt8 || | |||||
| param_value->tensor_type() == mindspore::kNumberTypeInt16) { | |||||
| MS_LOG(INFO) << "the node: " << node->fullname_with_scope() << " input_i: " << i << "has been " | |||||
| << " quantized"; | |||||
| already_quant = true; | |||||
| break; | |||||
| } else if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) { | |||||
| param_value = nullptr; | param_value = nullptr; | ||||
| continue; | continue; | ||||
| } else { | } else { | ||||
| @@ -161,6 +170,11 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| if (already_quant) { | |||||
| continue; | |||||
| } | |||||
| if (param_value == nullptr) { | if (param_value == nullptr) { | ||||
| MS_LOG(ERROR) << "No valid input param node !"; | MS_LOG(ERROR) << "No valid input param node !"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||