Browse Source

add new GPU kernel ReciprocalGrad

tags/v1.2.0-rc1
TFBunny 4 years ago
parent
commit
88b5458f78
6 changed files with 156 additions and 29 deletions
  1. +37
    -6
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cu
  2. +3
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cuh
  3. +9
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.cc
  4. +11
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.h
  5. +5
    -16
      mindspore/ops/_grad/grad_math_ops.py
  6. +91
    -0
      tests/st/ops/gpu/test_reciprocal_grad_op.py

+ 37
- 6
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cu View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -26,6 +26,7 @@ __global__ void SqrtGradKernel(const T *input, const T *dout, T *output, const s
}
return;
}

template <typename T>
__global__ void RsqrtGradKernel(const T *input, const T *dout, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
@@ -37,6 +38,7 @@ __global__ void RsqrtGradKernel(const T *input, const T *dout, T *output, const
}
return;
}

template <typename T>
__global__ void AsinGradKernel(const T *input, const T *dout, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
@@ -46,6 +48,7 @@ __global__ void AsinGradKernel(const T *input, const T *dout, T *output, const s
}
return;
}

template <>
__global__ void AsinGradKernel(const half *input, const half *dout, half *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
@@ -55,6 +58,7 @@ __global__ void AsinGradKernel(const half *input, const half *dout, half *output
}
return;
}

template <typename T>
__global__ void ACosGradKernel(const T *input, const T *dout, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
@@ -65,6 +69,7 @@ __global__ void ACosGradKernel(const T *input, const T *dout, T *output, const s
}
return;
}

template <>
__global__ void ACosGradKernel(const half *input, const half *dout, half *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
@@ -75,6 +80,7 @@ __global__ void ACosGradKernel(const half *input, const half *dout, half *output
}
return;
}

template <typename T>
__global__ void AtanGradKernel(const T *input, const T *dout, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
@@ -84,6 +90,7 @@ __global__ void AtanGradKernel(const T *input, const T *dout, T *output, const s
}
return;
}

template <typename T>
__global__ void AsinhGradKernel(const T *input, const T *dout, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
@@ -93,6 +100,7 @@ __global__ void AsinhGradKernel(const T *input, const T *dout, T *output, const
}
return;
}

template <typename T>
__global__ void AcoshGradKernel(const T *input, const T *dout, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
@@ -102,11 +110,24 @@ __global__ void AcoshGradKernel(const T *input, const T *dout, T *output, const
}
return;
}

template <typename T>
__global__ void ReciprocalGradKernel(const T *input, const T *dout, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) {
float inputf = static_cast<float>(input[i]);
float doutf = static_cast<float>(dout[i]);
float res = -1 * doutf * inputf * inputf;
output[i] = static_cast<T>(res);
}
return;
}

template <typename T>
void SqrtGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) {
SqrtGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count);
return;
}

template <typename T>
void RsqrtGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) {
RsqrtGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count);
@@ -143,20 +164,28 @@ void AcoshGrad(const T *input, const T *dout, T *output, const size_t count, cud
return;
}

template <typename T>
void ReciprocalGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) {
ReciprocalGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count);
return;
}

template void SqrtGrad<float>(const float *input, const float *dout, float *output, const size_t count,
cudaStream_t cuda_stream);
template void RsqrtGrad<float>(const float *input, const float *dout, float *output, const size_t count,
cudaStream_t cuda_stream);
template void AsinGrad<float>(const float *input, const float *dout, float *output, const size_t count,
cudaStream_t cuda_stream);
cudaStream_t cuda_stream);
template void ACosGrad<float>(const float *input, const float *dout, float *output, const size_t count,
cudaStream_t cuda_stream);
cudaStream_t cuda_stream);
template void AtanGrad<float>(const float *input, const float *dout, float *output, const size_t count,
cudaStream_t cuda_stream);
template void AsinhGrad<float>(const float *input, const float *dout, float *output, const size_t count,
cudaStream_t cuda_stream);
cudaStream_t cuda_stream);
template void AcoshGrad<float>(const float *input, const float *dout, float *output, const size_t count,
cudaStream_t cuda_stream);
template void ReciprocalGrad<float>(const float *input, const float *dout, float *output, const size_t count,
cudaStream_t cuda_stream);
template void SqrtGrad<half>(const half *input, const half *dout, half *output, const size_t count,
cudaStream_t cuda_stream);
template void RsqrtGrad<half>(const half *input, const half *dout, half *output, const size_t count,
@@ -164,10 +193,12 @@ template void RsqrtGrad<half>(const half *input, const half *dout, half *output,
template void AsinGrad<half>(const half *input, const half *dout, half *output, const size_t count,
cudaStream_t cuda_stream);
template void ACosGrad<half>(const half *input, const half *dout, half *output, const size_t count,
cudaStream_t cuda_stream);
cudaStream_t cuda_stream);
template void AtanGrad<half>(const half *input, const half *dout, half *output, const size_t count,
cudaStream_t cuda_stream);
template void AsinhGrad<half>(const half *input, const half *dout, half *output, const size_t count,
cudaStream_t cuda_stream);
cudaStream_t cuda_stream);
template void AcoshGrad<half>(const half *input, const half *dout, half *output, const size_t count,
cudaStream_t cuda_stream);
template void ReciprocalGrad<half>(const half *input, const half *dout, half *output, const size_t count,
cudaStream_t cuda_stream);

+ 3
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cuh View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -32,6 +32,7 @@ template <typename T>
void AsinhGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void AcoshGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream);

template <typename T>
void ReciprocalGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream);

#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOP_GRAD_IMPL_H_

+ 9
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -74,5 +74,13 @@ MS_REG_GPU_KERNEL_ONE(
AcoshGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryGradOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
ReciprocalGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryGradOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
ReciprocalGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryGradOpGpuKernel, half)
} // namespace kernel
} // namespace mindspore

+ 11
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -35,12 +35,14 @@ enum UnaryGradOptype {
UNARY_OP_ATAN_GRAD = 4,
UNARY_OP_ASINH_GRAD = 5,
UNARY_OP_ACOSH_GRAD = 6,
UNARY_OP_RECIPROCAL_GRAD = 7,
UNARY_OP_GRAD_INVALID_TYPE = 255
};
static const std::map<std::string, UnaryGradOptype> kUnaryGradOpTypeMap = {
{"SqrtGrad", UNARY_OP_SQRT_GRAD}, {"RsqrtGrad", UNARY_OP_RSQRT_GRAD}, {"AsinGrad", UNARY_OP_ASIN_GRAD},
{"ACosGrad", UNARY_OP_ACOS_GRAD}, {"AtanGrad", UNARY_OP_ATAN_GRAD}, {"AsinhGrad", UNARY_OP_ASINH_GRAD},
{"AcoshGrad", UNARY_OP_ACOSH_GRAD}};
{"SqrtGrad", UNARY_OP_SQRT_GRAD}, {"RsqrtGrad", UNARY_OP_RSQRT_GRAD},
{"AsinGrad", UNARY_OP_ASIN_GRAD}, {"ACosGrad", UNARY_OP_ACOS_GRAD},
{"AtanGrad", UNARY_OP_ATAN_GRAD}, {"AsinhGrad", UNARY_OP_ASINH_GRAD},
{"AcoshGrad", UNARY_OP_ACOSH_GRAD}, {"ReciprocalGrad", UNARY_OP_RECIPROCAL_GRAD}};

template <typename T>
class UnaryGradOpGpuKernel : public GpuKernel {
@@ -101,6 +103,11 @@ class UnaryGradOpGpuKernel : public GpuKernel {
reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case UNARY_OP_RECIPROCAL_GRAD: {
ReciprocalGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T),
reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
default: {
MS_LOG(EXCEPTION) << "Unary grad operation " << unary_grad_op_type_ << " is not supported.";
}


+ 5
- 16
mindspore/ops/_grad/grad_math_ops.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -448,22 +448,11 @@ def get_bprop_rsqrt(self):
@bprop_getters.register(P.Reciprocal)
def get_bprop_reciprocal(self):
"""Grad definition for `Reciprocal` operation."""
if self.target == "GPU":
neg = P.Neg()
mul = P.Mul()
square = P.Square()
reciprocal = P.Reciprocal()

def bprop(x, out, dout):
g = neg(reciprocal(square(x)))
dx = mul(dout, g)
return (dx,)
else:
reciprocal_grad = G.ReciprocalGrad()
reciprocal_grad = G.ReciprocalGrad()

def bprop(x, out, dout):
dx = reciprocal_grad(out, dout)
return (dx,)
def bprop(x, out, dout):
dx = reciprocal_grad(out, dout)
return (dx,)

return bprop



+ 91
- 0
tests/st/ops/gpu/test_reciprocal_grad_op.py View File

@@ -0,0 +1,91 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import numpy as np
import pytest

import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops.operations import _grad_ops as G


class NetReciprocalGrad(nn.Cell):
def __init__(self):
super(NetReciprocalGrad, self).__init__()
self.grad = G.ReciprocalGrad()

def construct(self, y, dy):
return self.grad(y, dy)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_reciprocal_grad_float32():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
y = Tensor(np.array([[[[-1, 1, 12],
[5, 34, 6],
[10, 2, -1]]]]).astype(np.float32))
dy = Tensor(np.array([[[[29, 1, 55],
[2.2, 63, 2],
[3, 3, 12]]]]).astype(np.float32))
expect = np.array([[[[-29, -1, -7920],
[-55, -72828, -72],
[-300, -12, -12]]]]).astype(np.float32)
net = NetReciprocalGrad()
output = net(y, dy)
np.testing.assert_array_almost_equal(output.asnumpy(), expect)

context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
y = Tensor(np.array([[[[-1, 1, 12],
[5, 34, 6],
[10, 2, -1]]]]).astype(np.float32))
dy = Tensor(np.array([[[[29, 1, 55],
[2.2, 63, 2],
[3, 3, 12]]]]).astype(np.float32))
expect = np.array([[[[-29, -1, -7920],
[-55, -72828, -72],
[-300, -12, -12]]]]).astype(np.float32)
net = NetReciprocalGrad()
output = net(y, dy)
np.testing.assert_array_almost_equal(output.asnumpy(), expect)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_reciprocal_grad_float16():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
y = Tensor(np.array([[0.01, 0.2, 0.22],
[10.002, 2, -1]]).astype(np.float16))
dy = Tensor(np.array([[34, 1, 55],
[3, 3, 63]]).astype(np.float16))
expect = np.array([[-0.0034, -0.03998, -2.662],
[-300, -12, -63]]).astype(np.float16)
net = NetReciprocalGrad()
output = net(y, dy)
np.testing.assert_array_almost_equal(output.asnumpy(), expect)

context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
y = Tensor(np.array([[0.01, 0.2, 0.22],
[10.002, 2, -1]]).astype(np.float16))
dy = Tensor(np.array([[34, 1, 55],
[3, 3, 63]]).astype(np.float16))
expect = np.array([[-0.0034, -0.03998, -2.662],
[-300, -12, -63]]).astype(np.float16)
net = NetReciprocalGrad()
output = net(y, dy)
np.testing.assert_array_almost_equal(output.asnumpy(), expect)

Loading…
Cancel
Save