diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index 24993e95c5..4e393f9a06 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -402,11 +402,13 @@ static std::vector InitClusters(float *data, size_t elem_count, size_t k) std::vector KMeans(float *data, size_t elem_count, size_t k, size_t epochs, schema::QuantParamT *quantParam) { std::vector clusters = InitClusters(data, elem_count, k); std::vector clusters_index{}; + double error{0}; if (clusters.size() < k) { MS_LOG(WARNING) << "K is less than the size of data so KMeans function is not executed."; return clusters_index; } for (size_t epoch = 0; epoch < epochs; epoch++) { + double error_cur{0}; clusters_index.clear(); std::vector> clusters_data(clusters.size()); for (size_t i = 0; i < elem_count; i++) { @@ -426,6 +428,15 @@ std::vector KMeans(float *data, size_t elem_count, size_t k, size_t epoc clusters[j] = std::accumulate(clusters_data[j].begin(), clusters_data[j].end(), 0.0) / clusters_data[j].size(); } } + // compare error + for (size_t j = 0; j < elem_count; j++) { + error_cur += pow(data[j] - clusters[clusters_index[j]], 2); + } + error_cur = pow(error_cur / elem_count, 0.5); + if (std::abs((error_cur - error) / error_cur) < 1e-6) { + break; + } + error = error_cur; } // update data quantParam->clusters = clusters;