|
|
|
@@ -123,10 +123,15 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti |
|
|
|
int quant_max, int quant_min, size_t bitNum, bool per_channel) { |
|
|
|
auto dims = weight->tensor_shape(); |
|
|
|
if (per_channel) { |
|
|
|
if (dims.size() != 4) { |
|
|
|
if (dims.size() != 4 && dims.size() != 2) { |
|
|
|
MS_LOG(INFO) << "weight dims size: " << dims.size() << " switch to per-layer quant mode."; |
|
|
|
per_channel = false; |
|
|
|
} else { |
|
|
|
auto op_type = (schema::PrimitiveType)primitive_c->Type(); |
|
|
|
if (dims.size() == 2 && op_type != schema::PrimitiveType_FullConnection) { |
|
|
|
MS_LOG(INFO) << "weight dims size is 2 but op_type is not FullConnection, switch to per-layer quant mode."; |
|
|
|
per_channel = false; |
|
|
|
} |
|
|
|
uint32_t channels = dims[0]; |
|
|
|
if (channels == 0) { |
|
|
|
MS_LOG(ERROR) << "channels is 0"; |
|
|
|
|