From: @tom__chen Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -0,0 +1,33 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/arrays/unique_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| Unique, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), | |||
| UniqueGpuKernel, float, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| Unique, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), | |||
| UniqueGpuKernel, half, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| Unique, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| UniqueGpuKernel, int, int) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,104 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNIQUEGPUKERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNIQUEGPUKERNEL_H_ | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/unique_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T, typename S> | |||
| class UniqueGpuKernel : public GpuKernel { | |||
| public: | |||
| UniqueGpuKernel() : input_size_(0), output_size_(0), workspace_size_(0), num_elements_(1), post_output_size_(0) {} | |||
| ~UniqueGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| T *input = GetDeviceAddress<T>(inputs, 0); | |||
| S *input_index = GetDeviceAddress<S>(workspace, 0); | |||
| S *sorted_index = GetDeviceAddress<S>(workspace, 1); | |||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||
| S *index = GetDeviceAddress<S>(outputs, 1); | |||
| stream_ptr_ = stream_ptr; | |||
| post_output_size_ = CalUnique(input, num_elements_, input_index, sorted_index, output, index, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| kernel_node_ = kernel_node; | |||
| std::vector<size_t> shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| for (auto x : shape) { | |||
| num_elements_ *= x; | |||
| } | |||
| input_size_ = num_elements_ * sizeof(T); | |||
| output_size_ = input_size_; | |||
| workspace_size_ = num_elements_ * sizeof(S); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| void PostExecute() override { | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr_)), | |||
| "cudaStreamSynchronized failed"); | |||
| std::vector<TypeId> type_ids; | |||
| std::vector<std::vector<size_t>> shapes; | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node_); | |||
| for (size_t i = 0; i < output_num; ++i) { | |||
| std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(kernel_node_, i); | |||
| if (i == 0) { | |||
| shape[0] = post_output_size_; | |||
| } | |||
| TypeId type_id = AnfAlgo::GetOutputInferDataType(kernel_node_, i); | |||
| type_ids.emplace_back(type_id); | |||
| shapes.emplace_back(shape); | |||
| } | |||
| AnfAlgo::SetOutputInferTypeAndShape(type_ids, shapes, kernel_node_.get()); | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(input_size_); | |||
| output_size_list_.push_back(output_size_); | |||
| output_size_list_.push_back(num_elements_ * sizeof(S)); | |||
| workspace_size_list_.push_back(workspace_size_); | |||
| workspace_size_list_.push_back(workspace_size_); | |||
| } | |||
| private: | |||
| void *stream_ptr_; | |||
| size_t input_size_; | |||
| size_t output_size_; | |||
| size_t workspace_size_; | |||
| int num_elements_; | |||
| int post_output_size_; | |||
| CNodePtr kernel_node_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNIQUEGPUKERNEL_H_ | |||
| @@ -0,0 +1,74 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <thrust/adjacent_difference.h> | |||
| #include <thrust/copy.h> | |||
| #include <thrust/device_ptr.h> | |||
| #include <thrust/execution_policy.h> | |||
| #include <thrust/sequence.h> | |||
| #include <thrust/sort.h> | |||
| #include <thrust/unique.h> | |||
| #include <algorithm> | |||
| #include "unique_impl.cuh" | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| #include "include/cuda_fp16.h" | |||
| template <typename T, typename S> | |||
| int CalUnique(const T *input, int num_elements, S *input_index, S *sorted_index, T *output, S *index, | |||
| cudaStream_t cuda_stream) { | |||
| auto policy = thrust::cuda::par.on(cuda_stream); | |||
| thrust::sequence(policy, | |||
| thrust::device_pointer_cast(sorted_index), | |||
| thrust::device_pointer_cast(sorted_index) + num_elements); | |||
| thrust::copy(thrust::device_pointer_cast(input), | |||
| thrust::device_pointer_cast(input) + num_elements, | |||
| thrust::device_pointer_cast(output)); | |||
| thrust::stable_sort_by_key(policy, | |||
| thrust::device_pointer_cast(output), | |||
| thrust::device_pointer_cast(output) + num_elements, | |||
| thrust::device_pointer_cast(sorted_index)); | |||
| thrust::adjacent_difference(policy, | |||
| thrust::device_pointer_cast(output), | |||
| thrust::device_pointer_cast(output) + num_elements, | |||
| thrust::device_pointer_cast(input_index), | |||
| thrust::not_equal_to<T>()); | |||
| thrust::fill(policy, | |||
| thrust::device_pointer_cast(input_index), | |||
| thrust::device_pointer_cast(input_index) + 1, | |||
| 0); | |||
| thrust::inclusive_scan(policy, | |||
| thrust::device_pointer_cast(input_index), | |||
| thrust::device_pointer_cast(input_index) + num_elements, | |||
| thrust::device_pointer_cast(input_index)); | |||
| thrust::scatter(policy, | |||
| thrust::device_pointer_cast(input_index), | |||
| thrust::device_pointer_cast(input_index) + num_elements, | |||
| thrust::device_pointer_cast(sorted_index), | |||
| thrust::device_pointer_cast(index)); | |||
| thrust::device_ptr<T> output_end; | |||
| output_end = thrust::unique(policy, | |||
| thrust::device_pointer_cast(output), | |||
| thrust::device_pointer_cast(output) + num_elements); | |||
| int output_size = thrust::distance(thrust::device_pointer_cast(output), output_end); | |||
| return output_size; | |||
| } | |||
| template int CalUnique<float, int>(const float *input, int num_elements, int *input_index, int *sorted_index, | |||
| float *output, int *index, cudaStream_t cuda_stream); | |||
| template int CalUnique<half, int>(const half *input, int num_elements, int *input_index, int *sorted_index, | |||
| half *output, int *index, cudaStream_t cuda_stream); | |||
| template int CalUnique<int, int>(const int *input, int num_elements, int *input_index, int *sorted_index, | |||
| int *output, int *index, cudaStream_t cuda_stream); | |||
| @@ -0,0 +1,22 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_UNIQUE_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_UNIQUE_H_ | |||
| template <typename T, typename S> | |||
| int CalUnique(const T *input, int num_elements, S *input_index, S *sorted_index, T *output, S *index, | |||
| cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_UNIQUE_H_ | |||
| @@ -0,0 +1,226 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| class NetUnique(nn.Cell): | |||
| def __init__(self): | |||
| super(NetUnique, self).__init__() | |||
| self.unique = P.Unique() | |||
| def construct(self, x): | |||
| x_unique, x_idx = self.unique(x) | |||
| return x_unique, x_idx | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_unique_1d(): | |||
| x = Tensor(np.array([4, 5, 1, 2, 3, 3, 4, 5]).astype(np.float32)) | |||
| exp_output = np.array([1, 2, 3, 4, 5]).astype(np.float32) | |||
| exp_idx = np.array([3, 4, 0, 1, 2, 2, 3, 4]).astype(np.int32) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = NetUnique() | |||
| x_unique, x_idx = net(x) | |||
| assert (x_unique.asnumpy() == exp_output).all() | |||
| assert (x_idx.asnumpy() == exp_idx).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_unique_1d_float(): | |||
| x = Tensor(np.array([0.4, 0.5, 1.23, 2.2, 12.43, 12.43, 0.4, 0.5]).astype(np.float32)) | |||
| exp_output = np.array([0.4, 0.5, 1.23, 2.2, 12.43]).astype(np.float32) | |||
| exp_idx = np.array([0, 1, 2, 3, 4, 4, 0, 1]).astype(np.int32) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = NetUnique() | |||
| x_unique, x_idx = net(x) | |||
| assert (x_unique.asnumpy() == exp_output).all() | |||
| assert (x_idx.asnumpy() == exp_idx).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_unique_1d_sorted(): | |||
| x = Tensor(np.array([1, 1, 2, 4, 4, 4, 7, 8, 8]).astype(np.float32)) | |||
| exp_output = np.array([1, 2, 4, 7, 8]).astype(np.float32) | |||
| exp_idx = np.array([0, 0, 1, 2, 2, 2, 3, 4, 4]).astype(np.int32) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = NetUnique() | |||
| x_unique, x_idx = net(x) | |||
| assert (x_unique.asnumpy() == exp_output).all() | |||
| assert (x_idx.asnumpy() == exp_idx).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_unique_zeros(): | |||
| x = Tensor(np.zeros(1000).astype(np.float32)) | |||
| exp_output = np.zeros(1).astype(np.float32) | |||
| exp_idx = np.zeros(1000).astype(np.int32) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = NetUnique() | |||
| x_unique, x_idx = net(x) | |||
| assert (x_unique.asnumpy() == exp_output).all() | |||
| assert (x_idx.asnumpy() == exp_idx).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_unique_large(): | |||
| x_np1 = np.arange(100) | |||
| x_np2 = np.arange(100, 200) | |||
| x_np3 = np.arange(200, 300) | |||
| x_np = np.concatenate((x_np1, x_np2, x_np3, x_np1, x_np2, x_np3, x_np1, x_np2, x_np3)) | |||
| x = Tensor(x_np.astype(np.float32)) | |||
| exp_output = np.arange(300).astype(np.float32) | |||
| exp_idx = np.concatenate((x_np1, x_np2, x_np3, x_np1, x_np2, x_np3, x_np1, x_np2, x_np3)).astype(np.int32) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = NetUnique() | |||
| x_unique, x_idx = net(x) | |||
| assert (x_unique.asnumpy() == exp_output).all() | |||
| assert (x_idx.asnumpy() == exp_idx).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_unique_1d_half(): | |||
| x = Tensor(np.array([0.4, 0.5, 1.23, 2.2, 12.43, 12.43, 0.4, 0.5]).astype(np.float16)) | |||
| exp_output = np.array([0.4, 0.5, 1.23, 2.2, 12.43]).astype(np.float16) | |||
| exp_idx = np.array([0, 1, 2, 3, 4, 4, 0, 1]).astype(np.int32) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = NetUnique() | |||
| x_unique, x_idx = net(x) | |||
| assert (x_unique.asnumpy() == exp_output).all() | |||
| assert (x_idx.asnumpy() == exp_idx).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_unique_1d_sorted_half(): | |||
| x = Tensor(np.array([1, 1, 2, 4, 4, 4, 7, 8, 8]).astype(np.float16)) | |||
| exp_output = np.array([1, 2, 4, 7, 8]).astype(np.float16) | |||
| exp_idx = np.array([0, 0, 1, 2, 2, 2, 3, 4, 4]).astype(np.int32) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = NetUnique() | |||
| x_unique, x_idx = net(x) | |||
| assert (x_unique.asnumpy() == exp_output).all() | |||
| assert (x_idx.asnumpy() == exp_idx).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_unique_zeros_half(): | |||
| x = Tensor(np.zeros(1000).astype(np.float16)) | |||
| exp_output = np.zeros(1).astype(np.float16) | |||
| exp_idx = np.zeros(1000).astype(np.int32) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = NetUnique() | |||
| x_unique, x_idx = net(x) | |||
| assert (x_unique.asnumpy() == exp_output).all() | |||
| assert (x_idx.asnumpy() == exp_idx).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_unique_large_half(): | |||
| x_np1 = np.arange(100) | |||
| x_np2 = np.arange(100, 200) | |||
| x_np3 = np.arange(200, 300) | |||
| x_np = np.concatenate((x_np1, x_np2, x_np3, x_np1, x_np2, x_np3, x_np1, x_np2, x_np3)) | |||
| x = Tensor(x_np.astype(np.float16)) | |||
| exp_output = np.arange(300).astype(np.float16) | |||
| exp_idx = np.concatenate((x_np1, x_np2, x_np3, x_np1, x_np2, x_np3, x_np1, x_np2, x_np3)).astype(np.int32) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = NetUnique() | |||
| x_unique, x_idx = net(x) | |||
| assert (x_unique.asnumpy() == exp_output).all() | |||
| assert (x_idx.asnumpy() == exp_idx).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_unique_1d_int32(): | |||
| x = Tensor(np.array([4, 5, 1, 2, 3, 3, 4, 5]).astype(np.int32)) | |||
| exp_output = np.array([1, 2, 3, 4, 5]).astype(np.int32) | |||
| exp_idx = np.array([3, 4, 0, 1, 2, 2, 3, 4]).astype(np.int32) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = NetUnique() | |||
| x_unique, x_idx = net(x) | |||
| assert (x_unique.asnumpy() == exp_output).all() | |||
| assert (x_idx.asnumpy() == exp_idx).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_unique_1d_sorted_int32(): | |||
| x = Tensor(np.array([1, 1, 2, 4, 4, 4, 7, 8, 8]).astype(np.int32)) | |||
| exp_output = np.array([1, 2, 4, 7, 8]).astype(np.int32) | |||
| exp_idx = np.array([0, 0, 1, 2, 2, 2, 3, 4, 4]).astype(np.int32) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = NetUnique() | |||
| x_unique, x_idx = net(x) | |||
| assert (x_unique.asnumpy() == exp_output).all() | |||
| assert (x_idx.asnumpy() == exp_idx).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_unique_zeros_int32(): | |||
| x = Tensor(np.zeros(1000).astype(np.int32)) | |||
| exp_output = np.zeros(1).astype(np.int32) | |||
| exp_idx = np.zeros(1000).astype(np.int32) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = NetUnique() | |||
| x_unique, x_idx = net(x) | |||
| assert (x_unique.asnumpy() == exp_output).all() | |||
| assert (x_idx.asnumpy() == exp_idx).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_unique_large_int32(): | |||
| x_np1 = np.arange(100) | |||
| x_np2 = np.arange(100, 200) | |||
| x_np3 = np.arange(200, 300) | |||
| x_np = np.concatenate((x_np1, x_np2, x_np3, x_np1, x_np2, x_np3, x_np1, x_np2, x_np3)) | |||
| x = Tensor(x_np.astype(np.int32)) | |||
| exp_output = np.arange(300).astype(np.int32) | |||
| exp_idx = np.concatenate((x_np1, x_np2, x_np3, x_np1, x_np2, x_np3, x_np1, x_np2, x_np3)).astype(np.int32) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = NetUnique() | |||
| x_unique, x_idx = net(x) | |||
| assert (x_unique.asnumpy() == exp_output).all() | |||
| assert (x_idx.asnumpy() == exp_idx).all() | |||