|
|
|
@@ -130,7 +130,7 @@ T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quan |
|
|
|
} |
|
|
|
template <typename T> |
|
|
|
STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType, |
|
|
|
int quant_max, int quant_min, size_t bitNum, bool per_channel) { |
|
|
|
int quant_max, int quant_min, size_t bitNum, bool per_channel, bool k_means = false) { |
|
|
|
auto dims = weight->tensor_shape(); |
|
|
|
auto op_type = (schema::PrimitiveType)primitive_c->Type(); |
|
|
|
if (per_channel) { |
|
|
|
@@ -208,7 +208,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti |
|
|
|
average_raw += raw_data; |
|
|
|
} |
|
|
|
} |
|
|
|
if (quantType == QuantType_WeightQuant && quant_param.clusters.size() == 0) { |
|
|
|
if (quantType == QuantType_WeightQuant && !k_means) { |
|
|
|
// mean |
|
|
|
average_dequant = average_dequant / one_filter_size; |
|
|
|
average_raw = average_raw / one_filter_size; |
|
|
|
@@ -256,7 +256,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti |
|
|
|
} |
|
|
|
|
|
|
|
schema::QuantParamT quant_param; |
|
|
|
if (quant_param.clusters.size() == 0) { |
|
|
|
if (!k_means) { |
|
|
|
STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "CalQuantizationParams failed" << status; |
|
|
|
@@ -267,7 +267,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti |
|
|
|
// update data and datatype |
|
|
|
for (uint32_t i = 0; i < elem_count; i++) { |
|
|
|
float raw_data = raw_datas[i]; |
|
|
|
if (quant_param.clusters.size() == 0) { |
|
|
|
if (!k_means) { |
|
|
|
auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min); |
|
|
|
quant_datas[i] = quant_data; |
|
|
|
} |
|
|
|
|