Merge pull request !3045 from chenweifeng/sorttags/v0.6.0-beta
| @@ -44,7 +44,7 @@ if(ENABLE_GPU) | |||||
| "backend/kernel_compiler/akg/akg_kernel_attrs_process.cc" | "backend/kernel_compiler/akg/akg_kernel_attrs_process.cc" | ||||
| ) | ) | ||||
| list(APPEND CUDA_NVCC_FLAGS -arch=sm_53) | |||||
| list(APPEND CUDA_NVCC_FLAGS -arch=sm_53 --expt-relaxed-constexpr) | |||||
| list(REMOVE_ITEM GPU_SRC_LIST "runtime/device/gpu/blocking_queue.cc" "runtime/device/gpu/gpu_buffer_mgr.cc") | list(REMOVE_ITEM GPU_SRC_LIST "runtime/device/gpu/blocking_queue.cc" "runtime/device/gpu/gpu_buffer_mgr.cc") | ||||
| list(REMOVE_ITEM GPU_SRC_LIST "runtime/device/gpu/mpi/mpi_initializer.cc" | list(REMOVE_ITEM GPU_SRC_LIST "runtime/device/gpu/mpi/mpi_initializer.cc" | ||||
| "runtime/device/gpu/distribution/collective_wrapper.cc" | "runtime/device/gpu/distribution/collective_wrapper.cc" | ||||
| @@ -0,0 +1,29 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_TWO(TopK, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddOutputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeInt32), | |||||
| TopKGpuKernel, float, int) | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,110 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_TOPK_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_TOPK_H_ | |||||
| #include <vector> | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T, typename S> | |||||
| class TopKGpuKernel : public GpuKernel { | |||||
| public: | |||||
| TopKGpuKernel() : sorted_(false), outer_size_(1), inner_size_(1), k_(1), use_share_mem_(true), ceil_power2_(0) {} | |||||
| ~TopKGpuKernel() override = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||||
| T *input_addr = GetDeviceAddress<T>(inputs, 0); | |||||
| S *k = GetDeviceAddress<S>(inputs, 1); | |||||
| T *output_addr = GetDeviceAddress<T>(outputs, 0); | |||||
| S *indices = GetDeviceAddress<S>(outputs, 1); | |||||
| T *data_buff = nullptr; | |||||
| S *index_buff = nullptr; | |||||
| if (use_share_mem_ == false) { | |||||
| data_buff = GetDeviceAddress<T>(workspaces, 0); | |||||
| index_buff = GetDeviceAddress<S>(workspaces, 1); | |||||
| } | |||||
| TopK(outer_size_, inner_size_, input_addr, k, output_addr, indices, data_buff, index_buff, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| if (sorted_ == false) { | |||||
| std::cout << "================BitonicSortByKey" << std::endl; | |||||
| BitonicSortByKey(outer_size_, k_, output_addr, indices, data_buff, index_buff, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||||
| for (size_t i = 0; i < input_shapes.size() - 1; i++) { | |||||
| outer_size_ *= input_shapes[i]; | |||||
| } | |||||
| inner_size_ = input_shapes[input_shapes.size() - 1]; | |||||
| k_ = output_shapes[output_shapes.size() - 1]; | |||||
| sorted_ = GetAttr<bool>(kernel_node, "sorted"); | |||||
| ceil_power2_ = RoundUpPower2(inner_size_); | |||||
| size_t buffer_size = ceil_power2_ * (sizeof(T) + sizeof(S)); | |||||
| if (buffer_size > SHARED_MEM_PER_BLOCK) { | |||||
| use_share_mem_ = false; | |||||
| MS_LOG(WARNING) << "CUDA share memory not enough, sort with RAM"; | |||||
| } | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| input_size_list_.push_back(outer_size_ * inner_size_ * sizeof(T)); | |||||
| input_size_list_.push_back(sizeof(S)); | |||||
| output_size_list_.push_back(outer_size_ * k_ * sizeof(T)); | |||||
| output_size_list_.push_back(outer_size_ * k_ * sizeof(S)); | |||||
| if (use_share_mem_ == false) { | |||||
| workspace_size_list_.push_back(outer_size_ * ceil_power2_ * sizeof(T)); | |||||
| workspace_size_list_.push_back(outer_size_ * ceil_power2_ * sizeof(S)); | |||||
| } | |||||
| } | |||||
| private: | |||||
| bool sorted_; | |||||
| int outer_size_; | |||||
| int inner_size_; | |||||
| int k_; | |||||
| bool use_share_mem_; | |||||
| int ceil_power2_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // TopKpuKernel | |||||
| @@ -0,0 +1,162 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh" | |||||
| #include <limits> | |||||
| #include <algorithm> | |||||
| int RoundUpPower2(int v) { | |||||
| v--; | |||||
| v |= v >> 1; | |||||
| v |= v >> 2; | |||||
| v |= v >> 4; | |||||
| v |= v >> 8; | |||||
| v |= v >> 16; | |||||
| v++; | |||||
| return v; | |||||
| } | |||||
| template <typename T> | |||||
| __inline__ __device__ void Swap(T *lhs, T *rhs) { | |||||
| T tmp = lhs[0]; | |||||
| lhs[0] = rhs[0]; | |||||
| rhs[0] = tmp; | |||||
| } | |||||
| template <typename T, typename S> | |||||
| __global__ void TopkKernel(const int outer, const int inner, const int ceil_power2, const T *input, const S *k, | |||||
| T *output, S *indices, T *data_buff, S *index_buff) { | |||||
| // default: sort with share memory | |||||
| extern __shared__ T share_mem[]; | |||||
| T *data_arr = share_mem; | |||||
| S *index_arr = reinterpret_cast<S *>(data_arr + ceil_power2); | |||||
| // sort with RAM | |||||
| if (data_buff != nullptr && index_buff != nullptr) { | |||||
| data_arr = data_buff + blockIdx.x * ceil_power2; | |||||
| index_arr = index_buff + blockIdx.x * ceil_power2; | |||||
| } | |||||
| for (int i = threadIdx.x; i < ceil_power2; i += blockDim.x) { | |||||
| data_arr[i] = (i < inner) ? input[blockIdx.x * inner + i] : std::numeric_limits<T>::max(); | |||||
| index_arr[i] = i; | |||||
| } | |||||
| __syncthreads(); | |||||
| for (size_t i = 2; i <= ceil_power2; i <<= 1) { | |||||
| for (size_t j = (i >> 1); j > 0; j >>= 1) { | |||||
| for (size_t tid = threadIdx.x; tid < ceil_power2; tid += blockDim.x) { | |||||
| size_t tid_comp = tid ^ j; | |||||
| if (tid_comp > tid) { | |||||
| if ((tid & i) == 0) { | |||||
| if (data_arr[tid] > data_arr[tid_comp]) { | |||||
| Swap(&data_arr[tid], &data_arr[tid_comp]); | |||||
| Swap(&index_arr[tid], &index_arr[tid_comp]); | |||||
| } | |||||
| } else { | |||||
| if (data_arr[tid] < data_arr[tid_comp]) { | |||||
| Swap(&data_arr[tid], &data_arr[tid_comp]); | |||||
| Swap(&index_arr[tid], &index_arr[tid_comp]); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| __syncthreads(); | |||||
| } | |||||
| } | |||||
| for (size_t tid = threadIdx.x; tid < k[0]; tid += blockDim.x) { | |||||
| output[blockIdx.x * k[0] + tid] = data_arr[inner - tid - 1]; | |||||
| indices[blockIdx.x * k[0] + tid] = index_arr[inner - tid - 1]; | |||||
| } | |||||
| } | |||||
| template <typename T, typename S> | |||||
| void TopK(const int &outer, const int &inner, const T *input, const S *k, T *output, S *indices, T *data_buff, | |||||
| S *index_buff, cudaStream_t stream) { | |||||
| int ceil_power2 = RoundUpPower2(inner); | |||||
| int share_mem = (data_buff == nullptr) ? ceil_power2 * (sizeof(T) + sizeof(S)) : 0; | |||||
| int thread = std::min(ceil_power2, GET_THREADS); | |||||
| TopkKernel<<<outer, thread, share_mem, stream>>>(outer, inner, ceil_power2, input, k, output, indices, data_buff, | |||||
| index_buff); | |||||
| } | |||||
| template <typename T, typename S> | |||||
| __global__ void BitonicSortByKeyKernel(const int outer, const int inner, const int ceil_power2, T *input, | |||||
| S *indices, T *data_buff, S *index_buff) { | |||||
| // default: sort with share memory | |||||
| extern __shared__ T share_mem[]; | |||||
| T *data_arr = share_mem; | |||||
| S *index_arr = reinterpret_cast<S *>(data_arr + ceil_power2); | |||||
| // sort with RAM | |||||
| if (data_buff != nullptr && index_buff != nullptr) { | |||||
| data_arr = data_buff + blockIdx.x * ceil_power2; | |||||
| index_arr = index_buff + blockIdx.x * ceil_power2; | |||||
| } | |||||
| for (int i = threadIdx.x; i < ceil_power2; i += blockDim.x) { | |||||
| data_arr[i] = (i < inner) ? input[blockIdx.x * inner + i] : std::numeric_limits<T>::max(); | |||||
| index_arr[i] = (i < inner) ? indices[blockIdx.x * inner + i] : std::numeric_limits<S>::max();; | |||||
| } | |||||
| __syncthreads(); | |||||
| for (size_t i = 2; i <= ceil_power2; i <<= 1) { | |||||
| for (size_t j = (i >> 1); j > 0; j >>= 1) { | |||||
| for (size_t tid = threadIdx.x; tid < ceil_power2; tid += blockDim.x) { | |||||
| size_t tid_comp = tid ^ j; | |||||
| if (tid_comp > tid) { | |||||
| if ((tid & i) == 0) { | |||||
| if (index_arr[tid] > index_arr[tid_comp]) { | |||||
| Swap(&data_arr[tid], &data_arr[tid_comp]); | |||||
| Swap(&index_arr[tid], &index_arr[tid_comp]); | |||||
| } | |||||
| } else { | |||||
| if (index_arr[tid] < index_arr[tid_comp]) { | |||||
| Swap(&data_arr[tid], &data_arr[tid_comp]); | |||||
| Swap(&index_arr[tid], &index_arr[tid_comp]); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| __syncthreads(); | |||||
| } | |||||
| } | |||||
| for (size_t tid = threadIdx.x; tid < inner; tid += blockDim.x) { | |||||
| input[blockIdx.x * inner + tid] = data_arr[tid]; | |||||
| indices[blockIdx.x * inner + tid] = index_arr[tid]; | |||||
| } | |||||
| } | |||||
| template <typename T, typename S> | |||||
| void BitonicSortByKey(const int &outer, const int &inner, T *input, S *indices, T *data_buff, S *index_buff, | |||||
| cudaStream_t stream) { | |||||
| int ceil_power2 = RoundUpPower2(inner); | |||||
| size_t share_mem = ceil_power2 * (sizeof(T) + sizeof(S)); | |||||
| if (share_mem > SHARED_MEM_PER_BLOCK) { | |||||
| share_mem = 0; | |||||
| } else { | |||||
| data_buff = nullptr; | |||||
| index_buff = nullptr; | |||||
| } | |||||
| int thread = std::min(ceil_power2, GET_THREADS); | |||||
| BitonicSortByKeyKernel<<<outer, thread, share_mem, stream>>>(outer, inner, ceil_power2, input, indices, data_buff, | |||||
| index_buff); | |||||
| } | |||||
| template void TopK(const int &outer, const int &inner, const float *input_addr, const int *k, float *output, | |||||
| int *indices, float *data_buff, int *index_buff, cudaStream_t stream); | |||||
| template void BitonicSortByKey(const int &outer, const int &inner, float *input, int *indices, float *data_buff, | |||||
| int *index_buff, cudaStream_t stream); | |||||
| @@ -0,0 +1,32 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TOPK_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TOPK_H_ | |||||
| #include <cuda_runtime.h> | |||||
| #include "runtime/device/gpu/cuda_common.h" | |||||
| template <typename T, typename S> | |||||
| void TopK(const int &outer, const int &inner, const T *input_addr, const S *k, T *output, S *indices, T *data_buff, | |||||
| S *index_buff, cudaStream_t stream); | |||||
| template <typename T, typename S> | |||||
| void BitonicSortByKey(const int &outer, const int &inner, T *input, S *indices, T *data_buff, S *index_buff, | |||||
| cudaStream_t stream); | |||||
| int RoundUpPower2(int v); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TOPK_H_ | |||||
| @@ -30,6 +30,7 @@ class CudaCommon { | |||||
| inline int blocks_num(const int total_threads) const { | inline int blocks_num(const int total_threads) const { | ||||
| return std::min(((total_threads - 1) / threads_per_block_) + 1, max_blocks_); | return std::min(((total_threads - 1) / threads_per_block_) + 1, max_blocks_); | ||||
| } | } | ||||
| size_t share_memory_size() const { return max_share_memory_; } | |||||
| static CudaCommon &GetInstance() { | static CudaCommon &GetInstance() { | ||||
| static CudaCommon instance; | static CudaCommon instance; | ||||
| @@ -44,6 +45,7 @@ class CudaCommon { | |||||
| threads_per_block_ = prop.maxThreadsPerBlock; | threads_per_block_ = prop.maxThreadsPerBlock; | ||||
| max_blocks_ = prop.multiProcessorCount; | max_blocks_ = prop.multiProcessorCount; | ||||
| major_sm_ = prop.major; | major_sm_ = prop.major; | ||||
| max_share_memory_ = prop.sharedMemPerBlock; | |||||
| } | } | ||||
| ~CudaCommon() = default; | ~CudaCommon() = default; | ||||
| CudaCommon(const CudaCommon &) = delete; | CudaCommon(const CudaCommon &) = delete; | ||||
| @@ -52,10 +54,12 @@ class CudaCommon { | |||||
| int max_blocks_; | int max_blocks_; | ||||
| int threads_per_block_; | int threads_per_block_; | ||||
| int major_sm_; | int major_sm_; | ||||
| size_t max_share_memory_; | |||||
| }; | }; | ||||
| #define GET_BLOCKS(total_threads) mindspore::device::gpu::CudaCommon::GetInstance().blocks_num(total_threads) | #define GET_BLOCKS(total_threads) mindspore::device::gpu::CudaCommon::GetInstance().blocks_num(total_threads) | ||||
| #define GET_THREADS mindspore::device::gpu::CudaCommon::GetInstance().threads_num() | #define GET_THREADS mindspore::device::gpu::CudaCommon::GetInstance().threads_num() | ||||
| #define GET_MAJOR_SM mindspore::device::gpu::CudaCommon::GetInstance().major_sm() | #define GET_MAJOR_SM mindspore::device::gpu::CudaCommon::GetInstance().major_sm() | ||||
| #define SHARED_MEM_PER_BLOCK mindspore::device::gpu::CudaCommon::GetInstance().share_memory_size() | |||||
| #define MINIUM_SM 6 | #define MINIUM_SM 6 | ||||
| #define RECOMMEND_SM 7 | #define RECOMMEND_SM 7 | ||||
| } // namespace gpu | } // namespace gpu | ||||
| @@ -0,0 +1,82 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_topk(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| x_np = np.random.rand(3, 4).astype(np.float32) | |||||
| k = 4 | |||||
| ms_output = P.TopK(True)(Tensor(x_np), k) | |||||
| np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] | |||||
| assert np.allclose(ms_output[0].asnumpy(), np_output) | |||||
| x_np = np.random.rand(3, 4).astype(np.float32) | |||||
| k = 4 | |||||
| ms_output = P.TopK(False)(Tensor(x_np), k) | |||||
| assert np.allclose(ms_output[0].asnumpy(), x_np) | |||||
| x_np = np.random.rand(2, 3, 4).astype(np.float32) | |||||
| k = 2 | |||||
| ms_output = P.TopK(True)(Tensor(x_np), k) | |||||
| np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] | |||||
| assert np.allclose(ms_output[0].asnumpy(), np_output) | |||||
| x_np = np.random.rand(512, 1024).astype(np.float32) | |||||
| k = 512 | |||||
| ms_output = P.TopK(True)(Tensor(x_np), k) | |||||
| np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] | |||||
| assert np.allclose(ms_output[0].asnumpy(), np_output) | |||||
| # sorted elements num greater than max thread per block | |||||
| x_np = np.random.rand(512, 2048).astype(np.float32) | |||||
| k = 1 | |||||
| ms_output = P.TopK(True)(Tensor(x_np), k) | |||||
| np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] | |||||
| assert np.allclose(ms_output[0].asnumpy(), np_output) | |||||
| x_np = np.random.rand(512, 2048).astype(np.float32) | |||||
| k = 2048 | |||||
| ms_output = P.TopK(True)(Tensor(x_np), k) | |||||
| np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] | |||||
| assert np.allclose(ms_output[0].asnumpy(), np_output) | |||||
| # sorted elements num greater than max share memory per block | |||||
| x_np = np.random.rand(512, 40960).astype(np.float32) | |||||
| k = 1 | |||||
| ms_output = P.TopK(True)(Tensor(x_np), k) | |||||
| np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] | |||||
| assert np.allclose(ms_output[0].asnumpy(), np_output) | |||||
| x_np = np.random.rand(512, 40960).astype(np.float32) | |||||
| k = 40960 | |||||
| ms_output = P.TopK(True)(Tensor(x_np), k) | |||||
| np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] | |||||
| assert np.allclose(ms_output[0].asnumpy(), np_output) | |||||
| x_np = np.random.rand(512, 40960).astype(np.float32) | |||||
| k = 40960 | |||||
| ms_output = P.TopK(False)(Tensor(x_np), k) | |||||
| assert np.allclose(ms_output[0].asnumpy(), x_np) | |||||