|
- /**
- * \file dnn/src/rocm/topk/opr_impl.cpp
- * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- *
- * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- */
-
- #include "./opr_impl.h"
- #include "./topk_radix.h.hip"
- #include "src/common/utils.h"
- #include "src/rocm/argsort/argsort.h.hip"
- #include "src/rocm/utils.h"
-
- using namespace megdnn;
- using namespace rocm;
-
- template <typename ctype>
- void TopKImpl::dispatch_with_ctype(int k, size_t m, size_t n, ptrdiff_t lda,
- const ctype* data, ctype* values,
- int* indices, void* workspace) {
- auto _handle = concrete_handle(handle());
- auto stream = _handle->stream();
- size_t grid_dim_y_limit = _handle->device_prop().maxGridSize[1];
- switch (param().mode) {
- case Param::Mode::KTH_ONLY:
- hip_check(topk::find_kth_radix<ctype>(data, values, workspace, m,
- n, lda, k, grid_dim_y_limit,
- stream));
- return;
- case Param::Mode::VALUE_IDX_NOSORT: {
- WorkspaceBundle wk_bundle{workspace, {m * sizeof(ctype), 1}};
- auto thresh = static_cast<ctype*>(wk_bundle.get(0));
- auto real_wk = wk_bundle.get(1);
- hip_check(topk::find_kth_radix<ctype>(data, thresh, real_wk, m, n,
- lda, k, grid_dim_y_limit,
- stream));
- hip_check(topk::topk_select<ctype>(data, thresh, values, indices,
- real_wk, m, n, lda, k,
- grid_dim_y_limit, stream));
- return;
- }
- case Param::Mode::VALUE_IDX_SORTED: {
- WorkspaceBundle wk_bundle{
- workspace,
- {m * sizeof(ctype), m * std::abs(k) * sizeof(ctype),
- m * std::abs(k) * sizeof(int32_t), 1}};
- auto thresh = static_cast<ctype*>(wk_bundle.get(0)),
- nosort_values = static_cast<ctype*>(wk_bundle.get(1));
- auto nosort_idx = static_cast<int32_t*>(wk_bundle.get(2));
- auto real_wk = wk_bundle.get(3);
- hip_check(topk::find_kth_radix<ctype>(data, thresh, real_wk, m, n,
- lda, k, grid_dim_y_limit,
- stream));
- hip_check(topk::topk_select<ctype>(data, thresh, nosort_values,
- nosort_idx, real_wk, m, n, lda,
- k, grid_dim_y_limit, stream));
- argsort::forward(nosort_values, values, indices, real_wk, m,
- std::abs(k), k > 0, stream, nosort_idx);
- return;
- }
- }
- megdnn_throw("bad topk mode");
- }
-
- void TopKImpl::do_exec(int k, _megdnn_tensor_in data, _megdnn_tensor_out values,
- int32_t* indices, _megdnn_workspace workspace) {
- switch (data.layout.dtype.enumv()) {
- case DTypeEnum::Float32:
- dispatch_with_ctype<float>(k, data.layout[0], data.layout[1],
- data.layout.stride[0], data.ptr<float>(),
- values.ptr<float>(), indices,
- workspace.raw_ptr);
- return;
- case DTypeEnum::Int32:
- dispatch_with_ctype<int32_t>(k, data.layout[0], data.layout[1],
- data.layout.stride[0], data.ptr<int32_t>(),
- values.ptr<int32_t>(), indices,
- workspace.raw_ptr);
- return;
- // #if !MEGDNN_DISABLE_FLOAT16
- // case DTypeEnum::Float16:
- // dispatch_with_ctype<dt_float16>(k, data.layout[0], data.layout[1],
- // data.layout.stride[0], data.ptr<dt_float16>(),
- // values.ptr<dt_float16>(), indices,
- // workspace.raw_ptr);
- // return;
- // #endif
- default:
- megdnn_throw(
- ssprintf("only float32, int32 and float16 supported for "
- "cuda topk, got: %s",
- data.layout.dtype.name()));
- }
- }
-
- size_t TopKImpl::get_workspace_in_bytes(int k, const TensorLayout& data,
- const TensorLayout& values,
- const TensorLayout& indices) {
- MEGDNN_MARK_USED_VAR(values);
- MEGDNN_MARK_USED_VAR(indices);
- size_t m = data[0], n = data[1];
- size_t kabs = std::abs(k);
- size_t grid_dim_y_limit =
- concrete_handle(handle())->device_prop().maxGridSize[1];
- megdnn_assert(std::max(m, n) <=
- static_cast<size_t>(std::numeric_limits<int>::max()));
- size_t kth = topk::find_kth_radix_workspace(m, n, grid_dim_y_limit),
- sel = topk::topk_select_workspace(m, n);
- auto ctsize = data.dtype.size();
- switch (param().mode) {
- case Param::Mode::KTH_ONLY:
- return kth;
- case Param::Mode::VALUE_IDX_NOSORT:
- return WorkspaceBundle{nullptr, {m * ctsize, std::max(kth, sel)}}
- .total_size_in_bytes();
- case Param::Mode::VALUE_IDX_SORTED:
- return WorkspaceBundle{
- nullptr,
- {m * ctsize, m * kabs * ctsize, m * kabs * sizeof(int32_t),
- std::max(std::max(kth, sel),
- argsort::get_fwd_workspace_in_bytes(
- m, kabs, data.dtype, k > 0, true))}}
- .total_size_in_bytes();
- }
- megdnn_throw("bad topk mode");
- }
-
- // vim: syntax=cpp.doxygen
|