|
|
|
@@ -204,11 +204,9 @@ __global__ void TopKBlock(int outer_size, int inner_size, const T *input, T *out |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T, typename S> |
|
|
|
void FastTopK(const int outer_size, const int inner_size, const T *input, const S *k, T *output, S *output_index, |
|
|
|
void FastTopK(const int outer_size, const int inner_size, const T *input, S k_cut, T *output, S *output_index, |
|
|
|
const T init_K, cudaStream_t stream) { |
|
|
|
int block_num_limit = outer_size < 128 ? outer_size : 128; |
|
|
|
S k_cut = 0; |
|
|
|
cudaMemcpy(&k_cut, k, sizeof(S), cudaMemcpyDeviceToHost); |
|
|
|
if (k_cut > inner_size) k_cut = inner_size; |
|
|
|
|
|
|
|
if (k_cut <= 32) { |
|
|
|
@@ -223,5 +221,5 @@ void FastTopK(const int outer_size, const int inner_size, const T *input, const |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template void FastTopK(const int outer_size, const int inner_size, const float *input, const int *k, float *output, |
|
|
|
template void FastTopK(const int outer_size, const int inner_size, const float *input, int k_cut, float *output, |
|
|
|
int *output_index, const float init_K, cudaStream_t stream); |