From: @david-he91 Reviewed-by: @linqingke,@liangchenghui,@liangchenghui Signed-off-by: @liangchenghui,@liangchenghuitags/v1.2.0-rc1
| @@ -76,6 +76,33 @@ __global__ void ACosGradKernel(const half *input, const half *dout, half *output | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void AtanGradKernel(const T *input, const T *dout, T *output, const size_t count) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||
| T one = 1; | |||
| T divisor = one + input[i] * input[i]; | |||
| output[i] = dout[i] / divisor; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void AsinhGradKernel(const T *input, const T *dout, T *output, const size_t count) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||
| float inputf = static_cast<float>(input[i]); | |||
| T coshy = static_cast<T>(coshf(inputf)); | |||
| output[i] = dout[i] / coshy; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void AcoshGradKernel(const T *input, const T *dout, T *output, const size_t count) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||
| float inputf = static_cast<float>(input[i]); | |||
| T sinhy = static_cast<T>(sinhf(inputf)); | |||
| output[i] = dout[i] / sinhy; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void SqrtGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) { | |||
| SqrtGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count); | |||
| return; | |||
| @@ -98,6 +125,24 @@ void ACosGrad(const T *input, const T *dout, T *output, const size_t count, cuda | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void AtanGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) { | |||
| AtanGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count); | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void AsinhGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) { | |||
| AsinhGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count); | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void AcoshGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) { | |||
| AcoshGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count); | |||
| return; | |||
| } | |||
| template void SqrtGrad<float>(const float *input, const float *dout, float *output, const size_t count, | |||
| cudaStream_t cuda_stream); | |||
| template void RsqrtGrad<float>(const float *input, const float *dout, float *output, const size_t count, | |||
| @@ -106,6 +151,12 @@ template void AsinGrad<float>(const float *input, const float *dout, float *outp | |||
| cudaStream_t cuda_stream); | |||
| template void ACosGrad<float>(const float *input, const float *dout, float *output, const size_t count, | |||
| cudaStream_t cuda_stream); | |||
| template void AtanGrad<float>(const float *input, const float *dout, float *output, const size_t count, | |||
| cudaStream_t cuda_stream); | |||
| template void AsinhGrad<float>(const float *input, const float *dout, float *output, const size_t count, | |||
| cudaStream_t cuda_stream); | |||
| template void AcoshGrad<float>(const float *input, const float *dout, float *output, const size_t count, | |||
| cudaStream_t cuda_stream); | |||
| template void SqrtGrad<half>(const half *input, const half *dout, half *output, const size_t count, | |||
| cudaStream_t cuda_stream); | |||
| template void RsqrtGrad<half>(const half *input, const half *dout, half *output, const size_t count, | |||
| @@ -114,3 +165,9 @@ template void AsinGrad<half>(const half *input, const half *dout, half *output, | |||
| cudaStream_t cuda_stream); | |||
| template void ACosGrad<half>(const half *input, const half *dout, half *output, const size_t count, | |||
| cudaStream_t cuda_stream); | |||
| template void AtanGrad<half>(const half *input, const half *dout, half *output, const size_t count, | |||
| cudaStream_t cuda_stream); | |||
| template void AsinhGrad<half>(const half *input, const half *dout, half *output, const size_t count, | |||
| cudaStream_t cuda_stream); | |||
| template void AcoshGrad<half>(const half *input, const half *dout, half *output, const size_t count, | |||
| cudaStream_t cuda_stream); | |||
| @@ -26,5 +26,12 @@ template <typename T> | |||
| void AsinGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void ACosGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void AtanGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void AsinhGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void AcoshGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOP_GRAD_IMPL_H_ | |||
| @@ -146,6 +146,15 @@ __global__ void AsinKernel(const T *input, T *output, const size_t count) { | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void AsinhKernel(const T *input, T *output, const size_t count) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||
| float inputf = static_cast<float>(input[i]); | |||
| T res = static_cast<T>(asinhf(inputf)); | |||
| output[i] = res; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void CosKernel(const T *input, T *output, const size_t count) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||
| output[i] = cos(input[i]); | |||
| @@ -169,6 +178,24 @@ __global__ void ACosKernel(const T *input, T *output, const size_t count) { | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void AcoshKernel(const T *input, T *output, const size_t count) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||
| float inputf = static_cast<float>(input[i]); | |||
| T res = static_cast<T>(acoshf(inputf)); | |||
| output[i] = res; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void AtanKernel(const T *input, T *output, const size_t count) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||
| float inputf = static_cast<float>(input[i]); | |||
| T res = static_cast<T>(atanf(inputf)); | |||
| output[i] = res; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void ZeroslikeKernel(T *output, const size_t count) { | |||
| T zero = 0.0; | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||
| @@ -281,6 +308,21 @@ void ACos(const T *input, T *output, const size_t count, cudaStream_t cuda_strea | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void Atan(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) { | |||
| AtanKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count); | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void Asinh(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) { | |||
| AsinhKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count); | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void Acosh(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) { | |||
| AcoshKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count); | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void Rsqrt(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) { | |||
| RsqrtKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count); | |||
| return; | |||
| @@ -315,6 +357,9 @@ template void Sin<float>(const float *input, float *output, const size_t count, | |||
| template void Cos<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Asin<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void ACos<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Atan<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Asinh<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Acosh<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Rsqrt<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Zeroslike<float>(float *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Abs<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); | |||
| @@ -333,6 +378,9 @@ template void Sin<half>(const half *input, half *output, const size_t count, cud | |||
| template void Cos<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Asin<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void ACos<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Atan<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Asinh<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Acosh<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Rsqrt<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Zeroslike<half>(half *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Abs<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); | |||
| @@ -49,6 +49,12 @@ void Asin(const T *input, T *output, const size_t count, cudaStream_t cuda_strea | |||
| template <typename T> | |||
| void ACos(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void Atan(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void Asinh(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void Acosh(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void Zeroslike(T *output, const size_t count, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void Abs(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); | |||
| @@ -74,6 +74,10 @@ MS_REG_GPU_KERNEL_ONE(Asin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut | |||
| UnaryOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Asin, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| UnaryOpGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(Asinh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| UnaryOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Asinh, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| UnaryOpGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| UnaryOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| @@ -82,6 +86,14 @@ MS_REG_GPU_KERNEL_ONE(ACos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut | |||
| UnaryOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(ACos, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| UnaryOpGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(Acosh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| UnaryOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Acosh, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| UnaryOpGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(Atan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| UnaryOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Atan, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| UnaryOpGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| UnaryOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| @@ -44,6 +44,9 @@ enum UnaryOptype { | |||
| UNARY_OP_COS, | |||
| UNARY_OP_ASIN, | |||
| UNARY_OP_ACOS, | |||
| UNARY_OP_ATAN, | |||
| UNARY_OP_ASINH, | |||
| UNARY_OP_ACOSH, | |||
| UNARY_OP_ABS, | |||
| UNARY_OP_FLOOR, | |||
| UNARY_OP_INVALID_TYPE = 255 | |||
| @@ -64,6 +67,9 @@ static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {{"Exp", UNARY | |||
| {"Cos", UNARY_OP_COS}, | |||
| {"Asin", UNARY_OP_ASIN}, | |||
| {"ACos", UNARY_OP_ACOS}, | |||
| {"Atan", UNARY_OP_ATAN}, | |||
| {"Asinh", UNARY_OP_ASINH}, | |||
| {"Acosh", UNARY_OP_ACOSH}, | |||
| {"Abs", UNARY_OP_ABS}, | |||
| {"Floor", UNARY_OP_FLOOR}}; | |||
| template <typename T> | |||
| @@ -142,6 +148,18 @@ class UnaryOpGpuKernel : public GpuKernel { | |||
| ACos(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| break; | |||
| } | |||
| case UNARY_OP_ATAN: { | |||
| Atan(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| break; | |||
| } | |||
| case UNARY_OP_ASINH: { | |||
| Asinh(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| break; | |||
| } | |||
| case UNARY_OP_ACOSH: { | |||
| Acosh(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| break; | |||
| } | |||
| case UNARY_OP_ZEROSLIKE: { | |||
| Zeroslike(output_addr, output_size_ / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| @@ -50,5 +50,29 @@ MS_REG_GPU_KERNEL_ONE( | |||
| ACosGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| UnaryGradOpGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| AtanGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| UnaryGradOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| AtanGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| UnaryGradOpGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| AsinhGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| UnaryGradOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| AsinhGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| UnaryGradOpGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| AcoshGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| UnaryGradOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| AcoshGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| UnaryGradOpGpuKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -32,12 +32,16 @@ enum UnaryGradOptype { | |||
| UNARY_OP_RSQRT_GRAD = 1, | |||
| UNARY_OP_ASIN_GRAD = 2, | |||
| UNARY_OP_ACOS_GRAD = 3, | |||
| UNARY_OP_ATAN_GRAD = 4, | |||
| UNARY_OP_ASINH_GRAD = 5, | |||
| UNARY_OP_ACOSH_GRAD = 6, | |||
| UNARY_OP_GRAD_INVALID_TYPE = 255 | |||
| }; | |||
| static const std::map<std::string, UnaryGradOptype> kUnaryGradOpTypeMap = {{"SqrtGrad", UNARY_OP_SQRT_GRAD}, | |||
| {"RsqrtGrad", UNARY_OP_RSQRT_GRAD}, | |||
| {"AsinGrad", UNARY_OP_ASIN_GRAD}, | |||
| {"ACosGrad", UNARY_OP_ACOS_GRAD}}; | |||
| static const std::map<std::string, UnaryGradOptype> kUnaryGradOpTypeMap = { | |||
| {"SqrtGrad", UNARY_OP_SQRT_GRAD}, {"RsqrtGrad", UNARY_OP_RSQRT_GRAD}, {"AsinGrad", UNARY_OP_ASIN_GRAD}, | |||
| {"ACosGrad", UNARY_OP_ACOS_GRAD}, {"AtanGrad", UNARY_OP_ATAN_GRAD}, {"AsinhGrad", UNARY_OP_ASINH_GRAD}, | |||
| {"AcoshGrad", UNARY_OP_ACOSH_GRAD}}; | |||
| template <typename T> | |||
| class UnaryGradOpGpuKernel : public GpuKernel { | |||
| public: | |||
| @@ -77,6 +81,21 @@ class UnaryGradOpGpuKernel : public GpuKernel { | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| break; | |||
| } | |||
| case UNARY_OP_ATAN_GRAD: { | |||
| AtanGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T), | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| break; | |||
| } | |||
| case UNARY_OP_ASINH_GRAD: { | |||
| AsinhGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T), | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| break; | |||
| } | |||
| case UNARY_OP_ACOSH_GRAD: { | |||
| AcoshGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T), | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| break; | |||
| } | |||
| case UNARY_OP_RSQRT_GRAD: { | |||
| RsqrtGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T), | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| @@ -0,0 +1,43 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| from mindspore import Tensor | |||
| import mindspore.ops.operations._grad_ops as P | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| np.random.seed(1) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_acoshgrad_fp32(): | |||
| y_np = np.random.rand(4, 2).astype(np.float32) * 10 | |||
| dout_np = np.random.rand(4, 2).astype(np.float32) * 10 | |||
| output_ms = P.AcoshGrad()(Tensor(y_np), Tensor(dout_np)) | |||
| output_np = dout_np / np.sinh(y_np) | |||
| assert np.allclose(output_ms.asnumpy(), output_np, 1e-4, 1e-4) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_acoshgrad_fp16(): | |||
| y_np = np.random.rand(4, 2).astype(np.float16) * 10 | |||
| dout_np = np.random.rand(4, 2).astype(np.float16) * 10 | |||
| output_ms = P.AcoshGrad()(Tensor(y_np), Tensor(dout_np)) | |||
| output_np = dout_np.astype(np.float32) / np.sinh(y_np).astype(np.float32) | |||
| assert np.allclose(output_ms.asnumpy(), output_np.astype(np.float16), 1e-3, 1e-3) | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| np.random.seed(1) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_acosh_fp32(): | |||
| x_np = np.random.rand(4, 2).astype(np.float32) * 10 + 1 | |||
| output_ms = P.Acosh()(Tensor(x_np)) | |||
| output_np = np.arccosh(x_np) | |||
| assert np.allclose(output_ms.asnumpy(), output_np, 1e-4, 1e-4) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_acosh_fp16(): | |||
| x_np = np.random.rand(4, 2).astype(np.float16) * 10 + 1 | |||
| output_ms = P.Acosh()(Tensor(x_np)) | |||
| output_np = np.arccosh(x_np.astype(np.float32)).astype(np.float16) | |||
| assert np.allclose(output_ms.asnumpy(), output_np, 1e-3, 1e-3) | |||
| @@ -0,0 +1,43 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| from mindspore import Tensor | |||
| import mindspore.ops.operations._grad_ops as P | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| np.random.seed(1) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_asinhgrad_fp32(): | |||
| y_np = np.random.rand(4, 2).astype(np.float32) * 10 | |||
| dout_np = np.random.rand(4, 2).astype(np.float32) * 10 | |||
| output_ms = P.AsinhGrad()(Tensor(y_np), Tensor(dout_np)) | |||
| output_np = dout_np / np.cosh(y_np) | |||
| assert np.allclose(output_ms.asnumpy(), output_np, 1e-4, 1e-4) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_asinhgrad_fp16(): | |||
| y_np = np.random.rand(4, 2).astype(np.float16) * 10 | |||
| dout_np = np.random.rand(4, 2).astype(np.float16) * 10 | |||
| output_ms = P.AsinhGrad()(Tensor(y_np), Tensor(dout_np)) | |||
| output_np = dout_np.astype(np.float32) / np.cosh(y_np).astype(np.float32) | |||
| assert np.allclose(output_ms.asnumpy(), output_np.astype(np.float16), 1e-3, 1e-3) | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| np.random.seed(1) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_asinh_fp32(): | |||
| x_np = np.random.rand(4, 2).astype(np.float32) * 10 | |||
| output_ms = P.Asinh()(Tensor(x_np)) | |||
| output_np = np.arcsinh(x_np) | |||
| assert np.allclose(output_ms.asnumpy(), output_np, 1e-4, 1e-4) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_asinh_fp16(): | |||
| x_np = np.random.rand(4, 2).astype(np.float16) * 10 | |||
| output_ms = P.Asinh()(Tensor(x_np)) | |||
| output_np = np.arcsinh(x_np.astype(np.float32)).astype(np.float16) | |||
| assert np.allclose(output_ms.asnumpy(), output_np, 1e-3, 1e-3) | |||
| @@ -0,0 +1,43 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| from mindspore import Tensor | |||
| import mindspore.ops.operations._grad_ops as P | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| np.random.seed(1) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_atangrad_fp32(): | |||
| x_np = np.random.rand(4, 2).astype(np.float32) * 10 | |||
| dout_np = np.random.rand(4, 2).astype(np.float32) * 10 | |||
| output_ms = P.AtanGrad()(Tensor(x_np), Tensor(dout_np)) | |||
| output_np = dout_np / (1 + np.square(x_np)) | |||
| assert np.allclose(output_ms.asnumpy(), output_np, 1e-4, 1e-4) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_atangrad_fp16(): | |||
| x_np = np.random.rand(4, 2).astype(np.float16) * 10 | |||
| dout_np = np.random.rand(4, 2).astype(np.float16) * 10 | |||
| output_ms = P.AtanGrad()(Tensor(x_np), Tensor(dout_np)) | |||
| output_np = dout_np.astype(np.float32) / (1 + np.square(x_np.astype(np.float32))) | |||
| assert np.allclose(output_ms.asnumpy(), output_np.astype(np.float16), 1e-3, 1e-3) | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| np.random.seed(1) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_atan_fp32(): | |||
| x_np = np.random.rand(4, 2).astype(np.float32) * 10 | |||
| output_ms = P.Atan()(Tensor(x_np)) | |||
| output_np = np.arctan(x_np) | |||
| assert np.allclose(output_ms.asnumpy(), output_np, 1e-4, 1e-4) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_atan_fp16(): | |||
| x_np = np.random.rand(4, 2).astype(np.float16) * 10 | |||
| output_ms = P.Atan()(Tensor(x_np)) | |||
| output_np = np.arctan(x_np.astype(np.float32)).astype(np.float16) | |||
| assert np.allclose(output_ms.asnumpy(), output_np, 1e-3, 1e-3) | |||