|
- /**
- * \file dnn/src/rocm/argsort/argsort.cpp.hip
- * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- *
- * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- */
- #include "hcc_detail/hcc_defs_prologue.h"
-
- #include "src/rocm/utils.h.hip"
- #include "./argsort.h.hip"
- #include "./bitonic_sort.h.hip"
- #include "megdnn/basic_types.h"
-
- #include "hipcub/device/device_radix_sort.hpp"
- #include "hipcub/device/device_segmented_radix_sort.hpp"
-
- using namespace megdnn;
- using namespace rocm;
-
- namespace {
- struct StridedOffsetIterator {
- int bias, stride;
-
- StridedOffsetIterator(int bias_, int stride_)
- : bias(bias_), stride(stride_) {}
-
- __device__ __forceinline__ int operator[](int i) const {
- return stride * i + bias;
- }
- };
-
- bool use_bitonic(uint32_t /*M*/, uint32_t N) {
- // bitonic sort is preferred when N is small (alwyas faster than radix sort)
- return N <= BITONIC_SORT_MAX_LENGTH;
- }
-
- bool use_segmented(uint32_t M, uint32_t /*N*/) {
- // an empirical value:
- // sort(1, 1e6): 0.574ms
- // segsort({1,2,8,16}, 1e6): 7-8ms
- // sort(1, 1e7): 3.425ms
- // segsort({1,2,8,16}, 1e7): 71-84ms
- //
- // segsort is about 7x-10x slower than sort on small batches, so we can
- // expect it to be faster than sort when batch is large enough.
- return M >= 8;
- }
-
- __global__ void kern_arange(int* dst, uint32_t n, uint32_t mod) {
- uint32_t i = threadIdx.x + blockIdx.x * blockDim.x;
- if (i < n) {
- dst[i] = i % mod;
- }
- }
-
- template <typename ctype>
- size_t get_sort_workspace(uint32_t M, uint32_t N, bool is_ascending) {
- if (use_bitonic(M, N)) {
- return 0;
- }
- return argsort::cub_sort_pairs<ctype, int>(is_ascending, NULL, 0, NULL, NULL, NULL, NULL,
- M, N, 0, sizeof(float)*8, NULL);
- }
- } // anonymous namespace
-
- template <typename KeyType, typename ValueType>
- MEGDNN_NOINLINE size_t argsort::cub_sort_pairs(
- bool is_ascending, void* workspace, size_t workspace_size,
- const KeyType* keys_in, KeyType* keys_out, const ValueType* values_in,
- ValueType* values_out, uint32_t M, uint32_t N, int begin_bit, int end_bit,hipStream_t stream){
- hipError_t err;
- if (use_segmented(M, N)) {
- if (is_ascending) {
- err = hipcub::DeviceSegmentedRadixSort::SortPairs(
- workspace, workspace_size, keys_in, keys_out, values_in,
- values_out, N * M, M, StridedOffsetIterator(0, N),
- StridedOffsetIterator(N, N), begin_bit, end_bit, stream);
- hip_check(err);
- } else {
- err = hipcub::DeviceSegmentedRadixSort::SortPairsDescending(
- workspace, workspace_size, keys_in, keys_out, values_in,
- values_out, N * M, M, StridedOffsetIterator(0, N),
- StridedOffsetIterator(N, N), begin_bit, end_bit, stream);
- hip_check(err);
- }
- } else {
- if (is_ascending) {
- for (size_t i = 0; i < M; ++i) {
- err = hipcub::DeviceRadixSort::SortPairs(
- workspace, workspace_size, keys_in + N * i,
- keys_out + N * i, values_in + N * i, values_out + N * i,
- N, begin_bit, end_bit, stream);
- hip_check(err);
- if (!keys_in) {
- return workspace_size;
- }
- }
- } else {
- for (size_t i = 0; i < M; ++i) {
- err = hipcub::DeviceRadixSort::SortPairsDescending(
- workspace, workspace_size, keys_in + N * i,
- keys_out + N * i, values_in + N * i, values_out + N * i,
- N, begin_bit, end_bit, stream);
- hip_check(err);
- if (!keys_in) {
- return workspace_size;
- }
- }
- }
- }
- return workspace_size;
- }
-
- size_t argsort::get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype,
- bool is_ascending,
- bool iptr_src_given) {
- size_t size = 0;
- switch (dtype.enumv().ev) {
- #define cb(ctype) \
- case DTypeTrait<ctype>::enumv: \
- size = get_sort_workspace<ctype>(M, N, is_ascending); \
- break;
- ARGSORT_FOREACH_CTYPE(cb)
- #undef cb
- default:
- megdnn_throw("argsort only supports float, int32 and float16");
- }
- if (!iptr_src_given) {
- size = DIVUP(size, sizeof(float)) * sizeof(float) + M * N * sizeof(int);
- }
- return size;
- }
-
- template <typename dtype>
- void argsort::forward(const dtype* sptr, dtype* dptr, int* iptr,
- void* workspace, uint32_t M, uint32_t N,
- bool is_ascending, hipStream_t stream,
- const int* iptr_src) {
- size_t wk_size = get_sort_workspace<dtype>(M, N, is_ascending);
- if (!iptr_src) {
- int* ptr = reinterpret_cast<int*>(static_cast<uint8_t*>(workspace) +
- DIVUP(wk_size, sizeof(float)) *
- sizeof(float));
- kern_arange<<<DIVUP(N * M, 512), 512, 0, stream>>>(ptr, M * N, N);
- iptr_src = ptr;
- }
-
- if (use_bitonic(M, N)) {
- hip_check(bitonic_sort(M, N, sptr, iptr_src, dptr, iptr, is_ascending,
- stream));
- } else {
- cub_sort_pairs(is_ascending, workspace, wk_size, sptr, dptr, iptr_src,
- iptr, M, N, 0, sizeof(float)*8, stream);
- }
- }
-
- namespace megdnn {
- namespace rocm {
-
- #define INST_CUB_SORT(dtype) \
- template MEGDNN_NOINLINE size_t argsort::cub_sort_pairs<dtype, dtype>(bool, \
- void*, size_t, const dtype*, dtype*, \
- const dtype*, dtype*, uint32_t, uint32_t,\
- int, int, hipStream_t);
-
- #define INST_FORWARD(dtype) \
- template void argsort::forward<dtype>(const dtype*, dtype*, int*, void*, \
- uint32_t, uint32_t, bool, hipStream_t, \
- const int*);
-
- ARGSORT_FOREACH_CTYPE(INST_FORWARD)
- INST_CUB_SORT(uint32_t)
- // INST_CUB_SORT(uint64_t)
- #undef INST_CUB_SORT
- #undef INST_FORWARD
- }
- } // namespace megdnn
- // vim: ft=rocm syntax=rocm.doxygen
|