|
|
|
@@ -50,11 +50,11 @@ void TopKCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const st |
|
|
|
std::vector<size_t> idx(inner_size_);
|
|
|
|
auto base_input = i * inner_size_;
|
|
|
|
std::iota(idx.begin(), idx.end(), base_input);
|
|
|
|
std::sort(idx.begin(), idx.end(),
|
|
|
|
[&input](size_t index_1, size_t index_2) { return input[index_1] > input[index_2]; });
|
|
|
|
std::stable_sort(idx.begin(), idx.end(),
|
|
|
|
[&input](size_t index_1, size_t index_2) { return input[index_1] > input[index_2]; });
|
|
|
|
auto base_output = i * k_num;
|
|
|
|
if (!sorted_) {
|
|
|
|
std::sort(idx.begin(), idx.begin() + k_num);
|
|
|
|
std::stable_sort(idx.begin(), idx.begin() + k_num);
|
|
|
|
}
|
|
|
|
for (int j = 0; j < k_num; ++j) {
|
|
|
|
indices[base_output + j] = idx[j] - base_input;
|
|
|
|
|