| @@ -402,11 +402,13 @@ static std::vector<float> InitClusters(float *data, size_t elem_count, size_t k) | |||||
| std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epochs, schema::QuantParamT *quantParam) { | std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epochs, schema::QuantParamT *quantParam) { | ||||
| std::vector<float> clusters = InitClusters(data, elem_count, k); | std::vector<float> clusters = InitClusters(data, elem_count, k); | ||||
| std::vector<int8_t> clusters_index{}; | std::vector<int8_t> clusters_index{}; | ||||
| double error{0}; | |||||
| if (clusters.size() < k) { | if (clusters.size() < k) { | ||||
| MS_LOG(WARNING) << "K is less than the size of data so KMeans function is not executed."; | MS_LOG(WARNING) << "K is less than the size of data so KMeans function is not executed."; | ||||
| return clusters_index; | return clusters_index; | ||||
| } | } | ||||
| for (size_t epoch = 0; epoch < epochs; epoch++) { | for (size_t epoch = 0; epoch < epochs; epoch++) { | ||||
| double error_cur{0}; | |||||
| clusters_index.clear(); | clusters_index.clear(); | ||||
| std::vector<std::vector<float>> clusters_data(clusters.size()); | std::vector<std::vector<float>> clusters_data(clusters.size()); | ||||
| for (size_t i = 0; i < elem_count; i++) { | for (size_t i = 0; i < elem_count; i++) { | ||||
| @@ -426,6 +428,15 @@ std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epoc | |||||
| clusters[j] = std::accumulate(clusters_data[j].begin(), clusters_data[j].end(), 0.0) / clusters_data[j].size(); | clusters[j] = std::accumulate(clusters_data[j].begin(), clusters_data[j].end(), 0.0) / clusters_data[j].size(); | ||||
| } | } | ||||
| } | } | ||||
| // compare error | |||||
| for (size_t j = 0; j < elem_count; j++) { | |||||
| error_cur += pow(data[j] - clusters[clusters_index[j]], 2); | |||||
| } | |||||
| error_cur = pow(error_cur / elem_count, 0.5); | |||||
| if (std::abs((error_cur - error) / error_cur) < 1e-6) { | |||||
| break; | |||||
| } | |||||
| error = error_cur; | |||||
| } | } | ||||
| // update data | // update data | ||||
| quantParam->clusters = clusters; | quantParam->clusters = clusters; | ||||