From: @robingrosman Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -14,9 +14,10 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TOPK_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TOPK_H_ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TOPK_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TOPK_GPU_KERNEL_H_ | |||
| #include <limits> | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| @@ -27,7 +28,7 @@ namespace kernel { | |||
| template <typename T, typename S> | |||
| class TopKGpuKernel : public GpuKernel { | |||
| public: | |||
| TopKGpuKernel() : sorted_(false), outer_size_(1), inner_size_(1), k_(1), use_share_mem_(true), ceil_power2_(0) {} | |||
| TopKGpuKernel() : sorted_(false), outer_size_(1), inner_size_(1), k_(1), input_shape_size_(0) {} | |||
| ~TopKGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| @@ -40,26 +41,17 @@ class TopKGpuKernel : public GpuKernel { | |||
| S *k = GetDeviceAddress<S>(inputs, 1); | |||
| T *output_addr = GetDeviceAddress<T>(outputs, 0); | |||
| S *indices = GetDeviceAddress<S>(outputs, 1); | |||
| T *data_buff = nullptr; | |||
| S *index_buff = nullptr; | |||
| if (use_share_mem_ == false) { | |||
| data_buff = GetDeviceAddress<T>(workspaces, 0); | |||
| index_buff = GetDeviceAddress<S>(workspaces, 1); | |||
| } | |||
| TopK(outer_size_, inner_size_, input_addr, k, output_addr, indices, data_buff, index_buff, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| const T init_k = std::numeric_limits<T>::lowest(); | |||
| if (sorted_ == false) { | |||
| BitonicSortByKey(outer_size_, k_, output_addr, indices, data_buff, index_buff, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| FastTopK(outer_size_, inner_size_, input_addr, k, output_addr, indices, init_k, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| input_shape_size_ = input_shapes.size(); | |||
| for (size_t i = 0; i < input_shapes.size() - 1; i++) { | |||
| outer_size_ *= input_shapes[i]; | |||
| } | |||
| @@ -68,13 +60,6 @@ class TopKGpuKernel : public GpuKernel { | |||
| sorted_ = GetAttr<bool>(kernel_node, "sorted"); | |||
| ceil_power2_ = RoundUpPower2(inner_size_); | |||
| size_t buffer_size = ceil_power2_ * (sizeof(T) + sizeof(S)); | |||
| if (buffer_size > SHARED_MEM_PER_BLOCK) { | |||
| use_share_mem_ = false; | |||
| MS_LOG(INFO) << "CUDA share memory not enough, sort with RAM"; | |||
| } | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| @@ -85,10 +70,6 @@ class TopKGpuKernel : public GpuKernel { | |||
| input_size_list_.push_back(sizeof(S)); | |||
| output_size_list_.push_back(outer_size_ * k_ * sizeof(T)); | |||
| output_size_list_.push_back(outer_size_ * k_ * sizeof(S)); | |||
| if (use_share_mem_ == false) { | |||
| workspace_size_list_.push_back(outer_size_ * ceil_power2_ * sizeof(T)); | |||
| workspace_size_list_.push_back(outer_size_ * ceil_power2_ * sizeof(S)); | |||
| } | |||
| } | |||
| private: | |||
| @@ -96,8 +77,7 @@ class TopKGpuKernel : public GpuKernel { | |||
| size_t outer_size_; | |||
| size_t inner_size_; | |||
| size_t k_; | |||
| bool use_share_mem_; | |||
| size_t ceil_power2_; | |||
| int input_shape_size_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| @@ -106,4 +86,4 @@ class TopKGpuKernel : public GpuKernel { | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // TopKpuKernel | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TOPK_GPU_KERNEL_H_ | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -23,6 +23,10 @@ | |||
| #define BLOCKSIZE 256 | |||
| #define MAX_DIMENSION 5 | |||
| template <typename T, typename S, typename K> | |||
| void CalRandomChoiceWithMaskSmall(int input_size, int seedc, int count, K *input, S *output_index, K *output_mask, | |||
| cudaStream_t stream); | |||
| template <typename T, typename S> | |||
| void CalRandomChoiceWithMask(const int &input_size, const int &input_shape_size, const int &d1, const int &d2, | |||
| const int &d3, const int &d4, const int &d5, const int &seedc, const int &count, | |||
| @@ -0,0 +1,152 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/topk_lib.cuh" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh" | |||
| // Kernel started from here | |||
| #define L2_RCWM_HELPER(BLOCK, NUM_WARP_Q, NUM_THREAD_Q, IS_DESCEND) \ | |||
| do { \ | |||
| L2Rcwm<T, S, K, NUM_WARP_Q, NUM_THREAD_Q, BLOCK, IS_DESCEND> \ | |||
| <<<1, BLOCK, 0, stream>>>(seedc, input_size, input, output_mask, output_index, k); \ | |||
| } while (0) | |||
| #define LEFT_INSERT_THREAD_QUEUE(_k, _v) \ | |||
| do { \ | |||
| if (is_descend ? Cmp<T>::gt(_k, warp_K_top) : Cmp<T>::lt(_k, warp_K_top)) { \ | |||
| { \ | |||
| _Pragma("unroll") for (int i = thread_queue - 1; i > 0; --i) { \ | |||
| threadK[i] = threadK[i - 1]; \ | |||
| threadV[i] = threadV[i - 1]; \ | |||
| } \ | |||
| } \ | |||
| threadK[0] = _k; \ | |||
| threadV[0] = _v; \ | |||
| ++num_vals; \ | |||
| } \ | |||
| } while (0) | |||
| template <typename T, typename S, typename K, int warp_queue, int thread_queue, int threads_per_block, bool is_descend> | |||
| __global__ void L2Rcwm(int seedc, int input_size, const K *input, K *output_mask, S *output_index, int k) { | |||
| constexpr int kNumWarps = threads_per_block / kWarpSize; | |||
| constexpr T init_K = static_cast<T>(-2.0); | |||
| constexpr S init_V = static_cast<S>(0); | |||
| __shared__ T shared_K[kNumWarps * warp_queue]; | |||
| __shared__ S shared_V[kNumWarps * warp_queue]; | |||
| curandState devState; | |||
| curand_init(seedc, threadIdx.x, 0, &devState); | |||
| T threadK[thread_queue]; // NOLINT | |||
| S threadV[thread_queue]; // NOLINT | |||
| T *warp_K; | |||
| S *warp_V; | |||
| T warp_K_top = init_K; | |||
| int k_minus_1 = k - 1; | |||
| int num_vals = 0; | |||
| int limit = (input_size / kWarpSize) * kWarpSize; | |||
| int i = threadIdx.x; | |||
| // init begin | |||
| _Pragma("unroll") for (int i = 0; i < thread_queue; ++i) { | |||
| threadK[i] = init_K; | |||
| threadV[i] = init_V; | |||
| } | |||
| int laneId = GetLaneId(); | |||
| int warpId = threadIdx.x / kWarpSize; // 0,1,2 or 3 | |||
| // warp shared memory start address | |||
| warp_K = shared_K + warpId * warp_queue; | |||
| warp_V = shared_V + warpId * warp_queue; | |||
| for (int i = laneId; i < warp_queue; i += kWarpSize) { | |||
| warp_K[i] = init_K; | |||
| warp_V[i] = init_V; | |||
| } | |||
| // sync till all threads init done | |||
| __syncwarp(); | |||
| // insert begin | |||
| for (; i < limit; i += threads_per_block) { | |||
| T rand_num = input[i] ? __uint2float_rn(curand(&devState)) : init_K; | |||
| LEFT_INSERT_THREAD_QUEUE(rand_num, i); | |||
| // CHECK_AND_MERGE_THREAD_QUEUE() begin | |||
| bool needSort = (num_vals == thread_queue); | |||
| needSort = __any_sync(0xffffffff, needSort); | |||
| if (!needSort) continue; | |||
| MergeWarpQueue<T, S, warp_queue, thread_queue, is_descend>(threadK, threadV, warp_K, warp_V); | |||
| num_vals = 0; | |||
| _Pragma("unroll") for (int i = 0; i < thread_queue; ++i) { | |||
| threadK[i] = init_K; | |||
| threadV[i] = init_V; | |||
| } | |||
| warp_K_top = warp_K[k_minus_1]; | |||
| __syncwarp(); | |||
| } | |||
| if (i < input_size) { | |||
| T rand_num = input[i] ? __uint2float_rn(curand(&devState)) : init_K; | |||
| LEFT_INSERT_THREAD_QUEUE(rand_num, i); | |||
| } | |||
| // reduce begin | |||
| MergeWarpQueue<T, S, warp_queue, thread_queue, is_descend>(threadK, threadV, warp_K, warp_V); | |||
| __syncthreads(); | |||
| SortBlockWide<kNumWarps, threads_per_block, T, S, warp_queue, is_descend>(shared_K, shared_V); | |||
| // ship data from shared memory to output buffer | |||
| for (int i = threadIdx.x; i < k; i += blockDim.x) { | |||
| output_mask[i] = shared_K[i] > static_cast<T>(-1.0) ? true : false; | |||
| output_index[i] = shared_V[i]; | |||
| } | |||
| } | |||
| template <typename T, typename S, typename K> | |||
| void RCWMScaleK(int seedc, int input_size, K *input, int k, S *output_index, K *output_mask, cudaStream_t stream) { | |||
| if (k <= 32) { | |||
| // num-threads-of-block, warp-queue-size, thread-queue-size | |||
| L2_RCWM_HELPER(256, 32, 2, true); | |||
| } else if (k <= 64) { | |||
| L2_RCWM_HELPER(256, 64, 3, true); | |||
| } else if (k <= 128) { | |||
| L2_RCWM_HELPER(256, 128, 3, true); | |||
| } else if (k <= 256) { | |||
| L2_RCWM_HELPER(256, 256, 4, true); | |||
| } else if (k <= 512) { | |||
| L2_RCWM_HELPER(256, 512, 8, true); | |||
| } else if (k <= 1024) { | |||
| L2_RCWM_HELPER(128, 1024, 8, true); | |||
| } else if (k <= 2048) { | |||
| L2_RCWM_HELPER(64, 2048, 8, true); | |||
| } | |||
| } | |||
| template <typename T, typename S, typename K> | |||
| void CalRandomChoiceWithMaskSmall(int input_size, int seedc, int count, K *input, S *output_index, K *output_mask, | |||
| cudaStream_t stream) { | |||
| RCWMScaleK<T, S, K>(seedc, input_size, input, count, output_index, output_mask, stream); | |||
| } | |||
| template void CalRandomChoiceWithMaskSmall<float, int, bool>(int input_size, int seedc, int count, bool *input, | |||
| int *output_index, bool *output_mask, cudaStream_t stream); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -15,148 +15,213 @@ | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/topk_lib.cuh" | |||
| #include <limits> | |||
| #include <algorithm> | |||
| size_t RoundUpPower2(size_t v) { | |||
| v--; | |||
| v |= v >> 1; | |||
| v |= v >> 2; | |||
| v |= v >> 4; | |||
| v |= v >> 8; | |||
| v |= v >> 16; | |||
| v++; | |||
| return v; | |||
| } | |||
| const int kMaxQueue = 128; | |||
| template <typename T> | |||
| __inline__ __device__ void Swap(T *lhs, T *rhs) { | |||
| T tmp = lhs[0]; | |||
| lhs[0] = rhs[0]; | |||
| rhs[0] = tmp; | |||
| } | |||
| #define TOPK_HELPER(BLOCK, NUM_WARP_Q, NUM_THREAD_Q, IS_DESCEND) \ | |||
| do { \ | |||
| TopKBlock<T, S, NUM_WARP_Q, NUM_THREAD_Q, BLOCK, IS_DESCEND> \ | |||
| <<<block_num_limit, BLOCK, 0, stream>>>(outer_size, inner_size, input, output, output_index, k_cut, init_K); \ | |||
| } while (0) | |||
| template <typename T, typename S> | |||
| __global__ void TopkKernel(const size_t outer, const size_t inner, const size_t ceil_power2, const T *input, const S *k, | |||
| T *output, S *indices, T *data_buff, S *index_buff) { | |||
| // default: sort with share memory | |||
| extern __shared__ T share_mem[]; | |||
| T *data_arr = share_mem; | |||
| S *index_arr = reinterpret_cast<S *>(data_arr + ceil_power2); | |||
| // sort with RAM | |||
| if (data_buff != nullptr && index_buff != nullptr) { | |||
| data_arr = data_buff + blockIdx.x * ceil_power2; | |||
| index_arr = index_buff + blockIdx.x * ceil_power2; | |||
| #define LEFT_INSERT_THREAD_QUEUE(_k, _v) \ | |||
| do { \ | |||
| if (is_descend ? CmpKV<T, S>::gt(_k, _v, (*ceil_K), (*ceil_V)) : CmpKV<T, S>::lt(_k, _v, (*ceil_K), (*ceil_V))) \ | |||
| break; \ | |||
| if (is_descend ? CmpKV<T, S>::gt(_k, _v, warp_K_top, warp_V_top) \ | |||
| : CmpKV<T, S>::lt(_k, _v, warp_K_top, warp_V_top)) { \ | |||
| { \ | |||
| _Pragma("unroll") for (int i = thread_queue - 1; i > 0; --i) { \ | |||
| threadK[i] = threadK[i - 1]; \ | |||
| threadV[i] = threadV[i - 1]; \ | |||
| } \ | |||
| } \ | |||
| threadK[0] = _k; \ | |||
| threadV[0] = _v; \ | |||
| ++num_vals; \ | |||
| } \ | |||
| } while (0) | |||
| template <typename T, typename S, int warp_queue, int thread_queue, int threads_per_block, bool is_descend> | |||
| inline __device__ void TopKInBuffer(T *shared_K, S *shared_V, int *watermark, T *ceil_K, S *ceil_V, int laneId) { | |||
| constexpr int kNumWarps = threads_per_block / kWarpSize; // kNumWarps is 1024/32=32 | |||
| // find last_K, which is max of last element of warp queue | |||
| T last_K = shared_K[laneId * warp_queue + warp_queue - 1]; | |||
| S last_V = shared_V[laneId * warp_queue + warp_queue - 1]; | |||
| __syncwarp(); | |||
| for (int offset = kNumWarps / 2; offset > 0; offset /= 2) { | |||
| // kNumWarps is 32 if block size is 1024 | |||
| T other_K = __shfl_down_sync(0xffffffff, last_K, offset); | |||
| S other_V = __shfl_down_sync(0xffffffff, last_V, offset); | |||
| bool is_greater = CmpKV<T, S>::gt(other_K, other_V, last_K, last_V); | |||
| ConditionalAssign(is_greater, &last_K, other_K); | |||
| ConditionalAssign(is_greater, &last_V, other_V); | |||
| } | |||
| __syncwarp(); | |||
| for (size_t i = threadIdx.x; i < ceil_power2; i += blockDim.x) { | |||
| data_arr[i] = (i < inner) ? input[blockIdx.x * inner + i] : std::numeric_limits<T>::max(); | |||
| index_arr[i] = i; | |||
| if (laneId == 0) { | |||
| *ceil_K = last_K; | |||
| *ceil_V = last_V; | |||
| } | |||
| __syncthreads(); | |||
| __syncwarp(); | |||
| for (size_t i = 2; i <= ceil_power2; i <<= 1) { | |||
| for (size_t j = (i >> 1); j > 0; j >>= 1) { | |||
| for (size_t tid = threadIdx.x; tid < ceil_power2; tid += blockDim.x) { | |||
| size_t tid_comp = tid ^ j; | |||
| if (tid_comp > tid) { | |||
| if ((tid & i) == 0) { | |||
| if (data_arr[tid] > data_arr[tid_comp]) { | |||
| Swap(&data_arr[tid], &data_arr[tid_comp]); | |||
| Swap(&index_arr[tid], &index_arr[tid_comp]); | |||
| } | |||
| } else { | |||
| if (data_arr[tid] < data_arr[tid_comp]) { | |||
| Swap(&data_arr[tid], &data_arr[tid_comp]); | |||
| Swap(&index_arr[tid], &index_arr[tid_comp]); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| __syncthreads(); | |||
| } | |||
| // calculate index cut by last_K | |||
| int L = 0; | |||
| int R = warp_queue; | |||
| while (L < R) { | |||
| int m = (L + R) / 2; | |||
| CmpKV<T, S>::gt(shared_K[laneId * warp_queue + m], shared_V[laneId * warp_queue + m], (*ceil_K), (*ceil_V)) | |||
| ? L = m + 1 | |||
| : R = m; | |||
| } | |||
| __syncwarp(); | |||
| for (size_t tid = threadIdx.x; tid < k[0]; tid += blockDim.x) { | |||
| output[blockIdx.x * k[0] + tid] = data_arr[inner - tid - 1]; | |||
| indices[blockIdx.x * k[0] + tid] = index_arr[inner - tid - 1]; | |||
| // merge top number which value is greater than last_K | |||
| for (int offset = kNumWarps / 2; offset > 0; offset /= 2) { | |||
| R += __shfl_down_sync(0xffffffff, R, offset); | |||
| } | |||
| } | |||
| template <typename T, typename S> | |||
| void TopK(const size_t &outer, const size_t &inner, const T *input, const S *k, T *output, S *indices, T *data_buff, | |||
| S *index_buff, cudaStream_t stream) { | |||
| size_t ceil_power2 = RoundUpPower2(inner); | |||
| size_t share_mem = (data_buff == nullptr) ? ceil_power2 * (sizeof(T) + sizeof(S)) : 0; | |||
| size_t thread_num = std::min(ceil_power2, static_cast<size_t>(GET_THREADS)); | |||
| TopkKernel<<<outer, thread_num, share_mem, stream>>>(outer, inner, ceil_power2, input, k, output, indices, data_buff, | |||
| index_buff); | |||
| __syncwarp(); | |||
| if (laneId == 0) { | |||
| watermark[0] = R; | |||
| } | |||
| __syncwarp(); | |||
| } | |||
| template <typename T, typename S> | |||
| __global__ void BitonicSortByKeyKernel(const size_t outer, const size_t inner, const size_t ceil_power2, T *input, | |||
| S *indices, T *data_buff, S *index_buff) { | |||
| // default: sort with share memory | |||
| extern __shared__ T share_mem[]; | |||
| T *data_arr = share_mem; | |||
| S *index_arr = reinterpret_cast<S *>(data_arr + ceil_power2); | |||
| // sort with RAM | |||
| if (data_buff != nullptr && index_buff != nullptr) { | |||
| data_arr = data_buff + blockIdx.x * ceil_power2; | |||
| index_arr = index_buff + blockIdx.x * ceil_power2; | |||
| template <typename T, typename S, int warp_queue, int thread_queue, int threads_per_block, bool is_descend> | |||
| inline __device__ void TopKStep(const int &outer_size, const int &inner_size, const T *input, T *output, | |||
| S *output_index, S k_cut, const T &init_K, const int &outer_id, T *shared_K, | |||
| S *shared_V, int *watermark, T *threadK, S *threadV, T *ceil_K, S *ceil_V, S *k_prime) { | |||
| constexpr int kNumWarps = threads_per_block / kWarpSize; | |||
| constexpr S init_V = static_cast<S>(-1); | |||
| T *warp_K; | |||
| S *warp_V; | |||
| T warp_K_top = init_K; | |||
| S warp_V_top = init_V; | |||
| int k_minus_1 = (k_cut <= kMaxQueue ? k_cut - 1 : kMaxQueue - 1); | |||
| int num_vals = 0; | |||
| int limit = (inner_size / kWarpSize) * kWarpSize; | |||
| _Pragma("unroll") for (int i = 0; i < thread_queue; ++i) { | |||
| threadK[i] = init_K; | |||
| threadV[i] = init_V; | |||
| } | |||
| for (size_t i = threadIdx.x; i < ceil_power2; i += blockDim.x) { | |||
| data_arr[i] = (i < inner) ? input[blockIdx.x * inner + i] : std::numeric_limits<T>::max(); | |||
| index_arr[i] = (i < inner) ? indices[blockIdx.x * inner + i] : std::numeric_limits<S>::max(); | |||
| int laneId = GetLaneId(); | |||
| int warpId = threadIdx.x / kWarpSize; // 0,1,2 or 3 | |||
| warp_K = shared_K + warpId * warp_queue; | |||
| warp_V = shared_V + warpId * warp_queue; | |||
| for (int i = laneId; i < warp_queue; i += kWarpSize) { | |||
| warp_K[i] = init_K; | |||
| warp_V[i] = init_V; | |||
| } | |||
| __syncthreads(); | |||
| for (size_t i = 2; i <= ceil_power2; i <<= 1) { | |||
| for (size_t j = (i >> 1); j > 0; j >>= 1) { | |||
| for (size_t tid = threadIdx.x; tid < ceil_power2; tid += blockDim.x) { | |||
| size_t tid_comp = tid ^ j; | |||
| if (tid_comp > tid) { | |||
| if ((tid & i) == 0) { | |||
| if (index_arr[tid] > index_arr[tid_comp]) { | |||
| Swap(&data_arr[tid], &data_arr[tid_comp]); | |||
| Swap(&index_arr[tid], &index_arr[tid_comp]); | |||
| } | |||
| } else { | |||
| if (index_arr[tid] < index_arr[tid_comp]) { | |||
| Swap(&data_arr[tid], &data_arr[tid_comp]); | |||
| Swap(&index_arr[tid], &index_arr[tid_comp]); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| __syncthreads(); | |||
| __syncwarp(); | |||
| int i = threadIdx.x; | |||
| for (; i < limit; i += threads_per_block) { | |||
| LEFT_INSERT_THREAD_QUEUE((input[outer_id * inner_size + i]), (outer_id * inner_size + i)); | |||
| bool needSort = (num_vals == thread_queue); | |||
| needSort = __any_sync(0xffffffff, needSort); | |||
| if (!needSort) continue; | |||
| MergeWarpQueue<T, S, warp_queue, thread_queue, is_descend>(threadK, threadV, warp_K, warp_V); | |||
| num_vals = 0; | |||
| _Pragma("unroll") for (int i = 0; i < thread_queue; ++i) { | |||
| threadK[i] = init_K; | |||
| threadV[i] = init_V; | |||
| } | |||
| warp_K_top = warp_K[k_minus_1]; | |||
| warp_V_top = warp_V[k_minus_1]; | |||
| __syncwarp(); | |||
| } | |||
| if (i < inner_size) { | |||
| LEFT_INSERT_THREAD_QUEUE((input[outer_id * inner_size + i]), (outer_id * inner_size + i)); | |||
| } | |||
| MergeWarpQueue<T, S, warp_queue, thread_queue, is_descend>(threadK, threadV, warp_K, warp_V); | |||
| __syncthreads(); | |||
| if (k_cut > kMaxQueue && warpId == 0) { | |||
| TopKInBuffer<T, S, warp_queue, thread_queue, threads_per_block, is_descend>(shared_K, shared_V, watermark, ceil_K, | |||
| ceil_V, laneId); | |||
| } | |||
| __syncthreads(); | |||
| SortBlockWide<kNumWarps, threads_per_block, T, S, warp_queue, is_descend>(shared_K, shared_V); | |||
| S k_step = (*k_prime) + watermark[0] <= k_cut ? watermark[0] : k_cut - (*k_prime); | |||
| for (int i = threadIdx.x; i < k_step; i += blockDim.x) { | |||
| output[outer_id * k_cut + (*k_prime) + i] = shared_K[i]; | |||
| output_index[outer_id * k_cut + (*k_prime) + i] = shared_V[i] % inner_size; | |||
| } | |||
| *k_prime += k_step; | |||
| __syncthreads(); | |||
| } | |||
| template <typename T, typename S, int warp_queue, int thread_queue, int threads_per_block, bool is_descend> | |||
| __global__ void TopKBlock(int outer_size, int inner_size, const T *input, T *output, S *output_index, S k_cut, | |||
| const T init_K) { | |||
| constexpr int kNumWarps = threads_per_block / kWarpSize; | |||
| __shared__ T shared_K[kNumWarps * warp_queue]; | |||
| __shared__ S shared_V[kNumWarps * warp_queue]; | |||
| __shared__ int watermark[1]; | |||
| __shared__ T ceil_K; | |||
| __shared__ S ceil_V; | |||
| for (size_t tid = threadIdx.x; tid < inner; tid += blockDim.x) { | |||
| input[blockIdx.x * inner + tid] = data_arr[tid]; | |||
| indices[blockIdx.x * inner + tid] = index_arr[tid]; | |||
| T threadK[thread_queue]; // NOLINT | |||
| S threadV[thread_queue]; // NOLINT | |||
| for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < blockDim.x * outer_size; | |||
| t_idx += blockDim.x * gridDim.x) { | |||
| S k_prime = 0; | |||
| int outer_id = t_idx / blockDim.x; | |||
| ceil_K = -init_K; | |||
| ceil_V = -1; | |||
| watermark[0] = k_cut; | |||
| do { | |||
| TopKStep<T, S, warp_queue, thread_queue, threads_per_block, is_descend>( | |||
| outer_size, inner_size, input, output, output_index, k_cut, init_K, outer_id, shared_K, shared_V, watermark, | |||
| threadK, threadV, &ceil_K, &ceil_V, &k_prime); | |||
| } while (k_prime < k_cut); | |||
| } | |||
| } | |||
| template <typename T, typename S> | |||
| void BitonicSortByKey(const size_t &outer, const size_t &inner, T *input, S *indices, T *data_buff, S *index_buff, | |||
| cudaStream_t stream) { | |||
| size_t ceil_power2 = RoundUpPower2(inner); | |||
| size_t share_mem = ceil_power2 * (sizeof(T) + sizeof(S)); | |||
| if (share_mem > SHARED_MEM_PER_BLOCK) { | |||
| share_mem = 0; | |||
| void FastTopK(const int outer_size, const int inner_size, const T *input, const S *k, T *output, S *output_index, | |||
| const T init_K, cudaStream_t stream) { | |||
| int block_num_limit = outer_size < 128 ? outer_size : 128; | |||
| S k_cut = 0; | |||
| cudaMemcpy(&k_cut, k, sizeof(S), cudaMemcpyDeviceToHost); | |||
| if (k_cut > inner_size) k_cut = inner_size; | |||
| if (k_cut <= 32) { | |||
| // num-threads-of-block, warp-queue-size, thread-queue-size | |||
| TOPK_HELPER(256, 32, 2, true); | |||
| } else if (k_cut <= 64) { | |||
| TOPK_HELPER(256, 64, 3, true); | |||
| } else if (k_cut <= 128) { | |||
| TOPK_HELPER(256, 128, 3, true); | |||
| } else { | |||
| data_buff = nullptr; | |||
| index_buff = nullptr; | |||
| TOPK_HELPER(1024, 128, 3, true); | |||
| } | |||
| size_t thread_num = std::min(ceil_power2, static_cast<size_t>(GET_THREADS)); | |||
| BitonicSortByKeyKernel<<<outer, thread_num, share_mem, stream>>>(outer, inner, ceil_power2, input, indices, data_buff, | |||
| index_buff); | |||
| } | |||
| template void TopK(const size_t &outer, const size_t &inner, const float *input_addr, const int *k, float *output, | |||
| int *indices, float *data_buff, int *index_buff, cudaStream_t stream); | |||
| template void BitonicSortByKey(const size_t &outer, const size_t &inner, float *input, int *indices, float *data_buff, | |||
| int *index_buff, cudaStream_t stream); | |||
| template void FastTopK(const int outer_size, const int inner_size, const float *input, const int *k, float *output, | |||
| int *output_index, const float init_K, cudaStream_t stream); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -14,19 +14,14 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TOPK_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TOPK_H_ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TOPK_IMPL_CUH_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TOPK_IMPL_CUH_ | |||
| #include <cuda_runtime.h> | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T, typename S> | |||
| void TopK(const size_t &outer, const size_t &inner, const T *input_addr, const S *k, T *output, S *indices, | |||
| T *data_buff, S *index_buff, cudaStream_t stream); | |||
| void FastTopK(const int outer, const int inner, const T *input_addr, const S *k, T *output, S *indices, const T initK, | |||
| cudaStream_t stream); | |||
| template <typename T, typename S> | |||
| void BitonicSortByKey(const size_t &outer, const size_t &inner, T *input, S *indices, T *data_buff, S *index_buff, | |||
| cudaStream_t stream); | |||
| size_t RoundUpPower2(size_t v); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TOPK_H_ | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TOPK_IMPL_CUH_ | |||
| @@ -0,0 +1,479 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #pragma once | |||
| constexpr int kWarpSize = 32; | |||
| constexpr __host__ __device__ int Log2(int n, int p = 0) { return (n <= 1) ? p : Log2(n / 2, p + 1); } | |||
| constexpr __host__ __device__ bool IsPow2(int v) { return (v && !(v & (v - 1))); } | |||
| constexpr __host__ __device__ int NextPow2(int v) { return (IsPow2(v) ? 2 * v : (1 << static_cast<int>(Log2(v) + 1))); } | |||
| __device__ __forceinline__ int GetLaneId() { | |||
| int laneId; | |||
| asm("mov.u32 %0, %%laneid;" : "=r"(laneId)); | |||
| return laneId; | |||
| } | |||
| template <typename T, typename S> | |||
| struct CmpKV { | |||
| __device__ static inline bool gt(T k1, S v1, T k2, S v2) { return k1 > k2 || (k1 == k2 && v1 < v2); } | |||
| __device__ static inline bool lt(T k1, S v1, T k2, S v2) { return k1 < k2 || (k1 == k2 && v1 > v2); } | |||
| }; | |||
| template <typename T> | |||
| struct Cmp { | |||
| __device__ static inline bool lt(T a, T b) { return a < b; } | |||
| __device__ static inline bool gt(T a, T b) { return a > b; } | |||
| }; | |||
| template <typename T> | |||
| inline __device__ T shfl_xor(const T val, int laneMask, int width = kWarpSize) { | |||
| return __shfl_xor_sync(0xffffffff, val, laneMask, width); | |||
| } | |||
| template <typename T, typename S, bool is_descend> | |||
| inline __device__ void L2CompareAndSwap(T *a, S *b, int i_1, int i_2) { | |||
| bool swap = | |||
| is_descend ? CmpKV<T, S>::gt(a[i_1], b[i_1], a[i_2], b[i_2]) : CmpKV<T, S>::lt(a[i_1], b[i_1], a[i_2], b[i_2]); | |||
| if (!swap) return; | |||
| T a_tmp = a[i_1]; | |||
| a[i_1] = a[i_2]; | |||
| a[i_2] = a_tmp; | |||
| T b_tmp = b[i_1]; | |||
| b[i_1] = b[i_2]; | |||
| b[i_2] = b_tmp; | |||
| } | |||
| template <typename T> | |||
| inline __device__ void ConditionalAssign(bool is_assign, T *x, const T &y) { | |||
| (*x) = is_assign ? y : (*x); | |||
| } | |||
| // Merge pairs of lists smaller than threads-per-block | |||
| // NumThreads is 128 | |||
| // N is 2, 1 etc | |||
| // L is 32, 64 etc | |||
| template <int NumThreads, typename T, typename S, int N, int L, bool AllThreads, bool is_descend, bool FullMerge> | |||
| inline __device__ void BlockSortSmallK(T *list_k, S *list_v) { | |||
| int mergeId = threadIdx.x / L; | |||
| int tid = threadIdx.x % L; | |||
| list_k += 2 * L * mergeId; | |||
| list_v += 2 * L * mergeId; | |||
| int pos = L - 1 - tid; | |||
| int stride = 2 * tid + 1; | |||
| if (AllThreads || (static_cast<int>(threadIdx.x) < N * L)) { | |||
| L2CompareAndSwap<T, S, is_descend>(list_k, list_v, pos, pos + stride); | |||
| } | |||
| __syncthreads(); | |||
| _Pragma("unroll") for (int stride = L / 2; stride > 0; stride /= 2) { | |||
| int pos = 2 * tid - (tid & (stride - 1)); | |||
| if (AllThreads || (static_cast<int>(threadIdx.x) < N * L)) { | |||
| L2CompareAndSwap<T, S, is_descend>(list_k, list_v, pos, pos + stride); | |||
| } | |||
| __syncthreads(); | |||
| } | |||
| } | |||
| // Merge pairs of lists larger than threads-per-block | |||
| template <int NumThreads, typename T, typename S, int L, bool is_descend, bool FullMerge> | |||
| inline __device__ void BlockSortBigK(T *list_k, S *list_v) { | |||
| constexpr int kLoopPerThread = L / NumThreads; | |||
| _Pragma("unroll") for (int loop = 0; loop < kLoopPerThread; ++loop) { | |||
| int tid = loop * NumThreads + threadIdx.x; | |||
| int pos = L - 1 - tid; | |||
| int stride = 2 * tid + 1; | |||
| L2CompareAndSwap<T, S, is_descend>(list_k, list_v, pos, pos + stride); | |||
| } | |||
| __syncthreads(); | |||
| constexpr int kSecondLoopPerThread = FullMerge ? kLoopPerThread : kLoopPerThread / 2; | |||
| _Pragma("unroll") for (int stride = L / 2; stride > 0; stride /= 2) { | |||
| _Pragma("unroll") for (int loop = 0; loop < kSecondLoopPerThread; ++loop) { | |||
| int tid = loop * NumThreads + threadIdx.x; | |||
| int pos = 2 * tid - (tid & (stride - 1)); | |||
| L2CompareAndSwap<T, S, is_descend>(list_k, list_v, pos, pos + stride); | |||
| } | |||
| __syncthreads(); | |||
| } | |||
| } | |||
| /// Merging lists smaller than threads-per-block | |||
| template <int NumThreads, typename T, typename S, int N, int L, bool is_descend, bool FullMerge = true> | |||
| inline __device__ void SortBlockStep(T *list_k, S *list_v) { | |||
| if (L <= NumThreads) { | |||
| int kNumParallelMerges = NumThreads / L; | |||
| int kNumIterations = N / kNumParallelMerges; | |||
| if (N < kNumParallelMerges) { | |||
| BlockSortSmallK<NumThreads, T, S, N, L, false, is_descend, FullMerge>(list_k, list_v); | |||
| } else { | |||
| _Pragma("unroll") for (int i = 0; i < kNumIterations; ++i) { | |||
| int start = i * kNumParallelMerges * 2 * L; | |||
| BlockSortSmallK<NumThreads, T, S, N, L, true, is_descend, FullMerge>(list_k + start, list_v + start); | |||
| } | |||
| } | |||
| } else { | |||
| _Pragma("unroll") for (int i = 0; i < N; ++i) { | |||
| int start = i * 2 * L; | |||
| BlockSortBigK<NumThreads, T, S, L, is_descend, FullMerge>(list_k + start, list_v + start); | |||
| } | |||
| } | |||
| } | |||
| // Block-wide merge | |||
| template <int NumWarps, int NumThreads, typename T, typename S, int warp_queue, bool is_descend> | |||
| inline __device__ void SortBlockWide(T *shared_K, S *shared_V) { | |||
| if (NumWarps == 2) { | |||
| SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 2), warp_queue, !is_descend, false>(shared_K, shared_V); | |||
| } else if (NumWarps == 4) { | |||
| SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 2), warp_queue, !is_descend>(shared_K, shared_V); | |||
| SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 4), warp_queue * 2, !is_descend, false>(shared_K, | |||
| shared_V); | |||
| } else if (NumWarps == 8) { | |||
| SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 2), warp_queue, !is_descend>(shared_K, shared_V); | |||
| SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 4), warp_queue * 2, !is_descend>(shared_K, shared_V); | |||
| SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 8), warp_queue * 4, !is_descend, false>(shared_K, | |||
| shared_V); | |||
| } else if (NumWarps == 16) { | |||
| SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 2), warp_queue, !is_descend>(shared_K, shared_V); | |||
| SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 4), warp_queue * 2, !is_descend>(shared_K, shared_V); | |||
| SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 8), warp_queue * 4, !is_descend>(shared_K, shared_V); | |||
| SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 16), warp_queue * 8, !is_descend>(shared_K, shared_V); | |||
| } else if (NumWarps == 32) { | |||
| SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 2), warp_queue, !is_descend>(shared_K, shared_V); | |||
| SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 4), warp_queue * 2, !is_descend>(shared_K, shared_V); | |||
| SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 8), warp_queue * 4, !is_descend>(shared_K, shared_V); | |||
| SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 16), warp_queue * 8, !is_descend>(shared_K, shared_V); | |||
| SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 32), warp_queue * 16, !is_descend>(shared_K, shared_V); | |||
| } | |||
| } | |||
| template <typename T, typename S, int L, bool is_descend, bool IsBitonic> | |||
| inline __device__ void BitonicSortWarpLE16(T *k, S *v) { | |||
| int laneId = GetLaneId(); | |||
| if (!IsBitonic) { | |||
| // Reverse the first comparison stage. head-tail swap. | |||
| T other_K = shfl_xor((*k), 2 * L - 1); | |||
| S other_V = shfl_xor((*v), 2 * L - 1); | |||
| bool small = !(laneId & L); | |||
| bool small_compare = small ? CmpKV<T, S>::gt((*k), (*v), other_K, other_V) : | |||
| CmpKV<T, S>::lt((*k), (*v), other_K, other_V); | |||
| bool small_compare_descend = is_descend ? small_compare : !small_compare; | |||
| ConditionalAssign(small_compare_descend, k, other_K); | |||
| ConditionalAssign(small_compare_descend, v, other_V); | |||
| } | |||
| _Pragma("unroll") for (int stride = IsBitonic ? L : L / 2; stride > 0; stride /= 2) { | |||
| T other_K = shfl_xor((*k), stride); | |||
| S other_V = shfl_xor((*v), stride); | |||
| bool small = !(laneId & stride); | |||
| bool small_compare = small ? CmpKV<T, S>::gt((*k), (*v), other_K, other_V) : | |||
| CmpKV<T, S>::lt((*k), (*v), other_K, other_V); | |||
| bool small_compare_descend = is_descend ? small_compare : !small_compare; | |||
| ConditionalAssign(small_compare_descend, k, other_K); | |||
| ConditionalAssign(small_compare_descend, v, other_V); | |||
| } | |||
| } | |||
| template <typename T, typename S, int N, bool is_descend, bool Low, bool Pow2> | |||
| struct MergeWarpStepBitonic {}; | |||
| // All merges call this | |||
| template <typename T, typename S, bool is_descend, bool Low> | |||
| struct MergeWarpStepBitonic<T, S, 1, is_descend, Low, true> { | |||
| static inline __device__ void merge(T k[1], S v[1]) { BitonicSortWarpLE16<T, S, 16, is_descend, true>(&k[0], &v[0]); } | |||
| }; | |||
| template <typename T, typename S, int N, bool is_descend, bool Low> | |||
| struct MergeWarpStepBitonic<T, S, N, is_descend, Low, true> { | |||
| static inline __device__ void merge(T k[N], S v[N]) { | |||
| _Pragma("unroll") for (int i = 0; i < N / 2; ++i) { L2CompareAndSwap<T, S, is_descend>(k, v, i, i + N / 2); } | |||
| { | |||
| T newK[N / 2]; | |||
| S newV[N / 2]; | |||
| _Pragma("unroll") for (int i = 0; i < N / 2; ++i) { | |||
| newK[i] = k[i]; | |||
| newV[i] = v[i]; | |||
| } | |||
| MergeWarpStepBitonic<T, S, N / 2, is_descend, true, true>::merge(newK, newV); | |||
| _Pragma("unroll") for (int i = 0; i < N / 2; ++i) { | |||
| k[i] = newK[i]; | |||
| v[i] = newV[i]; | |||
| } | |||
| } | |||
| { | |||
| T newK[N / 2]; | |||
| S newV[N / 2]; | |||
| _Pragma("unroll") for (int i = 0; i < N / 2; ++i) { | |||
| newK[i] = k[i + N / 2]; | |||
| newV[i] = v[i + N / 2]; | |||
| } | |||
| MergeWarpStepBitonic<T, S, N / 2, is_descend, false, true>::merge(newK, newV); | |||
| _Pragma("unroll") for (int i = 0; i < N / 2; ++i) { | |||
| k[i + N / 2] = newK[i]; | |||
| v[i + N / 2] = newV[i]; | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| // Low recursion | |||
| template <typename T, typename S, int N, bool is_descend> | |||
| struct MergeWarpStepBitonic<T, S, N, is_descend, true, false> { | |||
| static inline __device__ void merge(T k[N], S v[N]) { | |||
| constexpr int kNextHighestPowerOf2 = NextPow2(N); | |||
| _Pragma("unroll") for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) { | |||
| L2CompareAndSwap<T, S, is_descend>(k, v, i, i + kNextHighestPowerOf2 / 2); | |||
| } | |||
| constexpr int kLowSize = N - kNextHighestPowerOf2 / 2; | |||
| constexpr int kHighSize = kNextHighestPowerOf2 / 2; | |||
| { | |||
| T newK[kLowSize]; | |||
| S newV[kLowSize]; | |||
| _Pragma("unroll") for (int i = 0; i < kLowSize; ++i) { | |||
| newK[i] = k[i]; | |||
| newV[i] = v[i]; | |||
| } | |||
| constexpr bool kLowIsPowerOf2 = IsPow2(N - kNextHighestPowerOf2 / 2); | |||
| MergeWarpStepBitonic<T, S, kLowSize, is_descend, true, kLowIsPowerOf2>::merge(newK, newV); | |||
| _Pragma("unroll") for (int i = 0; i < kLowSize; ++i) { | |||
| k[i] = newK[i]; | |||
| v[i] = newV[i]; | |||
| } | |||
| } | |||
| { | |||
| T newK[kHighSize]; | |||
| S newV[kHighSize]; | |||
| _Pragma("unroll") for (int i = 0; i < kHighSize; ++i) { | |||
| newK[i] = k[i + kLowSize]; | |||
| newV[i] = v[i + kLowSize]; | |||
| } | |||
| constexpr bool kHighIsPowerOf2 = IsPow2(kNextHighestPowerOf2 / 2); | |||
| MergeWarpStepBitonic<T, S, kHighSize, is_descend, false, kHighIsPowerOf2>::merge(newK, newV); | |||
| _Pragma("unroll") for (int i = 0; i < kHighSize; ++i) { | |||
| k[i + kLowSize] = newK[i]; | |||
| v[i + kLowSize] = newV[i]; | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| // High recursion | |||
| template <typename T, typename S, int N, bool is_descend> | |||
| struct MergeWarpStepBitonic<T, S, N, is_descend, false, false> { | |||
| static inline __device__ void merge(T k[N], S v[N]) { | |||
| constexpr int kNextHighestPowerOf2 = NextPow2(N); | |||
| _Pragma("unroll") for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) { | |||
| L2CompareAndSwap<T, S, is_descend>(k, v, i, i + kNextHighestPowerOf2 / 2); | |||
| } | |||
| constexpr int kLowSize = kNextHighestPowerOf2 / 2; | |||
| constexpr int kHighSize = N - kNextHighestPowerOf2 / 2; | |||
| { | |||
| T newK[kLowSize]; | |||
| S newV[kLowSize]; | |||
| _Pragma("unroll") for (int i = 0; i < kLowSize; ++i) { | |||
| newK[i] = k[i]; | |||
| newV[i] = v[i]; | |||
| } | |||
| constexpr bool kLowIsPowerOf2 = IsPow2(kNextHighestPowerOf2 / 2); | |||
| MergeWarpStepBitonic<T, S, kLowSize, is_descend, true, kLowIsPowerOf2>::merge(newK, newV); | |||
| _Pragma("unroll") for (int i = 0; i < kLowSize; ++i) { | |||
| k[i] = newK[i]; | |||
| v[i] = newV[i]; | |||
| } | |||
| } | |||
| { | |||
| T newK[kHighSize]; | |||
| S newV[kHighSize]; | |||
| _Pragma("unroll") for (int i = 0; i < kHighSize; ++i) { | |||
| newK[i] = k[i + kLowSize]; | |||
| newV[i] = v[i + kLowSize]; | |||
| } | |||
| constexpr bool kHighIsPowerOf2 = IsPow2(N - kNextHighestPowerOf2 / 2); | |||
| MergeWarpStepBitonic<T, S, kHighSize, is_descend, false, kHighIsPowerOf2>::merge(newK, newV); | |||
| _Pragma("unroll") for (int i = 0; i < kHighSize; ++i) { | |||
| k[i + kLowSize] = newK[i]; | |||
| v[i + kLowSize] = newV[i]; | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| /// Merges two sets of registers across the warp of any size; | |||
| template <typename T, typename S, int N1, int N2, bool is_descend, bool FullMerge = true> | |||
| inline __device__ void MergeWarpByRegister(T k1[N1], S v1[N1], T k2[N2], S v2[N2]) { | |||
| constexpr int kSmallestN = N1 < N2 ? N1 : N2; | |||
| _Pragma("unroll") for (int i = 0; i < kSmallestN; ++i) { | |||
| T &ka = k1[N1 - 1 - i]; | |||
| S &va = v1[N1 - 1 - i]; | |||
| T &kb = k2[i]; | |||
| S &vb = v2[i]; | |||
| T other_Ka; | |||
| S other_Va; | |||
| if (FullMerge) { | |||
| other_Ka = shfl_xor(ka, kWarpSize - 1); | |||
| other_Va = shfl_xor(va, kWarpSize - 1); | |||
| } | |||
| T other_Kb = shfl_xor(kb, kWarpSize - 1); | |||
| S other_Vb = shfl_xor(vb, kWarpSize - 1); | |||
| bool swapa = is_descend ? CmpKV<T, S>::gt(ka, va, other_Kb, other_Vb) : CmpKV<T, S>::lt(ka, va, other_Kb, other_Vb); | |||
| ConditionalAssign(swapa, &ka, other_Kb); | |||
| ConditionalAssign(swapa, &va, other_Vb); | |||
| if (FullMerge) { | |||
| bool swapb = is_descend ? CmpKV<T, S>::lt(kb, vb, other_Ka, other_Va) : | |||
| CmpKV<T, S>::gt(kb, vb, other_Ka, other_Va); | |||
| ConditionalAssign(swapb, &kb, other_Ka); | |||
| ConditionalAssign(swapb, &vb, other_Va); | |||
| } | |||
| } | |||
| MergeWarpStepBitonic<T, S, N1, is_descend, true, IsPow2(N1)>::merge(k1, v1); | |||
| if (FullMerge) { | |||
| MergeWarpStepBitonic<T, S, N2, is_descend, false, IsPow2(N2)>::merge(k2, v2); | |||
| } | |||
| } | |||
| // Recursive template that uses the above bitonic merge | |||
| template <typename T, typename S, int N, bool is_descend> | |||
| struct SortWarpStepBitonic { | |||
| static inline __device__ void Sort(T k[N], S v[N]) { | |||
| constexpr int kSizeA = N / 2; | |||
| constexpr int kSizeB = N - kSizeA; | |||
| T aK[kSizeA]; | |||
| S aV[kSizeA]; | |||
| _Pragma("unroll") for (int i = 0; i < kSizeA; ++i) { | |||
| aK[i] = k[i]; | |||
| aV[i] = v[i]; | |||
| } | |||
| // Recursive sort | |||
| SortWarpStepBitonic<T, S, kSizeA, is_descend>::Sort(aK, aV); | |||
| T bK[kSizeB]; | |||
| S bV[kSizeB]; | |||
| _Pragma("unroll") for (int i = 0; i < kSizeB; ++i) { | |||
| bK[i] = k[i + kSizeA]; | |||
| bV[i] = v[i + kSizeA]; | |||
| } | |||
| SortWarpStepBitonic<T, S, kSizeB, is_descend>::Sort(bK, bV); | |||
| // Merge halves | |||
| MergeWarpByRegister<T, S, kSizeA, kSizeB, is_descend>(aK, aV, bK, bV); | |||
| _Pragma("unroll") for (int i = 0; i < kSizeA; ++i) { | |||
| k[i] = aK[i]; | |||
| v[i] = aV[i]; | |||
| } | |||
| _Pragma("unroll") for (int i = 0; i < kSizeB; ++i) { | |||
| k[i + kSizeA] = bK[i]; | |||
| v[i + kSizeA] = bV[i]; | |||
| } | |||
| } | |||
| }; | |||
| template <typename T, typename S, bool is_descend> | |||
| struct SortWarpStepBitonic<T, S, 1, is_descend> { | |||
| static inline __device__ void Sort(T k[1], S v[1]) { | |||
| // up to warp-size/2 | |||
| BitonicSortWarpLE16<T, S, 1, is_descend, false>(&k[0], &v[0]); | |||
| BitonicSortWarpLE16<T, S, 2, is_descend, false>(&k[0], &v[0]); | |||
| BitonicSortWarpLE16<T, S, 4, is_descend, false>(&k[0], &v[0]); | |||
| BitonicSortWarpLE16<T, S, 8, is_descend, false>(&k[0], &v[0]); | |||
| BitonicSortWarpLE16<T, S, 16, is_descend, false>(&k[0], &v[0]); | |||
| } | |||
| }; | |||
| template <typename T, typename S, int N, bool is_descend> | |||
| inline __device__ void SortWarpByRegister(T k[N], S v[N]) { | |||
| SortWarpStepBitonic<T, S, N, is_descend>::Sort(k, v); | |||
| } | |||
| template <typename T, typename S, int warp_queue, int thread_queue, bool is_descend> | |||
| inline __device__ void MergeWarpQueue(T *threadK, S *threadV, T *warp_K, S *warp_V) { | |||
| int laneId = GetLaneId(); | |||
| SortWarpByRegister<T, S, thread_queue, !is_descend>(threadK, threadV); | |||
| constexpr int kWarpQueueRegisters = warp_queue / kWarpSize; | |||
| T warp_KRegisters[kWarpQueueRegisters]; | |||
| S warp_VRegisters[kWarpQueueRegisters]; | |||
| _Pragma("unroll") for (int i = 0; i < kWarpQueueRegisters; ++i) { | |||
| warp_KRegisters[i] = warp_K[i * kWarpSize + laneId]; | |||
| warp_VRegisters[i] = warp_V[i * kWarpSize + laneId]; | |||
| } | |||
| __syncwarp(); | |||
| MergeWarpByRegister<T, S, kWarpQueueRegisters, thread_queue, !is_descend, false>(warp_KRegisters, warp_VRegisters, | |||
| threadK, threadV); | |||
| _Pragma("unroll") for (int i = 0; i < kWarpQueueRegisters; ++i) { | |||
| warp_K[i * kWarpSize + laneId] = warp_KRegisters[i]; | |||
| warp_V[i * kWarpSize + laneId] = warp_VRegisters[i]; | |||
| } | |||
| __syncwarp(); | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -39,17 +39,22 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel { | |||
| T *input = GetDeviceAddress<T>(inputs, 0); | |||
| S *output_index = GetDeviceAddress<S>(outputs, 0); | |||
| T *output_mask = GetDeviceAddress<T>(outputs, 1); | |||
| S *index_buff = GetDeviceAddress<S>(workspaces, 0); | |||
| S *mask_buff = GetDeviceAddress<S>(workspaces, 1); | |||
| S *rank_buff = GetDeviceAddress<S>(workspaces, 2); | |||
| S *Tnum_buff = GetDeviceAddress<S>(workspaces, 3); | |||
| S *tmp_buff = GetDeviceAddress<S>(workspaces, 4); | |||
| void *States = GetDeviceAddress<void *>(workspaces, 5); | |||
| curandState *devStates = reinterpret_cast<curandState *>(States); | |||
| CalRandomChoiceWithMask(input_size_, input_shape_size_, input_shape_5D_[0], input_shape_5D_[1], input_shape_5D_[2], | |||
| input_shape_5D_[3], input_shape_5D_[4], seedc_, count_, input, output_index, output_mask, | |||
| index_buff, mask_buff, rank_buff, Tnum_buff, tmp_buff, devStates, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| if (count_ > kSmallK || input_shape_size_ > 1) { | |||
| S *index_buff = GetDeviceAddress<S>(workspaces, 0); | |||
| S *mask_buff = GetDeviceAddress<S>(workspaces, 1); | |||
| S *rank_buff = GetDeviceAddress<S>(workspaces, 2); | |||
| S *Tnum_buff = GetDeviceAddress<S>(workspaces, 3); | |||
| S *tmp_buff = GetDeviceAddress<S>(workspaces, 4); | |||
| void *States = GetDeviceAddress<void *>(workspaces, 5); | |||
| curandState *devStates = reinterpret_cast<curandState *>(States); | |||
| CalRandomChoiceWithMask(input_size_, input_shape_size_, input_shape_5D_[0], input_shape_5D_[1], | |||
| input_shape_5D_[2], input_shape_5D_[3], input_shape_5D_[4], seedc_, count_, input, | |||
| output_index, output_mask, index_buff, mask_buff, rank_buff, Tnum_buff, tmp_buff, | |||
| devStates, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } else { | |||
| CalRandomChoiceWithMaskSmall<float, S, T>(input_size_, seedc_, count_, input, output_index, output_mask, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| return true; | |||
| } | |||
| @@ -94,7 +99,9 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel { | |||
| } | |||
| count_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "count")); | |||
| // upper ceiling for input for ceil_power2 | |||
| ceil_power2_ = RcwmRoundUpPower2(input_size_); | |||
| if (count_ > kSmallK || input_shape_size_ > 1) { | |||
| ceil_power2_ = RcwmRoundUpPower2(input_size_); | |||
| } | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| @@ -104,16 +111,19 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel { | |||
| input_size_list_.push_back(input_size_ * sizeof(T)); | |||
| output_size_list_.push_back(count_ * input_shape_size_ * sizeof(S)); | |||
| output_size_list_.push_back(count_ * sizeof(T)); | |||
| workspace_size_list_.push_back(input_size_ * input_shape_size_ * sizeof(S)); | |||
| workspace_size_list_.push_back(ceil_power2_ * sizeof(S)); | |||
| workspace_size_list_.push_back(ceil_power2_ * sizeof(S)); | |||
| int blocknum = std::ceil(static_cast<float>(ceil_power2_) / BLOCKSIZE); | |||
| workspace_size_list_.push_back(blocknum * sizeof(S)); | |||
| workspace_size_list_.push_back(ceil_power2_ * sizeof(S)); | |||
| workspace_size_list_.push_back(ceil_power2_ * sizeof(curandState)); | |||
| if (count_ > kSmallK || input_shape_size_ > 1) { | |||
| workspace_size_list_.push_back(input_size_ * input_shape_size_ * sizeof(S)); | |||
| workspace_size_list_.push_back(ceil_power2_ * sizeof(S)); | |||
| workspace_size_list_.push_back(ceil_power2_ * sizeof(S)); | |||
| int blocknum = std::ceil(static_cast<float>(ceil_power2_) / BLOCKSIZE); | |||
| workspace_size_list_.push_back(blocknum * sizeof(S)); | |||
| workspace_size_list_.push_back(ceil_power2_ * sizeof(S)); | |||
| workspace_size_list_.push_back(ceil_power2_ * sizeof(curandState)); | |||
| } | |||
| } | |||
| private: | |||
| const int kSmallK = 2048; | |||
| int input_shape_size_; | |||
| int seedc_; | |||
| int input_size_; | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-21 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -21,6 +21,7 @@ import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| class RCWM_count_in(nn.Cell): | |||
| def __init__(self): | |||
| super(RCWM_count_in, self).__init__() | |||
| @@ -29,6 +30,7 @@ class RCWM_count_in(nn.Cell): | |||
| def construct(self, x): | |||
| return self.RCWM_count_in(x) | |||
| class RCWM_count_out(nn.Cell): | |||
| def __init__(self): | |||
| super(RCWM_count_out, self).__init__() | |||
| @@ -37,6 +39,7 @@ class RCWM_count_out(nn.Cell): | |||
| def construct(self, x): | |||
| return self.RCWM_count_out(x) | |||
| class RCWM_3D(nn.Cell): | |||
| def __init__(self): | |||
| super(RCWM_3D, self).__init__() | |||
| @@ -45,6 +48,16 @@ class RCWM_3D(nn.Cell): | |||
| def construct(self, x): | |||
| return self.RCWM_3D(x) | |||
| class RCWM_1D(nn.Cell): | |||
| def __init__(self): | |||
| super(RCWM_1D, self).__init__() | |||
| self.RCWM_1D = P.RandomChoiceWithMask(count=10, seed=9) | |||
| def construct(self, x): | |||
| return self.RCWM_1D(x) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| @@ -58,12 +71,14 @@ def test_RCWM_3D(): | |||
| assert output1.shape == expect1 | |||
| assert output2.shape == expect2 | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_RCWM_count_out(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], [0, 0, 0, 1]]).astype(np.bool)) | |||
| input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], | |||
| [0, 0, 0, 1]]).astype(np.bool)) | |||
| expect1 = (10, 2) | |||
| expect2 = (10,) | |||
| rcwm = RCWM_count_out() | |||
| @@ -71,15 +86,36 @@ def test_RCWM_count_out(): | |||
| assert output1.shape == expect1 | |||
| assert output2.shape == expect2 | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_RCWM_count_in(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], [0, 0, 0, 1]]).astype(np.bool)) | |||
| input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], | |||
| [0, 0, 0, 1]]).astype(np.bool)) | |||
| expect1 = (4, 2) | |||
| expect2 = (4,) | |||
| rcwm = RCWM_count_in() | |||
| output1, output2 = rcwm(input_tensor) | |||
| assert output1.shape == expect1 | |||
| assert output2.shape == expect2 | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_RCWM_1D(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| input_tensor = Tensor( | |||
| np.array([1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1]).astype(np.bool)) | |||
| expect_index = np.array([[11], [9], [2], [15], [10], [7], | |||
| [8], [0], [0], [0]]).astype(np.int32) | |||
| expect_mask = np.array( | |||
| [True, True, True, True, True, True, True, True, False, False]) | |||
| rcwm = RCWM_1D() | |||
| output1, output2 = rcwm(input_tensor) | |||
| print(output1.asnumpy()) | |||
| print(output2) | |||
| assert np.array_equal(output1.asnumpy(), expect_index) | |||
| assert np.array_equal(output2.asnumpy(), expect_mask) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-21 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -24,7 +24,7 @@ from mindspore.ops import operations as P | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_topk(): | |||
| def test_topk_small_2d(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| x_np = np.random.rand(3, 4).astype(np.float32) | |||
| @@ -36,7 +36,20 @@ def test_topk(): | |||
| x_np = np.random.rand(3, 4).astype(np.float32) | |||
| k = 4 | |||
| ms_output = P.TopK(False)(Tensor(x_np), k) | |||
| assert np.allclose(ms_output[0].asnumpy(), x_np) | |||
| np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] | |||
| assert np.allclose(ms_output[0].asnumpy(), np_output) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_topk_3d(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| x_np = np.random.rand(2, 256, 128).astype(np.float32) | |||
| k = 4 | |||
| ms_output = P.TopK(True)(Tensor(x_np), k) | |||
| np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] | |||
| assert np.allclose(ms_output[0].asnumpy(), np_output) | |||
| x_np = np.random.rand(2, 3, 4).astype(np.float32) | |||
| k = 2 | |||
| @@ -44,6 +57,12 @@ def test_topk(): | |||
| np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] | |||
| assert np.allclose(ms_output[0].asnumpy(), np_output) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_topk_big_2d(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| x_np = np.random.rand(512, 1024).astype(np.float32) | |||
| k = 512 | |||
| ms_output = P.TopK(True)(Tensor(x_np), k) | |||
| @@ -51,32 +70,69 @@ def test_topk(): | |||
| assert np.allclose(ms_output[0].asnumpy(), np_output) | |||
| # sorted elements num greater than max thread per block | |||
| x_np = np.random.rand(512, 2048).astype(np.float32) | |||
| x_np = np.random.rand(128, 2048).astype(np.float32) | |||
| k = 1 | |||
| ms_output = P.TopK(True)(Tensor(x_np), k) | |||
| np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] | |||
| assert np.allclose(ms_output[0].asnumpy(), np_output) | |||
| x_np = np.random.rand(512, 2048).astype(np.float32) | |||
| x_np = np.random.rand(32, 2048).astype(np.float32) | |||
| k = 2048 | |||
| ms_output = P.TopK(True)(Tensor(x_np), k) | |||
| np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] | |||
| assert np.allclose(ms_output[0].asnumpy(), np_output) | |||
| # sorted elements num greater than max share memory per block | |||
| x_np = np.random.rand(512, 40960).astype(np.float32) | |||
| x_np = np.random.rand(16, 40960).astype(np.float32) | |||
| k = 1 | |||
| ms_output = P.TopK(True)(Tensor(x_np), k) | |||
| np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] | |||
| assert np.allclose(ms_output[0].asnumpy(), np_output) | |||
| x_np = np.random.rand(512, 40960).astype(np.float32) | |||
| k = 40960 | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_topk_big_k(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| x_np = np.random.rand(8, 40960).astype(np.float32) | |||
| k = 4096 | |||
| ms_output = P.TopK(True)(Tensor(x_np), k) | |||
| np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] | |||
| assert np.allclose(ms_output[0].asnumpy(), np_output) | |||
| x_np = np.random.rand(512, 40960).astype(np.float32) | |||
| k = 40960 | |||
| ms_output = P.TopK(False)(Tensor(x_np), k) | |||
| assert np.allclose(ms_output[0].asnumpy(), x_np) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_topk_1d(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| x_np = np.random.rand(12).astype(np.float32) | |||
| k = 4 | |||
| ms_output = P.TopK(True)(Tensor(x_np), k) | |||
| np_output = np.sort(x_np)[::-1][0:k] | |||
| assert np.allclose(ms_output[0].asnumpy(), np_output) | |||
| x_np = np.random.rand(1200).astype(np.float32) | |||
| k = 256 | |||
| ms_output = P.TopK(True)(Tensor(x_np), k) | |||
| np_output = np.sort(x_np)[::-1][0:k] | |||
| assert np.allclose(ms_output[0].asnumpy(), np_output) | |||
| x_np = np.random.rand(250000).astype(np.float32) | |||
| k = 2000 | |||
| ms_output = P.TopK(True)(Tensor(x_np), k) | |||
| np_output = np.sort(x_np)[::-1][0:k] | |||
| assert np.allclose(ms_output[0].asnumpy(), np_output) | |||
| x_np = np.random.rand(10240).astype(np.float32) | |||
| k = 4096 | |||
| ms_output = P.TopK(True)(Tensor(x_np), k) | |||
| np_output = np.sort(x_np)[::-1][0:k] | |||
| assert np.allclose(ms_output[0].asnumpy(), np_output) | |||
| x_np = np.random.rand(720).astype(np.float32) | |||
| k = 720 | |||
| ms_output = P.TopK(True)(Tensor(x_np), k) | |||
| np_output = np.sort(x_np)[::-1][0:k] | |||
| assert np.allclose(ms_output[0].asnumpy()[:k], np_output) | |||