Browse Source

add condition stop for kmeans

tags/v1.1.0
guohongzilong 5 years ago
parent
commit
b12c8564b5
1 changed files with 11 additions and 0 deletions
  1. +11
    -0
      mindspore/lite/tools/converter/quantizer/quantize_util.cc

+ 11
- 0
mindspore/lite/tools/converter/quantizer/quantize_util.cc View File

@@ -402,11 +402,13 @@ static std::vector<float> InitClusters(float *data, size_t elem_count, size_t k)
std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epochs, schema::QuantParamT *quantParam) { std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epochs, schema::QuantParamT *quantParam) {
std::vector<float> clusters = InitClusters(data, elem_count, k); std::vector<float> clusters = InitClusters(data, elem_count, k);
std::vector<int8_t> clusters_index{}; std::vector<int8_t> clusters_index{};
double error{0};
if (clusters.size() < k) { if (clusters.size() < k) {
MS_LOG(WARNING) << "K is less than the size of data so KMeans function is not executed."; MS_LOG(WARNING) << "K is less than the size of data so KMeans function is not executed.";
return clusters_index; return clusters_index;
} }
for (size_t epoch = 0; epoch < epochs; epoch++) { for (size_t epoch = 0; epoch < epochs; epoch++) {
double error_cur{0};
clusters_index.clear(); clusters_index.clear();
std::vector<std::vector<float>> clusters_data(clusters.size()); std::vector<std::vector<float>> clusters_data(clusters.size());
for (size_t i = 0; i < elem_count; i++) { for (size_t i = 0; i < elem_count; i++) {
@@ -426,6 +428,15 @@ std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epoc
clusters[j] = std::accumulate(clusters_data[j].begin(), clusters_data[j].end(), 0.0) / clusters_data[j].size(); clusters[j] = std::accumulate(clusters_data[j].begin(), clusters_data[j].end(), 0.0) / clusters_data[j].size();
} }
} }
// compare error
for (size_t j = 0; j < elem_count; j++) {
error_cur += pow(data[j] - clusters[clusters_index[j]], 2);
}
error_cur = pow(error_cur / elem_count, 0.5);
if (std::abs((error_cur - error) / error_cur) < 1e-6) {
break;
}
error = error_cur;
} }
// update data // update data
quantParam->clusters = clusters; quantParam->clusters = clusters;


Loading…
Cancel
Save