Merge pull request !945 from chenweifeng/unarytags/v0.3.0-alpha
| @@ -60,6 +60,34 @@ __global__ void SquareKernel(T *input, T *output, size_t count) { | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void SqrtKernel(T *input, T *output, size_t count) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||
| output[i] = sqrt(input[i]); | |||
| } | |||
| return; | |||
| } | |||
| template <> | |||
| __global__ void SqrtKernel(half *input, half *output, size_t count) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||
| output[i] = hsqrt(input[i]); | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void RsqrtKernel(T *input, T *output, size_t count) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||
| output[i] = rsqrt(input[i]); | |||
| } | |||
| return; | |||
| } | |||
| template <> | |||
| __global__ void RsqrtKernel(half *input, half *output, size_t count) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||
| output[i] = hrsqrt(input[i]); | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void ZeroslikeKernel(T *output, size_t count) { | |||
| T zero = 0.0; | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||
| @@ -93,6 +121,21 @@ void Square(T *input, T *output, size_t count, cudaStream_t cuda_stream) { | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void Pow(T *input, T *output, size_t count, cudaStream_t cuda_stream) { | |||
| PowKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count); | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void Sqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream) { | |||
| SqrtKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count); | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void Rsqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream) { | |||
| RsqrtKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count); | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream) { | |||
| ZeroslikeKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(output, count); | |||
| return; | |||
| @@ -103,10 +146,14 @@ template void Logarithm<float>(float *input, float *output, size_t count, cudaSt | |||
| template void Negative<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream); | |||
| template void Reciprocal<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream); | |||
| template void Square<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream); | |||
| template void Sqrt<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream); | |||
| template void Rsqrt<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream); | |||
| template void Zeroslike<float>(float *output, size_t count, cudaStream_t cuda_stream); | |||
| template void Exponential<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | |||
| template void Logarithm<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | |||
| template void Negative<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | |||
| template void Reciprocal<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | |||
| template void Square<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | |||
| template void Sqrt<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | |||
| template void Rsqrt<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | |||
| template void Zeroslike<half>(half *output, size_t count, cudaStream_t cuda_stream); | |||
| @@ -29,6 +29,10 @@ void Reciprocal(T *input, T *output, size_t count, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void Square(T *input, T *output, size_t count, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void Sqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void Rsqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_ | |||
| @@ -42,5 +42,9 @@ MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddO | |||
| UnaryOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| UnaryOpGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| UnaryOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| UnaryOpGpuKernel, float) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -34,6 +34,8 @@ enum UnaryOptype { | |||
| UNARY_OP_RECIPROCAL, | |||
| UNARY_OP_ZEROSLIKE, | |||
| UNARY_OP_SQUARE, | |||
| UNARY_OP_SQRT, | |||
| UNARY_OP_RSQRT, | |||
| UNARY_OP_INVALID_TYPE = 255 | |||
| }; | |||
| static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {{"Exp", UNARY_OP_EXP}, | |||
| @@ -41,7 +43,9 @@ static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {{"Exp", UNARY | |||
| {"Neg", UNARY_OP_NEG}, | |||
| {"Reciprocal", UNARY_OP_RECIPROCAL}, | |||
| {"ZerosLike", UNARY_OP_ZEROSLIKE}, | |||
| {"Square", UNARY_OP_SQUARE}}; | |||
| {"Square", UNARY_OP_SQUARE}, | |||
| {"Sqrt", UNARY_OP_SQRT}, | |||
| {"Rsqrt", UNARY_OP_RSQRT}}; | |||
| template <typename T> | |||
| class UnaryOpGpuKernel : public GpuKernel { | |||
| public: | |||
| @@ -80,6 +84,14 @@ class UnaryOpGpuKernel : public GpuKernel { | |||
| Square(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| break; | |||
| } | |||
| case UNARY_OP_SQRT: { | |||
| Sqrt(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| break; | |||
| } | |||
| case UNARY_OP_RSQRT: { | |||
| Rsqrt(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| break; | |||
| } | |||
| case UNARY_OP_ZEROSLIKE: { | |||
| Zeroslike(output_addr, output_size_ / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| @@ -0,0 +1,38 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import pytest | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| import mindspore.nn as nn | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sqrt(): | |||
| x_np = np.random.rand(2, 3, 4, 4).astype(np.float32) | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| output_ms = P.Sqrt()(Tensor(x_np)) | |||
| output_np = np.sqrt(x_np) | |||
| assert np.allclose(output_ms.asnumpy(), output_np) | |||
| output_ms = P.Rsqrt()(Tensor(x_np)) | |||
| output_np = 1 / np.sqrt(x_np) | |||
| assert np.allclose(output_ms.asnumpy(), output_np) | |||