GitOrigin-RevId: 6aa35928c8
tags/v1.2.0
| @@ -22,20 +22,25 @@ template <typename ctype> | |||
| void TopKImpl::dispatch_with_ctype(int k, size_t m, size_t n, ptrdiff_t lda, | |||
| const ctype* data, ctype* values, | |||
| int* indices, void* workspace) { | |||
| auto stream = concrete_handle(handle())->stream(); | |||
| auto _handle = concrete_handle(handle()); | |||
| auto stream = _handle->stream(); | |||
| size_t grid_dim_y_limit = _handle->device_prop().maxGridSize[1]; | |||
| switch (param().mode) { | |||
| case Param::Mode::KTH_ONLY: | |||
| cuda_check(topk::find_kth_radix<ctype>(data, values, workspace, m, | |||
| n, lda, k, stream)); | |||
| n, lda, k, grid_dim_y_limit, | |||
| stream)); | |||
| return; | |||
| case Param::Mode::VALUE_IDX_NOSORT: { | |||
| WorkspaceBundle wk_bundle{workspace, {m * sizeof(ctype), 1}}; | |||
| auto thresh = static_cast<ctype*>(wk_bundle.get(0)); | |||
| auto real_wk = wk_bundle.get(1); | |||
| cuda_check(topk::find_kth_radix<ctype>(data, thresh, real_wk, m, n, | |||
| lda, k, stream)); | |||
| lda, k, grid_dim_y_limit, | |||
| stream)); | |||
| cuda_check(topk::topk_select<ctype>(data, thresh, values, indices, | |||
| real_wk, m, n, lda, k, stream)); | |||
| real_wk, m, n, lda, k, | |||
| grid_dim_y_limit, stream)); | |||
| return; | |||
| } | |||
| case Param::Mode::VALUE_IDX_SORTED: { | |||
| @@ -48,10 +53,11 @@ void TopKImpl::dispatch_with_ctype(int k, size_t m, size_t n, ptrdiff_t lda, | |||
| auto nosort_idx = static_cast<int32_t*>(wk_bundle.get(2)); | |||
| auto real_wk = wk_bundle.get(3); | |||
| cuda_check(topk::find_kth_radix<ctype>(data, thresh, real_wk, m, n, | |||
| lda, k, stream)); | |||
| lda, k, grid_dim_y_limit, | |||
| stream)); | |||
| cuda_check(topk::topk_select<ctype>(data, thresh, nosort_values, | |||
| nosort_idx, real_wk, m, n, lda, | |||
| k, stream)); | |||
| k, grid_dim_y_limit, stream)); | |||
| argsort::forward(nosort_values, values, indices, real_wk, m, | |||
| std::abs(k), k > 0, stream, nosort_idx); | |||
| return; | |||
| @@ -89,9 +95,11 @@ size_t TopKImpl::get_workspace_in_bytes(int k, const TensorLayout& data, | |||
| MEGDNN_MARK_USED_VAR(indices); | |||
| size_t m = data[0], n = data[1]; | |||
| size_t kabs = std::abs(k); | |||
| size_t grid_dim_y_limit = | |||
| concrete_handle(handle())->device_prop().maxGridSize[1]; | |||
| megdnn_assert(std::max(m, n) <= | |||
| static_cast<size_t>(std::numeric_limits<int>::max())); | |||
| size_t kth = topk::find_kth_radix_workspace(m, n), | |||
| size_t kth = topk::find_kth_radix_workspace(m, n, grid_dim_y_limit), | |||
| sel = topk::topk_select_workspace(m, n); | |||
| auto ctsize = data.dtype.size(); | |||
| switch (param().mode) { | |||
| @@ -468,17 +468,9 @@ static size_t get_scan_workspace(uint32_t size) { | |||
| } // namespace select | |||
| } // namespace cuda_topk_impl | |||
| uint32_t topk::find_kth_radix_workspace(uint32_t batch, uint32_t length) { | |||
| uint32_t topk::find_kth_radix_workspace(uint32_t batch, uint32_t length, | |||
| uint32_t grid_dim_y_limit) { | |||
| using namespace cuda_topk_impl::kth; | |||
| int device_id; | |||
| if (cudaGetDevice(&device_id) != cudaSuccess) { | |||
| megdnn_trap(); | |||
| } | |||
| cudaDeviceProp prop; | |||
| if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) { | |||
| megdnn_trap(); | |||
| } | |||
| uint32_t grid_dim_y_limit = prop.maxGridSize[1]; | |||
| uint32_t limit = batch > grid_dim_y_limit ? grid_dim_y_limit : batch; | |||
| return (limit * get_grid_dim_x(length) * NR_BUCKET + limit * 2) * | |||
| sizeof(uint32_t); | |||
| @@ -488,6 +480,7 @@ template <typename ctype> | |||
| cudaError_t topk::find_kth_radix(const ctype* input, ctype* output, | |||
| void* workspace, uint32_t batch, | |||
| uint32_t length, int32_t lda, int32_t k, | |||
| uint32_t grid_dim_y_limit, | |||
| cudaStream_t stream) { | |||
| using namespace cuda_topk_impl::kth; | |||
| if (!k) { | |||
| @@ -502,16 +495,6 @@ cudaError_t topk::find_kth_radix(const ctype* input, ctype* output, | |||
| megdnn_trap(); | |||
| } | |||
| int device_id; | |||
| if (cudaGetDevice(&device_id) != cudaSuccess) { | |||
| megdnn_trap(); | |||
| } | |||
| cudaDeviceProp prop; | |||
| if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) { | |||
| megdnn_trap(); | |||
| } | |||
| uint32_t grid_dim_y_limit = prop.maxGridSize[1]; | |||
| uint32_t batch_idx = 0; | |||
| uint32_t grid_dim_x = get_grid_dim_x(length); | |||
| uint32_t grid_dim_y = 1; | |||
| @@ -567,20 +550,11 @@ template <typename ctype> | |||
| cudaError_t topk::topk_select(const ctype* input, const ctype* thresh, | |||
| ctype* output_value, int32_t* output_idx, | |||
| void* workspace, uint32_t batch, uint32_t length, | |||
| int32_t lda, int32_t k, cudaStream_t stream) { | |||
| int32_t lda, int32_t k, | |||
| uint32_t batch_upper_limit, cudaStream_t stream) { | |||
| using namespace cuda_topk_impl; | |||
| using namespace cuda_topk_impl::select; | |||
| int device_id; | |||
| if (cudaGetDevice(&device_id) != cudaSuccess) { | |||
| megdnn_trap(); | |||
| } | |||
| cudaDeviceProp prop; | |||
| if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) { | |||
| megdnn_trap(); | |||
| } | |||
| uint32_t batch_upper_limit = prop.maxGridSize[1]; | |||
| uint32_t length_split = DIVUP(length, REDUCE_SIZE); | |||
| void (*kptr_reduce_block_cnt)(const ctype*, const ctype*, uint32_t, int32_t, | |||
| @@ -688,10 +662,10 @@ namespace topk { | |||
| #define INST(t) \ | |||
| template cudaError_t find_kth_radix<t>(const t*, t*, void*, uint32_t, \ | |||
| uint32_t, int32_t, int32_t, \ | |||
| cudaStream_t); \ | |||
| uint32_t, cudaStream_t); \ | |||
| template cudaError_t topk_select<t>(const t*, const t*, t*, int32_t*, \ | |||
| void*, uint32_t, uint32_t, int32_t, \ | |||
| int32_t, cudaStream_t) | |||
| int32_t, uint32_t, cudaStream_t) | |||
| INST(float); | |||
| INST(int32_t); | |||
| #undef INST | |||
| @@ -76,10 +76,12 @@ struct RadixConverter<int32_t> { | |||
| template <typename ctype> | |||
| cudaError_t find_kth_radix(const ctype* input, ctype* output, void* workspace, | |||
| uint32_t batch, uint32_t length, int32_t lda, | |||
| int32_t k, cudaStream_t stream); | |||
| int32_t k, uint32_t grid_dim_y_limit, | |||
| cudaStream_t stream); | |||
| //! get workspace in bytes | |||
| uint32_t find_kth_radix_workspace(uint32_t batch, uint32_t length); | |||
| uint32_t find_kth_radix_workspace(uint32_t batch, uint32_t length, | |||
| uint32_t grid_dim_y_limit); | |||
| /*! | |||
| * \brief select values from rows of input that compare to thresh as specified | |||
| @@ -90,7 +92,8 @@ template <typename ctype> | |||
| cudaError_t topk_select(const ctype* input, const ctype* thresh, | |||
| ctype* output_value, int32_t* output_idx, | |||
| void* workspace, uint32_t batch, uint32_t length, | |||
| int32_t lda, int32_t k, cudaStream_t stream); | |||
| int32_t lda, int32_t k, uint32_t batch_upper_limit, | |||
| cudaStream_t stream); | |||
| uint32_t topk_select_workspace(uint32_t batch, uint32_t length); | |||