diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index 15e97b7f91..927789cac9 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -123,10 +123,15 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primiti int quant_max, int quant_min, size_t bitNum, bool per_channel) { auto dims = weight->tensor_shape(); if (per_channel) { - if (dims.size() != 4) { + if (dims.size() != 4 && dims.size() != 2) { MS_LOG(INFO) << "weight dims size: " << dims.size() << " switch to per-layer quant mode."; per_channel = false; } else { + auto op_type = (schema::PrimitiveType)primitive_c->Type(); + if (dims.size() == 2 && op_type != schema::PrimitiveType_FullConnection) { + MS_LOG(INFO) << "weight dims size is 2 but op_type is not FullConnection, switch to per-layer quant mode."; + per_channel = false; + } uint32_t channels = dims[0]; if (channels == 0) { MS_LOG(ERROR) << "channels is 0";