| @@ -0,0 +1,43 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "kernel/gpu/arrays/select_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_ONE(Select, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeBool) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32), | |||||
| SelectGpuKernel, float) | |||||
| MS_REG_GPU_KERNEL_ONE(Select, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeBool) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddOutputAttr(kNumberTypeFloat16), | |||||
| SelectGpuKernel, half) | |||||
| MS_REG_GPU_KERNEL_ONE(Select, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeBool) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddOutputAttr(kNumberTypeInt32), | |||||
| SelectGpuKernel, int) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,95 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_SELECT_GPU_KERNEL_H | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_SELECT_GPU_KERNEL_H | |||||
| #include <vector> | |||||
| #include "kernel/gpu/gpu_kernel.h" | |||||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||||
| #include "kernel/gpu/cuda_impl/select_impl.cuh" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T> | |||||
| class SelectGpuKernel : public GpuKernel { | |||||
| public: | |||||
| SelectGpuKernel() : input_size_(0), output_size_(0) {} | |||||
| ~SelectGpuKernel() override = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override { | |||||
| bool *input_cond = GetDeviceAddress<bool>(inputs, 0); | |||||
| T *input_x = GetDeviceAddress<T>(inputs, 1); | |||||
| T *input_y = GetDeviceAddress<T>(inputs, 2); | |||||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||||
| CalSelect(output_size_ / sizeof(T), input_cond, input_x, input_y, output, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| if (!CheckParam(kernel_node)) { | |||||
| return false; | |||||
| } | |||||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| input_size_ = sizeof(bool); | |||||
| output_size_ = sizeof(T); | |||||
| for (size_t x : shape) { | |||||
| input_size_ = input_size_ * x; | |||||
| output_size_ = output_size_ * x; | |||||
| } | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| input_size_list_.push_back(input_size_); | |||||
| input_size_list_.push_back(output_size_); | |||||
| input_size_list_.push_back(output_size_); | |||||
| output_size_list_.push_back(output_size_); | |||||
| } | |||||
| private: | |||||
| bool CheckParam(const CNodePtr &kernel_node) { | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| if (input_num != 3) { | |||||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but SelectGpuKernel needs 3 output."; | |||||
| return false; | |||||
| } | |||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||||
| if (output_num != 1) { | |||||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but SelectGpuKernel needs 1 output."; | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| size_t input_size_; | |||||
| size_t output_size_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_SELECT_GPU_KERNEL_H | |||||
| @@ -0,0 +1,42 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <stdio.h> | |||||
| #include <stdint.h> | |||||
| #include <include/cuda_runtime.h> | |||||
| #include "kernel/gpu/cuda_impl/select_impl.cuh" | |||||
| template <typename T> | |||||
| __global__ void Select(const size_t size, const bool* cond, const T* input_x, const T* input_y, T* output) { | |||||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | |||||
| output[pos] = cond[pos] ? input_x[pos] : input_y[pos]; | |||||
| } | |||||
| return; | |||||
| } | |||||
| template <typename T> | |||||
| void CalSelect(const size_t size, const bool* cond, const T* input_x, const T* input_y, T* output, | |||||
| cudaStream_t cuda_stream) { | |||||
| Select<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, cond, input_x, input_y, output); | |||||
| return; | |||||
| } | |||||
| template void CalSelect<float>(const size_t size, const bool* cond, const float* input_X, const float* input_y, | |||||
| float* output, cudaStream_t cuda_stream); | |||||
| template void CalSelect<int>(const size_t size, const bool* cond, const int* input_X, const int* input_y, int* output, | |||||
| cudaStream_t cuda_stream); | |||||
| template void CalSelect<half>(const size_t size, const bool* cond, const half* input_X, const half* input_y, | |||||
| half* output, cudaStream_t cuda_stream); | |||||
| @@ -0,0 +1,25 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SELECT_IMPL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SELECT_IMPL_H_ | |||||
| #include "device/gpu/cuda_common.h" | |||||
| template <typename T> | |||||
| void CalSelect(const size_t size, const bool* cond, const T* input_x, const T* input_y, T* output, | |||||
| cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SELECT_IMPL_H_ | |||||
| @@ -0,0 +1,47 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import pytest | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| import mindspore.nn as nn | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.select = P.Select() | |||||
| def construct(self, cond, x, y): | |||||
| return self.select(cond, x, y) | |||||
| cond = np.array([[True, False], [True, False]]).astype(np.bool) | |||||
| x = np.array([[1.2, 1], [1, 0]]).astype(np.float32) | |||||
| y = np.array([[1, 2], [3, 4.0]]).astype(np.float32) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_select(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| select = Net() | |||||
| output = select(Tensor(cond), Tensor(x), Tensor(y)) | |||||
| expect = [[1.2, 2], [1, 4.0]] | |||||
| error = np.ones(shape=[2, 2]) * 1.0e-6 | |||||
| diff = output.asnumpy() - expect | |||||
| assert np.all(diff < error) | |||||
| assert np.all(-diff < error) | |||||