Merge pull request !7950 from linqingke/gpu_opstags/v1.1.0
| @@ -0,0 +1,83 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/softplus_impl.cuh" | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| __global__ void SoftplusKernel(const size_t size, const T *input_addr, T *output_addr) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | |||
| float x = input_addr[pos]; | |||
| output_addr[pos] = logf(1. + exp(x)); | |||
| } | |||
| } | |||
| template <> | |||
| __global__ void SoftplusKernel(const size_t size, const half *input_addr, half *output_addr) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | |||
| float x = __half2float(input_addr[pos]); | |||
| output_addr[pos] = __float2half(logf(1. + exp(x))); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void Softplus(const size_t size, const T *input_addr, T *output_addr, cudaStream_t cuda_stream) { | |||
| SoftplusKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input_addr, output_addr); | |||
| return; | |||
| } | |||
| template <> | |||
| void Softplus(const size_t size, const half *input_addr, half *output_addr, cudaStream_t cuda_stream) { | |||
| SoftplusKernel<half><<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input_addr, output_addr); | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void SoftplusGradKernel(const size_t size, const T *dy_addr, const T *x_addr, T *dx_addr) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | |||
| T exp_x = exp(x_addr[pos]); | |||
| dx_addr[pos] = dy_addr[pos] * exp_x / (1. + exp_x); | |||
| } | |||
| } | |||
| template <typename T> | |||
| __global__ void SoftplusGradKernel(const size_t size, const half *dy_addr, const half *x_addr, half *dx_addr) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | |||
| float x = __half2float(x_addr[pos]); | |||
| float dy = __half2float(dy_addr[pos]); | |||
| float exp_x = exp(x); | |||
| dx_addr[pos] = __float2half(dy * exp_x / (1. + exp_x)); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void SoftplusGrad(const size_t size, const T *dy_addr, const T *x_addr, T *dx_addr, cudaStream_t cuda_stream) { | |||
| SoftplusGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dy_addr, x_addr, dx_addr); | |||
| return; | |||
| } | |||
| template <> | |||
| void SoftplusGrad(const size_t size, const half *dy_addr, const half *x_addr, half *dx_addr, cudaStream_t cuda_stream) { | |||
| SoftplusGradKernel<half><<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dy_addr, x_addr, dx_addr); | |||
| return; | |||
| } | |||
| template void Softplus(const size_t size, const float *input_addr, float *output_addr, cudaStream_t cuda_stream); | |||
| template void Softplus(const size_t size, const half *input_addr, half *output_addr, cudaStream_t cuda_stream); | |||
| template void SoftplusGrad(const size_t size, const float *dy_addr, const float *x_addr, float *dx_addr, | |||
| cudaStream_t cuda_stream); | |||
| template void SoftplusGrad(const size_t size, const half *dy_addr, const half *x_addr, half *dx_addr, | |||
| cudaStream_t cuda_stream); | |||
| @@ -0,0 +1,27 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SOFTPLUS_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SOFTPLUS_H_ | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template<typename T> | |||
| void Softplus(const size_t input_size, const T* input_addr, T* output_addr, cudaStream_t cuda_stream); | |||
| template<typename T> | |||
| void SoftplusGrad(const size_t size, const T* dy_addr, const T* x_addr, T* dx_addr, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SOFTPLUS_H_ | |||
| @@ -0,0 +1,26 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/nn/softplus_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(Softplus, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| SoftplusGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Softplus, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| SoftplusGpuKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,72 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SOFTPLUS_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SOFTPLUS_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/kernel_constants.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/softplus_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class SoftplusGpuKernel : public GpuKernel { | |||
| public: | |||
| SoftplusGpuKernel() : input_size_(0) {} | |||
| ~SoftplusGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| T *input_addr = GetDeviceAddress<T>(inputs, 0); | |||
| T *output_addr = GetDeviceAddress<T>(outputs, 0); | |||
| Softplus(input_size_ / sizeof(T), input_addr, output_addr, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| InitResource(); | |||
| input_size_ = sizeof(T); | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| for (auto dim : input_shape) { | |||
| input_size_ *= dim; | |||
| } | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(input_size_); | |||
| output_size_list_.push_back(input_size_); | |||
| } | |||
| private: | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| size_t input_size_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SOFTPLUS_GPU_KERNEL_H_ | |||
| @@ -0,0 +1,30 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/nn/softplus_grad_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| SoftplusGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| SoftplusGpuGradKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| SoftplusGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| SoftplusGpuGradKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,74 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SOFTPLUS_GRAD_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SOFTPLUS_GRAD_KERNEL_H_ | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/kernel_constants.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/softplus_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class SoftplusGpuGradKernel : public GpuKernel { | |||
| public: | |||
| SoftplusGpuGradKernel() : input_size_(0) {} | |||
| ~SoftplusGpuGradKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| T *dy_addr = GetDeviceAddress<T>(inputs, 0); | |||
| T *x_addr = GetDeviceAddress<T>(inputs, 1); | |||
| T *dx_addr = GetDeviceAddress<T>(outputs, 0); | |||
| SoftplusGrad(input_size_ / sizeof(T), dy_addr, x_addr, dx_addr, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| InitResource(); | |||
| input_size_ = sizeof(T); | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| for (auto dim : input_shape) { | |||
| input_size_ *= dim; | |||
| } | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(input_size_); | |||
| input_size_list_.push_back(input_size_); | |||
| output_size_list_.push_back(input_size_); | |||
| } | |||
| private: | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| size_t input_size_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SOFTPLUS_GRAD_KERNEL_H_ | |||
| @@ -0,0 +1,77 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| class SoftplusNet(nn.Cell): | |||
| def __init__(self): | |||
| super(SoftplusNet, self).__init__() | |||
| self.softplus = P.Softplus() | |||
| def construct(self, x): | |||
| return self.softplus(x) | |||
| class Grad(nn.Cell): | |||
| def __init__(self, network): | |||
| super(Grad, self).__init__() | |||
| self.grad = C.GradOperation(get_all=True, sens_param=True) | |||
| self.network = network | |||
| def construct(self, input_data, sens): | |||
| gout = self.grad(self.network)(input_data, sens) | |||
| return gout | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_softplusgrad(): | |||
| x = np.array([0.58401114, 0.68800163, 0.9760397, 0.14702141, 0.46563736, 0.9607501, | |||
| 0.14567593, 0.12261796, 0.37054458, 0.46421242]).astype(np.float32) | |||
| dy = np.array([0.5559598, 0.96994054, 0.24770357, 0.34646875, 0.2984393, 0.03287048, | |||
| 0.55681044, 0.966908, 0.06015943, 0.6099489]).astype(np.float32) | |||
| x_ms = Tensor(x) | |||
| dy_ms = Tensor(dy) | |||
| net = SoftplusNet() | |||
| grad = Grad(net) | |||
| output = grad(x_ms, dy_ms) | |||
| expect = dy * np.exp(x) / (1 + np.exp(x)) | |||
| assert np.allclose(output[0].asnumpy(), expect, rtol=1e-3) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_softplusgrad_fp16(): | |||
| np.random.seed(42) | |||
| x_np = np.random.randn(5, 3, 6).astype(np.float16) | |||
| dy_np = np.random.randn(5, 3, 6).astype(np.float16) | |||
| net = SoftplusNet() | |||
| grad = Grad(net) | |||
| output = grad(Tensor(x_np), Tensor(dy_np)) | |||
| expect = dy_np * np.exp(x_np) / (1 + np.exp(x_np)) | |||
| assert np.allclose(output[0].asnumpy(), expect, rtol=1e-2) | |||
| @@ -0,0 +1,106 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| class SoftplusNet(nn.Cell): | |||
| def __init__(self): | |||
| super(SoftplusNet, self).__init__() | |||
| self.softplus = P.Softplus() | |||
| def construct(self, x): | |||
| return self.softplus(x) | |||
| def SoftplusCompute(x): | |||
| return np.log(1 + np.exp(x)) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_softplus_1d(): | |||
| x_np = np.random.random((50,)).astype(np.float32) | |||
| y_np = SoftplusCompute(x_np) | |||
| x_ms = Tensor(x_np) | |||
| net = SoftplusNet() | |||
| y_ms = net(x_ms) | |||
| assert np.allclose(y_np, y_ms.asnumpy()) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_softplus_2d(): | |||
| x_np = np.random.random((50, 40)).astype(np.float32) | |||
| y_np = SoftplusCompute(x_np) | |||
| x_ms = Tensor(x_np) | |||
| net = SoftplusNet() | |||
| y_ms = net(x_ms) | |||
| assert np.allclose(y_np, y_ms.asnumpy()) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_softplus_4d(): | |||
| x_np = np.random.random((32, 3, 224, 224)).astype(np.float32) | |||
| y_np = SoftplusCompute(x_np) | |||
| x_ms = Tensor(x_np) | |||
| net = SoftplusNet() | |||
| y_ms = net(x_ms) | |||
| assert np.allclose(y_np, y_ms.asnumpy()) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_softplus_neg(): | |||
| x_np = np.random.random((32, 3, 224, 224)).astype(np.float32) * -1 | |||
| y_np = SoftplusCompute(x_np) | |||
| x_ms = Tensor(x_np) | |||
| net = SoftplusNet() | |||
| y_ms = net(x_ms) | |||
| assert np.allclose(y_np, y_ms.asnumpy()) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_softplus_4d_fp16(): | |||
| x_np = np.random.random((32, 3, 224, 224)).astype(np.float16) | |||
| y_np = SoftplusCompute(x_np) | |||
| x_ms = Tensor(x_np) | |||
| net = SoftplusNet() | |||
| y_ms = net(x_ms) | |||
| assert np.allclose(y_np, y_ms.asnumpy(), rtol=5e-3) | |||