| @@ -95,6 +95,34 @@ __global__ void RsqrtKernel(half *input, half *output, size_t count) { | |||||
| return; | return; | ||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| __global__ void SinKernel(T *input, T *output, size_t count) { | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||||
| output[i] = sin(input[i]); | |||||
| } | |||||
| return; | |||||
| } | |||||
| template <> | |||||
| __global__ void SinKernel(half *input, half *output, size_t count) { | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||||
| output[i] = hsin(input[i]); | |||||
| } | |||||
| return; | |||||
| } | |||||
| template <typename T> | |||||
| __global__ void CosKernel(T *input, T *output, size_t count) { | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||||
| output[i] = cos(input[i]); | |||||
| } | |||||
| return; | |||||
| } | |||||
| template <> | |||||
| __global__ void CosKernel(half *input, half *output, size_t count) { | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||||
| output[i] = hcos(input[i]); | |||||
| } | |||||
| return; | |||||
| } | |||||
| template <typename T> | |||||
| __global__ void ZeroslikeKernel(T *output, size_t count) { | __global__ void ZeroslikeKernel(T *output, size_t count) { | ||||
| T zero = 0.0; | T zero = 0.0; | ||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | ||||
| @@ -167,6 +195,16 @@ void Sqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream) { | |||||
| return; | return; | ||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| void Sin(T *input, T *output, size_t count, cudaStream_t cuda_stream) { | |||||
| SinKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count); | |||||
| return; | |||||
| } | |||||
| template <typename T> | |||||
| void Cos(T *input, T *output, size_t count, cudaStream_t cuda_stream) { | |||||
| CosKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count); | |||||
| return; | |||||
| } | |||||
| template <typename T> | |||||
| void Rsqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream) { | void Rsqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream) { | ||||
| RsqrtKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count); | RsqrtKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count); | ||||
| return; | return; | ||||
| @@ -193,6 +231,8 @@ template void Negative<float>(float *input, float *output, size_t count, cudaStr | |||||
| template void Reciprocal<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream); | template void Reciprocal<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream); | ||||
| template void Square<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream); | template void Square<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream); | ||||
| template void Sqrt<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream); | template void Sqrt<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream); | ||||
| template void Sin<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream); | |||||
| template void Cos<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream); | |||||
| template void Rsqrt<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream); | template void Rsqrt<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream); | ||||
| template void Zeroslike<float>(float *output, size_t count, cudaStream_t cuda_stream); | template void Zeroslike<float>(float *output, size_t count, cudaStream_t cuda_stream); | ||||
| template void Abs<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream); | template void Abs<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream); | ||||
| @@ -203,6 +243,8 @@ template void Negative<half>(half *input, half *output, size_t count, cudaStream | |||||
| template void Reciprocal<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | template void Reciprocal<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | ||||
| template void Square<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | template void Square<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | ||||
| template void Sqrt<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | template void Sqrt<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | ||||
| template void Sin<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | |||||
| template void Cos<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | |||||
| template void Rsqrt<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | template void Rsqrt<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | ||||
| template void Zeroslike<half>(half *output, size_t count, cudaStream_t cuda_stream); | template void Zeroslike<half>(half *output, size_t count, cudaStream_t cuda_stream); | ||||
| template void Abs<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | template void Abs<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | ||||
| @@ -33,6 +33,10 @@ void Sqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream); | |||||
| template <typename T> | template <typename T> | ||||
| void Rsqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream); | void Rsqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream); | ||||
| template <typename T> | template <typename T> | ||||
| void Sin(T *input, T *output, size_t count, cudaStream_t cuda_stream); | |||||
| template <typename T> | |||||
| void Cos(T *input, T *output, size_t count, cudaStream_t cuda_stream); | |||||
| template <typename T> | |||||
| void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream); | void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream); | ||||
| template <typename T> | template <typename T> | ||||
| void Abs(T *input, T *output, size_t count, cudaStream_t cuda_stream); | void Abs(T *input, T *output, size_t count, cudaStream_t cuda_stream); | ||||
| @@ -46,6 +46,14 @@ MS_REG_GPU_KERNEL_ONE(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut | |||||
| UnaryOpGpuKernel, float) | UnaryOpGpuKernel, float) | ||||
| MS_REG_GPU_KERNEL_ONE(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | MS_REG_GPU_KERNEL_ONE(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | ||||
| UnaryOpGpuKernel, float) | UnaryOpGpuKernel, float) | ||||
| MS_REG_GPU_KERNEL_ONE(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| UnaryOpGpuKernel, float) | |||||
| MS_REG_GPU_KERNEL_ONE(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||||
| UnaryOpGpuKernel, half) | |||||
| MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| UnaryOpGpuKernel, float) | |||||
| MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||||
| UnaryOpGpuKernel, half) | |||||
| MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | ||||
| UnaryOpGpuKernel, float) | UnaryOpGpuKernel, float) | ||||
| MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | ||||
| @@ -36,6 +36,8 @@ enum UnaryOptype { | |||||
| UNARY_OP_SQUARE, | UNARY_OP_SQUARE, | ||||
| UNARY_OP_SQRT, | UNARY_OP_SQRT, | ||||
| UNARY_OP_RSQRT, | UNARY_OP_RSQRT, | ||||
| UNARY_OP_SIN, | |||||
| UNARY_OP_COS, | |||||
| UNARY_OP_ABS, | UNARY_OP_ABS, | ||||
| UNARY_OP_FLOOR, | UNARY_OP_FLOOR, | ||||
| UNARY_OP_INVALID_TYPE = 255 | UNARY_OP_INVALID_TYPE = 255 | ||||
| @@ -48,6 +50,8 @@ static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {{"Exp", UNARY | |||||
| {"Square", UNARY_OP_SQUARE}, | {"Square", UNARY_OP_SQUARE}, | ||||
| {"Sqrt", UNARY_OP_SQRT}, | {"Sqrt", UNARY_OP_SQRT}, | ||||
| {"Rsqrt", UNARY_OP_RSQRT}, | {"Rsqrt", UNARY_OP_RSQRT}, | ||||
| {"Sin", UNARY_OP_SIN}, | |||||
| {"Cos", UNARY_OP_COS}, | |||||
| {"Abs", UNARY_OP_ABS}, | {"Abs", UNARY_OP_ABS}, | ||||
| {"Floor", UNARY_OP_FLOOR}}; | {"Floor", UNARY_OP_FLOOR}}; | ||||
| template <typename T> | template <typename T> | ||||
| @@ -100,6 +104,14 @@ class UnaryOpGpuKernel : public GpuKernel { | |||||
| Rsqrt(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr)); | Rsqrt(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| break; | break; | ||||
| } | } | ||||
| case UNARY_OP_SIN: { | |||||
| Sin(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| break; | |||||
| } | |||||
| case UNARY_OP_COS: { | |||||
| Cos(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| break; | |||||
| } | |||||
| case UNARY_OP_ZEROSLIKE: { | case UNARY_OP_ZEROSLIKE: { | ||||
| Zeroslike(output_addr, output_size_ / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr)); | Zeroslike(output_addr, output_size_ / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| return true; | return true; | ||||
| @@ -0,0 +1,33 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_cos(): | |||||
| x_np = np.random.rand(2, 3, 4, 4).astype(np.float32) | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||||
| output_ms = P.Cos()(Tensor(x_np)) | |||||
| output_np = np.cos(x_np) | |||||
| assert np.allclose(output_ms.asnumpy(), output_np) | |||||
| @@ -0,0 +1,33 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_sin(): | |||||
| x_np = np.random.rand(2, 3, 4, 4).astype(np.float32) | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||||
| output_ms = P.Sin()(Tensor(x_np)) | |||||
| output_np = np.sin(x_np) | |||||
| assert np.allclose(output_ms.asnumpy(), output_np) | |||||