| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -26,6 +26,7 @@ __global__ void SqrtGradKernel(const T *input, const T *dout, T *output, const s | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void RsqrtGradKernel(const T *input, const T *dout, T *output, const size_t count) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||
| @@ -37,6 +38,7 @@ __global__ void RsqrtGradKernel(const T *input, const T *dout, T *output, const | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void AsinGradKernel(const T *input, const T *dout, T *output, const size_t count) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||
| @@ -46,6 +48,7 @@ __global__ void AsinGradKernel(const T *input, const T *dout, T *output, const s | |||
| } | |||
| return; | |||
| } | |||
| template <> | |||
| __global__ void AsinGradKernel(const half *input, const half *dout, half *output, const size_t count) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||
| @@ -55,6 +58,7 @@ __global__ void AsinGradKernel(const half *input, const half *dout, half *output | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void ACosGradKernel(const T *input, const T *dout, T *output, const size_t count) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||
| @@ -65,6 +69,7 @@ __global__ void ACosGradKernel(const T *input, const T *dout, T *output, const s | |||
| } | |||
| return; | |||
| } | |||
| template <> | |||
| __global__ void ACosGradKernel(const half *input, const half *dout, half *output, const size_t count) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||
| @@ -75,6 +80,7 @@ __global__ void ACosGradKernel(const half *input, const half *dout, half *output | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void AtanGradKernel(const T *input, const T *dout, T *output, const size_t count) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||
| @@ -84,6 +90,7 @@ __global__ void AtanGradKernel(const T *input, const T *dout, T *output, const s | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void AsinhGradKernel(const T *input, const T *dout, T *output, const size_t count) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||
| @@ -93,6 +100,7 @@ __global__ void AsinhGradKernel(const T *input, const T *dout, T *output, const | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void AcoshGradKernel(const T *input, const T *dout, T *output, const size_t count) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||
| @@ -102,11 +110,24 @@ __global__ void AcoshGradKernel(const T *input, const T *dout, T *output, const | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void ReciprocalGradKernel(const T *input, const T *dout, T *output, const size_t count) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) { | |||
| float inputf = static_cast<float>(input[i]); | |||
| float doutf = static_cast<float>(dout[i]); | |||
| float res = -1 * doutf * inputf * inputf; | |||
| output[i] = static_cast<T>(res); | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void SqrtGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) { | |||
| SqrtGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count); | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void RsqrtGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) { | |||
| RsqrtGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count); | |||
| @@ -143,20 +164,28 @@ void AcoshGrad(const T *input, const T *dout, T *output, const size_t count, cud | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void ReciprocalGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) { | |||
| ReciprocalGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count); | |||
| return; | |||
| } | |||
| template void SqrtGrad<float>(const float *input, const float *dout, float *output, const size_t count, | |||
| cudaStream_t cuda_stream); | |||
| template void RsqrtGrad<float>(const float *input, const float *dout, float *output, const size_t count, | |||
| cudaStream_t cuda_stream); | |||
| template void AsinGrad<float>(const float *input, const float *dout, float *output, const size_t count, | |||
| cudaStream_t cuda_stream); | |||
| cudaStream_t cuda_stream); | |||
| template void ACosGrad<float>(const float *input, const float *dout, float *output, const size_t count, | |||
| cudaStream_t cuda_stream); | |||
| cudaStream_t cuda_stream); | |||
| template void AtanGrad<float>(const float *input, const float *dout, float *output, const size_t count, | |||
| cudaStream_t cuda_stream); | |||
| template void AsinhGrad<float>(const float *input, const float *dout, float *output, const size_t count, | |||
| cudaStream_t cuda_stream); | |||
| cudaStream_t cuda_stream); | |||
| template void AcoshGrad<float>(const float *input, const float *dout, float *output, const size_t count, | |||
| cudaStream_t cuda_stream); | |||
| template void ReciprocalGrad<float>(const float *input, const float *dout, float *output, const size_t count, | |||
| cudaStream_t cuda_stream); | |||
| template void SqrtGrad<half>(const half *input, const half *dout, half *output, const size_t count, | |||
| cudaStream_t cuda_stream); | |||
| template void RsqrtGrad<half>(const half *input, const half *dout, half *output, const size_t count, | |||
| @@ -164,10 +193,12 @@ template void RsqrtGrad<half>(const half *input, const half *dout, half *output, | |||
| template void AsinGrad<half>(const half *input, const half *dout, half *output, const size_t count, | |||
| cudaStream_t cuda_stream); | |||
| template void ACosGrad<half>(const half *input, const half *dout, half *output, const size_t count, | |||
| cudaStream_t cuda_stream); | |||
| cudaStream_t cuda_stream); | |||
| template void AtanGrad<half>(const half *input, const half *dout, half *output, const size_t count, | |||
| cudaStream_t cuda_stream); | |||
| template void AsinhGrad<half>(const half *input, const half *dout, half *output, const size_t count, | |||
| cudaStream_t cuda_stream); | |||
| cudaStream_t cuda_stream); | |||
| template void AcoshGrad<half>(const half *input, const half *dout, half *output, const size_t count, | |||
| cudaStream_t cuda_stream); | |||
| template void ReciprocalGrad<half>(const half *input, const half *dout, half *output, const size_t count, | |||
| cudaStream_t cuda_stream); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -32,6 +32,7 @@ template <typename T> | |||
| void AsinhGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void AcoshGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void ReciprocalGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOP_GRAD_IMPL_H_ | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -74,5 +74,13 @@ MS_REG_GPU_KERNEL_ONE( | |||
| AcoshGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| UnaryGradOpGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ReciprocalGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| UnaryGradOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ReciprocalGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| UnaryGradOpGpuKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -35,12 +35,14 @@ enum UnaryGradOptype { | |||
| UNARY_OP_ATAN_GRAD = 4, | |||
| UNARY_OP_ASINH_GRAD = 5, | |||
| UNARY_OP_ACOSH_GRAD = 6, | |||
| UNARY_OP_RECIPROCAL_GRAD = 7, | |||
| UNARY_OP_GRAD_INVALID_TYPE = 255 | |||
| }; | |||
| static const std::map<std::string, UnaryGradOptype> kUnaryGradOpTypeMap = { | |||
| {"SqrtGrad", UNARY_OP_SQRT_GRAD}, {"RsqrtGrad", UNARY_OP_RSQRT_GRAD}, {"AsinGrad", UNARY_OP_ASIN_GRAD}, | |||
| {"ACosGrad", UNARY_OP_ACOS_GRAD}, {"AtanGrad", UNARY_OP_ATAN_GRAD}, {"AsinhGrad", UNARY_OP_ASINH_GRAD}, | |||
| {"AcoshGrad", UNARY_OP_ACOSH_GRAD}}; | |||
| {"SqrtGrad", UNARY_OP_SQRT_GRAD}, {"RsqrtGrad", UNARY_OP_RSQRT_GRAD}, | |||
| {"AsinGrad", UNARY_OP_ASIN_GRAD}, {"ACosGrad", UNARY_OP_ACOS_GRAD}, | |||
| {"AtanGrad", UNARY_OP_ATAN_GRAD}, {"AsinhGrad", UNARY_OP_ASINH_GRAD}, | |||
| {"AcoshGrad", UNARY_OP_ACOSH_GRAD}, {"ReciprocalGrad", UNARY_OP_RECIPROCAL_GRAD}}; | |||
| template <typename T> | |||
| class UnaryGradOpGpuKernel : public GpuKernel { | |||
| @@ -101,6 +103,11 @@ class UnaryGradOpGpuKernel : public GpuKernel { | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| break; | |||
| } | |||
| case UNARY_OP_RECIPROCAL_GRAD: { | |||
| ReciprocalGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T), | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| break; | |||
| } | |||
| default: { | |||
| MS_LOG(EXCEPTION) << "Unary grad operation " << unary_grad_op_type_ << " is not supported."; | |||
| } | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -448,22 +448,11 @@ def get_bprop_rsqrt(self): | |||
| @bprop_getters.register(P.Reciprocal) | |||
| def get_bprop_reciprocal(self): | |||
| """Grad definition for `Reciprocal` operation.""" | |||
| if self.target == "GPU": | |||
| neg = P.Neg() | |||
| mul = P.Mul() | |||
| square = P.Square() | |||
| reciprocal = P.Reciprocal() | |||
| def bprop(x, out, dout): | |||
| g = neg(reciprocal(square(x))) | |||
| dx = mul(dout, g) | |||
| return (dx,) | |||
| else: | |||
| reciprocal_grad = G.ReciprocalGrad() | |||
| reciprocal_grad = G.ReciprocalGrad() | |||
| def bprop(x, out, dout): | |||
| dx = reciprocal_grad(out, dout) | |||
| return (dx,) | |||
| def bprop(x, out, dout): | |||
| dx = reciprocal_grad(out, dout) | |||
| return (dx,) | |||
| return bprop | |||
| @@ -0,0 +1,91 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| class NetReciprocalGrad(nn.Cell): | |||
| def __init__(self): | |||
| super(NetReciprocalGrad, self).__init__() | |||
| self.grad = G.ReciprocalGrad() | |||
| def construct(self, y, dy): | |||
| return self.grad(y, dy) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_reciprocal_grad_float32(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| y = Tensor(np.array([[[[-1, 1, 12], | |||
| [5, 34, 6], | |||
| [10, 2, -1]]]]).astype(np.float32)) | |||
| dy = Tensor(np.array([[[[29, 1, 55], | |||
| [2.2, 63, 2], | |||
| [3, 3, 12]]]]).astype(np.float32)) | |||
| expect = np.array([[[[-29, -1, -7920], | |||
| [-55, -72828, -72], | |||
| [-300, -12, -12]]]]).astype(np.float32) | |||
| net = NetReciprocalGrad() | |||
| output = net(y, dy) | |||
| np.testing.assert_array_almost_equal(output.asnumpy(), expect) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| y = Tensor(np.array([[[[-1, 1, 12], | |||
| [5, 34, 6], | |||
| [10, 2, -1]]]]).astype(np.float32)) | |||
| dy = Tensor(np.array([[[[29, 1, 55], | |||
| [2.2, 63, 2], | |||
| [3, 3, 12]]]]).astype(np.float32)) | |||
| expect = np.array([[[[-29, -1, -7920], | |||
| [-55, -72828, -72], | |||
| [-300, -12, -12]]]]).astype(np.float32) | |||
| net = NetReciprocalGrad() | |||
| output = net(y, dy) | |||
| np.testing.assert_array_almost_equal(output.asnumpy(), expect) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_reciprocal_grad_float16(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| y = Tensor(np.array([[0.01, 0.2, 0.22], | |||
| [10.002, 2, -1]]).astype(np.float16)) | |||
| dy = Tensor(np.array([[34, 1, 55], | |||
| [3, 3, 63]]).astype(np.float16)) | |||
| expect = np.array([[-0.0034, -0.03998, -2.662], | |||
| [-300, -12, -63]]).astype(np.float16) | |||
| net = NetReciprocalGrad() | |||
| output = net(y, dy) | |||
| np.testing.assert_array_almost_equal(output.asnumpy(), expect) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| y = Tensor(np.array([[0.01, 0.2, 0.22], | |||
| [10.002, 2, -1]]).astype(np.float16)) | |||
| dy = Tensor(np.array([[34, 1, 55], | |||
| [3, 3, 63]]).astype(np.float16)) | |||
| expect = np.array([[-0.0034, -0.03998, -2.662], | |||
| [-300, -12, -63]]).astype(np.float16) | |||
| net = NetReciprocalGrad() | |||
| output = net(y, dy) | |||
| np.testing.assert_array_almost_equal(output.asnumpy(), expect) | |||