| @@ -0,0 +1,183 @@ | |||
| /** | |||
| * \file dnn/src/rocm/argsort/argsort.cpp.hip | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "hcc_detail/hcc_defs_prologue.h" | |||
| #include "src/rocm/utils.h.hip" | |||
| #include "./argsort.h.hip" | |||
| #include "./bitonic_sort.h.hip" | |||
| #include "megdnn/basic_types.h" | |||
| #include "hipcub/device/device_radix_sort.hpp" | |||
| #include "hipcub/device/device_segmented_radix_sort.hpp" | |||
| using namespace megdnn; | |||
| using namespace rocm; | |||
| namespace { | |||
| struct StridedOffsetIterator { | |||
| int bias, stride; | |||
| StridedOffsetIterator(int bias_, int stride_) | |||
| : bias(bias_), stride(stride_) {} | |||
| __device__ __forceinline__ int operator[](int i) const { | |||
| return stride * i + bias; | |||
| } | |||
| }; | |||
| bool use_bitonic(uint32_t /*M*/, uint32_t N) { | |||
| // bitonic sort is preferred when N is small (alwyas faster than radix sort) | |||
| return N <= BITONIC_SORT_MAX_LENGTH; | |||
| } | |||
| bool use_segmented(uint32_t M, uint32_t /*N*/) { | |||
| // an empirical value: | |||
| // sort(1, 1e6): 0.574ms | |||
| // segsort({1,2,8,16}, 1e6): 7-8ms | |||
| // sort(1, 1e7): 3.425ms | |||
| // segsort({1,2,8,16}, 1e7): 71-84ms | |||
| // | |||
| // segsort is about 7x-10x slower than sort on small batches, so we can | |||
| // expect it to be faster than sort when batch is large enough. | |||
| return M >= 8; | |||
| } | |||
| __global__ void kern_arange(int* dst, uint32_t n, uint32_t mod) { | |||
| uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; | |||
| if (i < n) { | |||
| dst[i] = i % mod; | |||
| } | |||
| } | |||
| template <typename ctype> | |||
| size_t get_sort_workspace(uint32_t M, uint32_t N, bool is_ascending) { | |||
| if (use_bitonic(M, N)) { | |||
| return 0; | |||
| } | |||
| return argsort::cub_sort_pairs<ctype, int>(is_ascending, NULL, 0, NULL, NULL, NULL, NULL, | |||
| M, N, 0, sizeof(float)*8, NULL); | |||
| } | |||
| } // anonymous namespace | |||
| template <typename KeyType, typename ValueType> | |||
| MEGDNN_NOINLINE size_t argsort::cub_sort_pairs( | |||
| bool is_ascending, void* workspace, size_t workspace_size, | |||
| const KeyType* keys_in, KeyType* keys_out, const ValueType* values_in, | |||
| ValueType* values_out, uint32_t M, uint32_t N, int begin_bit, int end_bit,hipStream_t stream){ | |||
| hipError_t err; | |||
| if (use_segmented(M, N)) { | |||
| if (is_ascending) { | |||
| err = hipcub::DeviceSegmentedRadixSort::SortPairs( | |||
| workspace, workspace_size, keys_in, keys_out, values_in, | |||
| values_out, N * M, M, StridedOffsetIterator(0, N), | |||
| StridedOffsetIterator(N, N), begin_bit, end_bit, stream); | |||
| hip_check(err); | |||
| } else { | |||
| err = hipcub::DeviceSegmentedRadixSort::SortPairsDescending( | |||
| workspace, workspace_size, keys_in, keys_out, values_in, | |||
| values_out, N * M, M, StridedOffsetIterator(0, N), | |||
| StridedOffsetIterator(N, N), begin_bit, end_bit, stream); | |||
| hip_check(err); | |||
| } | |||
| } else { | |||
| if (is_ascending) { | |||
| for (size_t i = 0; i < M; ++i) { | |||
| err = hipcub::DeviceRadixSort::SortPairs( | |||
| workspace, workspace_size, keys_in + N * i, | |||
| keys_out + N * i, values_in + N * i, values_out + N * i, | |||
| N, begin_bit, end_bit, stream); | |||
| hip_check(err); | |||
| if (!keys_in) { | |||
| return workspace_size; | |||
| } | |||
| } | |||
| } else { | |||
| for (size_t i = 0; i < M; ++i) { | |||
| err = hipcub::DeviceRadixSort::SortPairsDescending( | |||
| workspace, workspace_size, keys_in + N * i, | |||
| keys_out + N * i, values_in + N * i, values_out + N * i, | |||
| N, begin_bit, end_bit, stream); | |||
| hip_check(err); | |||
| if (!keys_in) { | |||
| return workspace_size; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return workspace_size; | |||
| } | |||
| size_t argsort::get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype, | |||
| bool is_ascending, | |||
| bool iptr_src_given) { | |||
| size_t size = 0; | |||
| switch (dtype.enumv().ev) { | |||
| #define cb(ctype) \ | |||
| case DTypeTrait<ctype>::enumv: \ | |||
| size = get_sort_workspace<ctype>(M, N, is_ascending); \ | |||
| break; | |||
| ARGSORT_FOREACH_CTYPE(cb) | |||
| #undef cb | |||
| default: | |||
| megdnn_throw("argsort only supports float, int32 and float16"); | |||
| } | |||
| if (!iptr_src_given) { | |||
| size = DIVUP(size, sizeof(float)) * sizeof(float) + M * N * sizeof(int); | |||
| } | |||
| return size; | |||
| } | |||
| template <typename dtype> | |||
| void argsort::forward(const dtype* sptr, dtype* dptr, int* iptr, | |||
| void* workspace, uint32_t M, uint32_t N, | |||
| bool is_ascending, hipStream_t stream, | |||
| const int* iptr_src) { | |||
| size_t wk_size = get_sort_workspace<dtype>(M, N, is_ascending); | |||
| if (!iptr_src) { | |||
| int* ptr = reinterpret_cast<int*>(static_cast<uint8_t*>(workspace) + | |||
| DIVUP(wk_size, sizeof(float)) * | |||
| sizeof(float)); | |||
| kern_arange<<<DIVUP(N * M, 512), 512, 0, stream>>>(ptr, M * N, N); | |||
| iptr_src = ptr; | |||
| } | |||
| if (use_bitonic(M, N)) { | |||
| hip_check(bitonic_sort(M, N, sptr, iptr_src, dptr, iptr, is_ascending, | |||
| stream)); | |||
| } else { | |||
| cub_sort_pairs(is_ascending, workspace, wk_size, sptr, dptr, iptr_src, | |||
| iptr, M, N, 0, sizeof(float)*8, stream); | |||
| } | |||
| } | |||
| namespace megdnn { | |||
| namespace rocm { | |||
| #define INST_CUB_SORT(dtype) \ | |||
| template MEGDNN_NOINLINE size_t argsort::cub_sort_pairs<dtype, dtype>(bool, \ | |||
| void*, size_t, const dtype*, dtype*, \ | |||
| const dtype*, dtype*, uint32_t, uint32_t,\ | |||
| int, int, hipStream_t); | |||
| #define INST_FORWARD(dtype) \ | |||
| template void argsort::forward<dtype>(const dtype*, dtype*, int*, void*, \ | |||
| uint32_t, uint32_t, bool, hipStream_t, \ | |||
| const int*); | |||
| ARGSORT_FOREACH_CTYPE(INST_FORWARD) | |||
| INST_CUB_SORT(uint32_t) | |||
| // INST_CUB_SORT(uint64_t) | |||
| #undef INST_CUB_SORT | |||
| #undef INST_FORWARD | |||
| } | |||
| } // namespace megdnn | |||
| // vim: ft=rocm syntax=rocm.doxygen | |||
| @@ -0,0 +1,50 @@ | |||
| /** | |||
| * \file dnn/src/rocm/argsort/argsort.h.hip | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "hcc_detail/hcc_defs_prologue.h" | |||
| #include "hip_header.h" | |||
| #include <stddef.h> | |||
| #include <stdint.h> | |||
| #include "megdnn/dtype.h" | |||
| namespace megdnn { | |||
| namespace rocm { | |||
| namespace argsort { | |||
| size_t get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype, | |||
| bool is_ascending, | |||
| bool iptr_src_given = false); | |||
| template <typename KeyType, typename ValueType> | |||
| size_t cub_sort_pairs( | |||
| bool is_ascending, void* workspace, size_t workspace_size, | |||
| const KeyType* keys_in, KeyType* keys_out, const ValueType* values_in, | |||
| ValueType* values_out, uint32_t M, uint32_t N, int begin_bit, int end_bit,hipStream_t stream); | |||
| /*! | |||
| * \param iptr_src pointer to indices; a range would be generated if it is null | |||
| */ | |||
| template <typename dtype> | |||
| void forward(const dtype* sptr, dtype* dptr, int* iptr, void* workspace, | |||
| uint32_t M, uint32_t N, bool is_ascending, hipStream_t stream, | |||
| const int* iptr_src = NULL); | |||
| //! iterate over all supported data types | |||
| #define ARGSORT_FOREACH_CTYPE(cb) \ | |||
| cb(float) cb(int32_t) // DNN_INC_FLOAT16(cb(dt_float16)) | |||
| } // namespace argsort | |||
| } // namespace rocm | |||
| } // namespace megdnn | |||
| // vim: ft=cpp syntax=cpp.doxygen | |||
| @@ -0,0 +1,67 @@ | |||
| /** | |||
| * \file dnn/src/rocm/argsort/backward.cpp.hip | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "hcc_detail/hcc_defs_prologue.h" | |||
| #include "src/rocm/utils.h.hip" | |||
| #include "./argsort.h.hip" | |||
| #include "./backward.h.hip" | |||
| // #include "src/rocm/utils.h" | |||
| using namespace megdnn; | |||
| using namespace rocm; | |||
| using namespace argsort; | |||
| namespace { | |||
| template <typename T> | |||
| __global__ void backward_kernel(uint32_t dst_w, uint32_t src_w, | |||
| uint32_t src_size, T* dst, const T* src_data, | |||
| const int* src_idx) { | |||
| uint32_t idx = threadIdx.x + blockIdx.x * blockDim.x; | |||
| if (idx < src_size) { | |||
| uint32_t r = idx / src_w; | |||
| dst[r * dst_w + src_idx[idx]] = src_data[idx]; | |||
| } | |||
| } | |||
| } // namespace | |||
| template <typename T> | |||
| void argsort::backward_proxy(uint32_t dst_h, uint32_t dst_w, uint32_t src_w, | |||
| T* dst, const T* src_data, const int* src_idx, | |||
| hipStream_t stream) { | |||
| if (dst_w != src_w) { | |||
| hipMemsetAsync(dst, 0, dst_h * dst_w * sizeof(T), stream); | |||
| } | |||
| uint32_t src_size = dst_h * src_w; | |||
| backward_kernel<<<DIVUP(src_size, 512), 512, 0, stream>>>( | |||
| dst_w, src_w, src_size, dst, src_data, src_idx); | |||
| after_kernel_launch(); | |||
| } | |||
| namespace megdnn { | |||
| namespace rocm { | |||
| namespace argsort { | |||
| #define INST(T) \ | |||
| template void backward_proxy(uint32_t dst_h, uint32_t dst_w, \ | |||
| uint32_t src_w, T* dst, const T* src_data, \ | |||
| const int* src_idx, hipStream_t stream); | |||
| ARGSORT_FOREACH_CTYPE(INST) | |||
| #undef INST | |||
| } // namespace argsort | |||
| } // namespace rocm | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,29 @@ | |||
| /** | |||
| * \file dnn/src/rocm/argsort/backward.h.hip | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "hip_header.h" | |||
| #include <stdint.h> | |||
| namespace megdnn { | |||
| namespace rocm { | |||
| namespace argsort { | |||
| template <typename T> | |||
| void backward_proxy(uint32_t dst_h, uint32_t dst_w, uint32_t src_w, T* dst, | |||
| const T* src_data, const int* src_idx, hipStream_t stream); | |||
| } // namespace argsort | |||
| } // namespace rocm | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,320 @@ | |||
| /** | |||
| * \file dnn/src/rocm/argsort/bitonic_sort.cpp.hip | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "hcc_detail/hcc_defs_prologue.h" | |||
| #include "./bitonic_sort.h.hip" | |||
| // #include "src/cuda/query_blocksize.cuh" | |||
| // #include "megdnn/dtype.h" | |||
| // #if __CUDACC_VER_MAJOR__ < 9 | |||
| // #pragma message "warp sync disabled due to insufficient cuda version" | |||
| #define __syncwarp __syncthreads | |||
| // #endif | |||
| #include <algorithm> | |||
| #include <cmath> | |||
| using namespace megdnn; | |||
| using namespace rocm; | |||
| namespace bitonic_sort_impl { | |||
| //! load keys and init idx | |||
| template <class CompareLess, typename T> | |||
| __device__ __forceinline__ void safe_load0(T* dst, uint16_t* idx, const T* src, | |||
| uint32_t id, uint32_t size) { | |||
| dst[id] = id < size ? src[id] : CompareLess::template max<T>(); | |||
| idx[id] = id; | |||
| } | |||
| //! load values | |||
| template <typename T> | |||
| __device__ __forceinline__ void safe_load1(T* dst, const T* src, uint32_t id, | |||
| uint32_t size) { | |||
| // broadcast last value to avoid out-of-bound values (for example, when | |||
| // input contains NaN) | |||
| dst[id] = src[min(id, size - 1)]; | |||
| } | |||
| //! write keys | |||
| template <typename T> | |||
| __device__ __forceinline__ void safe_write0(T* dst, const T* src, uint32_t id, | |||
| uint32_t size) { | |||
| if (id < size) { | |||
| dst[id] = src[id]; | |||
| } | |||
| } | |||
| //! write values | |||
| template <typename T> | |||
| __device__ __forceinline__ void safe_write1(T* dst, const T* src, | |||
| const uint16_t* remap, uint32_t id, | |||
| uint32_t size) { | |||
| if (id < size) { | |||
| dst[id] = src[remap[id]]; | |||
| } | |||
| } | |||
| struct SyncWarp { | |||
| static __device__ __forceinline__ void s() { __syncwarp(); } | |||
| }; | |||
| struct SyncBlock { | |||
| static __device__ __forceinline__ void s() { __syncthreads(); } | |||
| }; | |||
| template <typename T> | |||
| struct NumTrait; | |||
| template <> | |||
| struct NumTrait<float> { | |||
| static __device__ __forceinline__ float max() { return INFINITY; } | |||
| static __device__ __forceinline__ float min() { return -INFINITY; } | |||
| }; | |||
| template <> | |||
| struct NumTrait<int32_t> { | |||
| static __device__ __forceinline__ int32_t max() { return INT_MAX; } | |||
| static __device__ __forceinline__ int32_t min() { return INT_MIN; } | |||
| }; | |||
| // #if !MEGDNN_DISABLE_FLOAT16 | |||
| // template <> | |||
| // struct NumTrait<dt_float16> { | |||
| // static __device__ __forceinline__ dt_float16 max() { | |||
| // return std::numeric_limits<dt_float16>::max(); | |||
| // } | |||
| // static __device__ __forceinline__ dt_float16 min() { | |||
| // return std::numeric_limits<dt_float16>::lowest(); | |||
| // } | |||
| // }; | |||
| // #endif | |||
| struct LessThan { | |||
| template <typename Key, typename Value> | |||
| static __device__ __forceinline__ bool cmp(Key k0, Value v0, Key k1, | |||
| Value v1) { | |||
| return (k0 < k1) | ((k0 == k1) & (v0 < v1)); | |||
| } | |||
| template <typename T> | |||
| static __device__ __forceinline__ T max() { | |||
| return NumTrait<T>::max(); | |||
| } | |||
| }; | |||
| struct GreaterThan { | |||
| template <typename Key, typename Value> | |||
| static __device__ __forceinline__ bool cmp(Key k0, Value v0, Key k1, | |||
| Value v1) { | |||
| return (k0 > k1) | ((k0 == k1) & (v0 < v1)); | |||
| } | |||
| template <typename T> | |||
| static __device__ __forceinline__ T max() { | |||
| return NumTrait<T>::min(); | |||
| } | |||
| }; | |||
| template <typename Key, typename Value> | |||
| union KVUnion { | |||
| Key key; | |||
| Value value; | |||
| }; | |||
| template <typename Key, typename Value> | |||
| static int get_shmem(int block_size, void* = NULL) { | |||
| return (sizeof(KVUnion<Key, Value>) + sizeof(uint16_t)) * block_size * 4; | |||
| } | |||
| /*! | |||
| * \brief batched bitonic sort (M, N) for small N | |||
| * | |||
| * launch configuration: | |||
| * grid(X) | |||
| * block(N/4, Y) | |||
| * | |||
| * where N / 4 == 1 << nr_th_log2 | |||
| */ | |||
| template <class Sync, typename Key, typename Value, class CompareLess, | |||
| uint32_t nr_th_log2> | |||
| static __global__ void kern(uint32_t batch, uint32_t length, const Key* key_inp, | |||
| const Value* value_inp, Key* key_out, | |||
| Value* value_out) { | |||
| const uint32_t nr_th = 1 << nr_th_log2; | |||
| // 24KiB shared memory for 4-byte keys for 1024 threads | |||
| extern __shared__ uint8_t smem_storage[]; | |||
| uint16_t* idx_storage = reinterpret_cast<uint16_t*>(smem_storage); | |||
| KVUnion<Key, Value>* keys_storage = reinterpret_cast<KVUnion<Key, Value>*>( | |||
| idx_storage + blockDim.y * (nr_th * 4)); | |||
| uint32_t cur_batch = blockIdx.x * blockDim.y + threadIdx.y, | |||
| off = cur_batch * length; | |||
| key_inp += off; | |||
| key_out += off; | |||
| value_inp += off; | |||
| value_out += off; | |||
| uint32_t storage_offset = threadIdx.y * (nr_th * 4); | |||
| uint16_t* values = idx_storage + storage_offset; | |||
| Key* keys = reinterpret_cast<Key*>(keys_storage + storage_offset); | |||
| uint32_t tid0 = threadIdx.x, tid1 = tid0 + nr_th, | |||
| cur_length = cur_batch < batch ? length : 0; | |||
| safe_load0<CompareLess>(keys, values, key_inp, tid0, cur_length); | |||
| safe_load0<CompareLess>(keys, values, key_inp, tid0 + nr_th, cur_length); | |||
| safe_load0<CompareLess>(keys, values, key_inp, tid0 + nr_th * 2, | |||
| cur_length); | |||
| safe_load0<CompareLess>(keys, values, key_inp, tid0 + nr_th * 3, | |||
| cur_length); | |||
| Sync::s(); | |||
| #define WORK(_idx, _asc) \ | |||
| do { \ | |||
| uint32_t _id0 = (_idx), _id1 = _id0 + step; \ | |||
| Key _k0 = keys[_id0], _k1 = keys[_id1]; \ | |||
| uint16_t _v0 = values[_id0], _v1 = values[_id1]; \ | |||
| if (CompareLess::cmp(_k0, _v0, _k1, _v1) != _asc) { \ | |||
| keys[_id0] = _k1; \ | |||
| keys[_id1] = _k0; \ | |||
| values[_id0] = _v1; \ | |||
| values[_id1] = _v0; \ | |||
| } \ | |||
| } while (0) | |||
| #pragma unroll | |||
| for (uint32_t slen_log = 0; slen_log <= (nr_th_log2 + 1); ++slen_log) { | |||
| // log2 of half of current bitonic sequence (i.e. length of its | |||
| // monotonic part) | |||
| uint32_t asc0 = !((tid0 >> slen_log) & 1), | |||
| asc1 = !((tid1 >> slen_log) & 1); | |||
| #pragma unroll | |||
| for (uint32_t j = 0; j <= slen_log; ++j) { | |||
| uint32_t step = 1 << (slen_log - j), xmask = step - 1, | |||
| ymask = ~xmask; | |||
| WORK((tid0 & xmask) + ((tid0 & ymask) << 1), asc0); | |||
| WORK((tid1 & xmask) + ((tid1 & ymask) << 1), asc1); | |||
| Sync::s(); | |||
| } | |||
| } | |||
| #undef WORK | |||
| if (cur_batch < batch) { | |||
| safe_write0(key_out, keys, tid0, length); | |||
| safe_write0(key_out, keys, tid0 + nr_th, length); | |||
| safe_write0(key_out, keys, tid0 + nr_th * 2, length); | |||
| safe_write0(key_out, keys, tid0 + nr_th * 3, length); | |||
| // permute values according to sorted indices | |||
| Value* copied_values = reinterpret_cast<Value*>(keys); | |||
| safe_load1(copied_values, value_inp, tid0, cur_length); | |||
| safe_load1(copied_values, value_inp, tid0 + nr_th, cur_length); | |||
| safe_load1(copied_values, value_inp, tid0 + nr_th * 2, cur_length); | |||
| safe_load1(copied_values, value_inp, tid0 + nr_th * 3, cur_length); | |||
| Sync::s(); | |||
| safe_write1(value_out, copied_values, values, tid0, length); | |||
| safe_write1(value_out, copied_values, values, tid0 + nr_th, length); | |||
| safe_write1(value_out, copied_values, values, tid0 + nr_th * 2, length); | |||
| safe_write1(value_out, copied_values, values, tid0 + nr_th * 3, length); | |||
| } | |||
| } | |||
| } // namespace bitonic_sort_impl | |||
| template <typename Key, typename Value> | |||
| hipError_t rocm::bitonic_sort(uint32_t batch, uint32_t length, | |||
| const Key* key_inp, const Value* value_inp, | |||
| Key* key_out, Value* value_out, bool ascending, | |||
| hipStream_t stream) { | |||
| using namespace bitonic_sort_impl; | |||
| if (length == 1) { | |||
| if (key_inp != key_out) { | |||
| hipMemcpyAsync(key_out, key_inp, sizeof(Key) * batch, | |||
| hipMemcpyDeviceToDevice, stream); | |||
| } | |||
| if (value_inp != value_out) { | |||
| hipMemcpyAsync(value_out, value_inp, sizeof(Value) * batch, | |||
| hipMemcpyDeviceToDevice, stream); | |||
| } | |||
| return hipGetLastError(); | |||
| } | |||
| void (*kptr)(uint32_t, uint32_t, const Key*, const Value*, Key*, Value*) = | |||
| NULL; | |||
| uint32_t l4 = (length + 3) / 4; | |||
| dim3 block; | |||
| #define chk(s) \ | |||
| do { \ | |||
| if (!kptr && l4 <= (1 << s)) { \ | |||
| block.x = 1 << s; \ | |||
| if ((1 << s) <= 32) { \ | |||
| if (ascending) { \ | |||
| kptr = kern<SyncWarp, Key, Value, LessThan, s>; \ | |||
| } else { \ | |||
| kptr = kern<SyncWarp, Key, Value, GreaterThan, s>; \ | |||
| } \ | |||
| } else { \ | |||
| if (ascending) { \ | |||
| kptr = kern<SyncBlock, Key, Value, LessThan, s>; \ | |||
| } else { \ | |||
| kptr = kern<SyncBlock, Key, Value, GreaterThan, s>; \ | |||
| } \ | |||
| } \ | |||
| } \ | |||
| } while (0) | |||
| chk(0); | |||
| chk(1); | |||
| chk(2); | |||
| chk(3); | |||
| chk(4); | |||
| chk(5); | |||
| chk(6); | |||
| chk(7); | |||
| chk(8); | |||
| chk(9); | |||
| if (!kptr) { | |||
| return hipErrorInvalidConfiguration; | |||
| } | |||
| // TODO: this is randomly choosed | |||
| int suggested_block_size = 128; | |||
| // query_launch_config_for_kernel(reinterpret_cast<void*>(kptr), | |||
| // get_shmem<Key, Value>) | |||
| // .block_size; | |||
| block.y = std::max<int>(suggested_block_size / block.x, 1); | |||
| int shmem = get_shmem<Key, Value>(block.y * block.x); | |||
| kptr<<<(batch - 1) / block.y + 1, block, shmem, stream>>>( | |||
| batch, length, key_inp, value_inp, key_out, value_out); | |||
| return hipGetLastError(); | |||
| } | |||
| namespace megdnn { | |||
| namespace rocm { | |||
| #define INST(k, v) \ | |||
| template hipError_t bitonic_sort<k, v>(uint32_t, uint32_t, const k*, \ | |||
| const v*, k*, v*, bool, \ | |||
| hipStream_t) | |||
| INST(float, int); | |||
| INST(int32_t, int); | |||
| // DNN_INC_FLOAT16(INST(dt_float16, int)); | |||
| #undef INST | |||
| } // namespace megdnn | |||
| } // namespace megdnn | |||
| // vim: ft=rocm syntax=rocm.doxygen | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * \file dnn/src/rocm/argsort/bitonic_sort.h.hip | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "hip_header.h" | |||
| #include <stdint.h> | |||
| namespace megdnn { | |||
| namespace rocm { | |||
| const uint32_t BITONIC_SORT_MAX_LENGTH = 1024; | |||
| // cub radix sort seems to be faster with lengths > 1024 | |||
| /*! | |||
| * \brief bitonic sort for k/v pairs | |||
| * | |||
| * Requires \p length no larger than 4 times of cuda thread num. \p key_inp | |||
| * and \p key_out can be identical, and so are \p value_inp and \p value_out. | |||
| */ | |||
| template <typename Key, typename Value> | |||
| hipError_t bitonic_sort(uint32_t batch, uint32_t length, const Key* key_inp, | |||
| const Value* value_inp, Key* key_out, Value* value_out, | |||
| bool ascending, hipStream_t stream); | |||
| } // namespace rocm | |||
| } // namespace megdnn | |||
| // vim: ft=cpp syntax=cpp.doxygen | |||
| @@ -0,0 +1,79 @@ | |||
| /** | |||
| * \file dnn/src/rocm/argsort/opr_impl.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "./opr_impl.h" | |||
| #include "./argsort.h.hip" | |||
| #include "./backward.h.hip" | |||
| #include "src/common/utils.h" | |||
| #include "src/rocm/utils.h" | |||
| using namespace megdnn; | |||
| using namespace rocm; | |||
| void ArgsortForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_tensor_out indices, | |||
| _megdnn_workspace workspace) { | |||
| check_exec(src.layout, dst.layout, indices.layout, workspace.size); | |||
| auto M = src.layout.shape[0], N = src.layout.shape[1]; | |||
| auto iptr = indices.ptr<dt_int32>(); | |||
| auto wptr = static_cast<void*>(workspace.raw_ptr); | |||
| bool is_ascending = (param().order == Order::ASCENDING); | |||
| auto stream = hip_stream(handle()); | |||
| switch (src.layout.dtype.enumv()) { | |||
| #define cb(t) \ | |||
| case DTypeTrait<t>::enumv: \ | |||
| argsort::forward(src.ptr<t>(), dst.ptr<t>(), iptr, wptr, M, N, \ | |||
| is_ascending, stream); \ | |||
| break; | |||
| ARGSORT_FOREACH_CTYPE(cb); | |||
| #undef cb | |||
| default: | |||
| megdnn_throw(ssprintf("unsupported argsort dtype on cuda: %s", | |||
| src.layout.dtype.name())); | |||
| } | |||
| } | |||
| size_t ArgsortForwardImpl::get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout&, | |||
| const TensorLayout&) { | |||
| megdnn_assert(src.ndim == 2, "invalid src layout: %s", | |||
| src.to_string().c_str()); | |||
| auto M = src.shape[0], N = src.shape[1]; | |||
| auto&& dtype = src.dtype; | |||
| megdnn_assert(std::max(M, N) <= | |||
| static_cast<size_t>(std::numeric_limits<int>::max())); | |||
| return argsort::get_fwd_workspace_in_bytes( | |||
| M, N, dtype, param().order == Param::Order::ASCENDING); | |||
| } | |||
| void ArgsortBackwardImpl::exec(_megdnn_tensor_in diff, | |||
| _megdnn_tensor_in indices, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) { | |||
| check_exec(diff.layout, indices.layout, grad.layout, workspace.size); | |||
| auto stream = hip_stream(handle()); | |||
| switch (diff.layout.dtype.enumv()) { | |||
| #define cb(t) \ | |||
| case DTypeTrait<t>::enumv: \ | |||
| argsort::backward_proxy(grad.layout[0], grad.layout[1], \ | |||
| diff.layout[1], grad.ptr<t>(), diff.ptr<t>(), \ | |||
| indices.ptr<int>(), stream); \ | |||
| break; | |||
| ARGSORT_FOREACH_CTYPE(cb); | |||
| #undef cb | |||
| default: | |||
| megdnn_throw(ssprintf("unsupported argsort dtype on cuda: %s", | |||
| diff.layout.dtype.name())); | |||
| } | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * \file dnn/src/rocm/argsort/opr_impl.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs.h" | |||
| namespace megdnn { | |||
| namespace rocm { | |||
| class ArgsortForwardImpl final: public ArgsortForward { | |||
| public: | |||
| using ArgsortForward::ArgsortForward; | |||
| void exec(_megdnn_tensor_in src, | |||
| _megdnn_tensor_out dst, | |||
| _megdnn_tensor_out indices, | |||
| _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout &src, | |||
| const TensorLayout &dst, | |||
| const TensorLayout &indices) override; | |||
| }; | |||
| class ArgsortBackwardImpl final: public ArgsortBackward { | |||
| public: | |||
| using ArgsortBackward::ArgsortBackward; | |||
| void exec(_megdnn_tensor_in diff, | |||
| _megdnn_tensor_in indices, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout &, | |||
| const TensorLayout &, | |||
| const TensorLayout &) override { | |||
| return 0; | |||
| } | |||
| }; | |||
| } // namespace rocm | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -33,6 +33,7 @@ | |||
| #include "src/rocm/powc/opr_impl.h" | |||
| #include "src/rocm/indexing_multi_axis_vec/opr_impl.h" | |||
| #include "src/rocm/linspace/opr_impl.h" | |||
| #include "src/rocm/argsort/opr_impl.h" | |||
| #include "src/rocm/argmxx/opr_impl.h" | |||
| #include "src/rocm/sleep/opr_impl.h" | |||
| #include "src/rocm/batch_normalization/opr_impl.h" | |||
| @@ -148,6 +149,8 @@ bool HandleImpl::check_cross_dev_copy_constraint(const TensorLayout& src) { | |||
| return src.is_contiguous() || src.stride[src.ndim - 1] == 1; | |||
| } | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgsortForward); | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgsortBackward); | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionForward); | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardData); | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardFilter); | |||
| @@ -0,0 +1,124 @@ | |||
| /** | |||
| * \file dnn/test/rocm/argsort.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "test/rocm/fixture.h" | |||
| #include "test/common/checker.h" | |||
| #include "test/common/rng.h" | |||
| #include "test/common/tensor.h" | |||
| #include "../src/rocm/argsort/opr_impl.h" | |||
| using namespace megdnn; | |||
| using namespace test; | |||
| namespace { | |||
| class ArgsortRNG final : public RNG { | |||
| bool m_rev_order = false; | |||
| DType m_dtype; | |||
| template <typename T> | |||
| void fill(T* ptr, int n) { | |||
| if (m_rev_order) { | |||
| for (int i = 0; i < n; ++i) | |||
| ptr[i] = static_cast<T>(n / 2 - i); | |||
| } else { | |||
| for (int i = 0; i < n; ++i) | |||
| ptr[i] = static_cast<T>(i - n / 2); | |||
| COMPAT_RANDOM(ptr, ptr + n); | |||
| } | |||
| } | |||
| void gen(const TensorND& tensor) override { | |||
| auto n = tensor.layout.total_nr_elems(); | |||
| if (m_dtype == dtype::Float32{}) { | |||
| fill(tensor.ptr<dt_float32>(), n); | |||
| } else { | |||
| megdnn_assert(m_dtype == dtype::Int32{}); | |||
| fill(tensor.ptr<dt_int32>(), n); | |||
| } | |||
| } | |||
| public: | |||
| ArgsortRNG(DType dt) : m_dtype{dt} {} | |||
| void set_rev_order(bool flag) { m_rev_order = flag; } | |||
| }; | |||
| void run_forward_test(Handle* handle, DType dtype) { | |||
| Checker<ArgsortForward> checker(handle); | |||
| using Param = Argsort::Param; | |||
| using Order = Param::Order; | |||
| ArgsortRNG rng{dtype}; | |||
| checker.set_dtype(2, dtype::Int32()); | |||
| checker.set_dtype(0, dtype).set_rng(0, &rng); | |||
| for (size_t i = 3; i < 10240; i *= 2) { | |||
| Param param; | |||
| param.order = Order::ASCENDING; | |||
| checker.set_param(param).execs({{3, i + 1}, {}, {}}); | |||
| param.order = Order::DESCENDING; | |||
| checker.set_param(param).execs({{3, i - 1}, {}, {}}); | |||
| checker.set_param(param).execs({{13, i + 3}, {}, {}}); | |||
| } | |||
| { | |||
| // reverse sort large array | |||
| constexpr size_t N = 200003; | |||
| rng.set_rev_order(true); | |||
| Param param; | |||
| param.order = Order::ASCENDING; | |||
| checker.set_param(param).execs({{1, N}, {}, {}}); | |||
| } | |||
| } | |||
| void run_backward_test(Handle* handle, DType dtype) { | |||
| class IdxRng final : public RNG { | |||
| void gen(const TensorND& tensor) override { | |||
| auto ptr = tensor.ptr<dt_int32>(); | |||
| auto m = tensor.layout[0], n = tensor.layout[1]; | |||
| for (size_t i = 0; i < m; ++i) { | |||
| for (size_t j = 0; j < n; ++j) { | |||
| ptr[j] = j; | |||
| } | |||
| COMPAT_RANDOM(ptr, ptr + n); | |||
| ptr += n; | |||
| } | |||
| } | |||
| } rng; | |||
| Checker<ArgsortBackward> checker(handle); | |||
| checker.set_dtype(1, dtype::Int32()).set_rng(1, &rng); | |||
| checker.set_dtype(0, dtype); | |||
| checker.set_dtype(2, dtype); | |||
| for (size_t i = 16; i < 4096; i *= 2) { | |||
| checker.execs({{3, i}, {3, i}, {3, i}}); | |||
| checker.execs({{3, i + 3}, {3, i + 3}, {3, i + 3}}); | |||
| checker.execs({{3, i + 3}, {3, i + 3}, {3, i + 7}}); | |||
| } | |||
| } | |||
| } // anonymous namespace | |||
| TEST_F(ROCM, ARGSORT_FORWARD_F32) { | |||
| run_forward_test(handle_rocm(), dtype::Float32{}); | |||
| } | |||
| TEST_F(ROCM, ARGSORT_FORWARD_I32) { | |||
| run_forward_test(handle_rocm(), dtype::Int32{}); | |||
| } | |||
| TEST_F(ROCM, ARGSORT_BACKWARD_F32) { | |||
| run_backward_test(handle_rocm(), dtype::Float32{}); | |||
| } | |||
| TEST_F(ROCM, ARGSORT_BACKWARD_I32) { | |||
| run_backward_test(handle_rocm(), dtype::Int32{}); | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||