Browse Source

!4251 adding type support for gpu kernels for EfficientNet

Merge pull request !4251 from Peilin/efficientnet
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
c7b50bcdd2
5 changed files with 64 additions and 19 deletions
  1. +2
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.cc
  2. +12
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu
  3. +5
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cu
  4. +16
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.cc
  5. +29
    -14
      tests/st/ops/gpu/test_transpose_op.py

+ 2
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.cc View File

@@ -21,5 +21,7 @@ MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat32).A
TransposeGpuFwdKernel, float) TransposeGpuFwdKernel, float)
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
TransposeGpuFwdKernel, half) TransposeGpuFwdKernel, half)
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
TransposeGpuFwdKernel, int)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

+ 12
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu View File

@@ -15,6 +15,7 @@
*/ */


#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cuh" #include "backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
#include "runtime/device/gpu/cuda_common.h" #include "runtime/device/gpu/cuda_common.h"


template <typename T> template <typename T>
@@ -22,9 +23,9 @@ struct MinimumGradFunc {
__device__ __forceinline__ void operator()(const T &x1, const T &x2, const bool &grad_x1, const bool &grad_x2, __device__ __forceinline__ void operator()(const T &x1, const T &x2, const bool &grad_x1, const bool &grad_x2,
const T &dy, T *dx1, T *dx2) { const T &dy, T *dx1, T *dx2) {
if (grad_x1 && x1 < x2) { if (grad_x1 && x1 < x2) {
atomicAdd(dx1, dy);
ms_atomic_add(dx1, dy);
} else if (grad_x2 && x1 >= x2) { } else if (grad_x2 && x1 >= x2) {
atomicAdd(dx2, dy);
ms_atomic_add(dx2, dy);
} }
} }
}; };
@@ -34,9 +35,9 @@ struct MaximumGradFunc {
__device__ __forceinline__ void operator()(const T &x1, const T &x2, const bool &grad_x1, const bool &grad_x2, __device__ __forceinline__ void operator()(const T &x1, const T &x2, const bool &grad_x1, const bool &grad_x2,
const T &dy, T *dx1, T *dx2) { const T &dy, T *dx1, T *dx2) {
if (grad_x1 && x1 > x2) { if (grad_x1 && x1 > x2) {
atomicAdd(dx1, dy);
ms_atomic_add(dx1, dy);
} else if (grad_x2 && x1 <= x2) { } else if (grad_x2 && x1 <= x2) {
atomicAdd(dx2, dy);
ms_atomic_add(dx2, dy);
} }
} }
}; };
@@ -117,6 +118,9 @@ template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &
cudaStream_t stream); cudaStream_t stream);
template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op,
const int *x1, const int *x2, const int *dy, int *dx1, int *dx2, cudaStream_t stream); const int *x1, const int *x2, const int *dy, int *dx1, int *dx2, cudaStream_t stream);
template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op,
const half *x1, const half *x2, const half *dy, half *dx1, half *dx2,
cudaStream_t stream);
template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const float *x1, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const float *x1,
@@ -125,3 +129,7 @@ template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const i
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const int *x1, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const int *x1,
const int *x2, const int *dy, int *dx1, int *dx2, cudaStream_t stream); const int *x2, const int *dy, int *dx1, int *dx2, cudaStream_t stream);
template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const half *x1,
const half *x2, const half *dy, half *dx1, half *dx2, cudaStream_t stream);

+ 5
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cu View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -15,8 +15,10 @@
*/ */
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "transpose_impl.cuh" #include "transpose_impl.cuh"
#include "runtime/device/gpu/cuda_common.h" #include "runtime/device/gpu/cuda_common.h"
template <typename T> template <typename T>
__global__ void Transpose(const int size, const T* input, const int* input_shape, const int* input_axis, __global__ void Transpose(const int size, const T* input, const int* input_shape, const int* input_axis,
const int shape_size, T* output) { const int shape_size, T* output) {
@@ -63,3 +65,5 @@ template void CalTranspose<float>(const int size, const float* input, const int*
const int shape_size, float* output, cudaStream_t cuda_stream); const int shape_size, float* output, cudaStream_t cuda_stream);
template void CalTranspose<half>(const int size, const half* input, const int* input_shape, const int* input_axis, template void CalTranspose<half>(const int size, const half* input, const int* input_shape, const int* input_axis,
const int shape_size, half* output, cudaStream_t cuda_stream); const int shape_size, half* output, cudaStream_t cuda_stream);
template void CalTranspose<int>(const int size, const int* input, const int* input_shape, const int* input_axis,
const int shape_size, int* output, cudaStream_t cuda_stream);

+ 16
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.cc View File

@@ -34,6 +34,22 @@ MS_REG_GPU_KERNEL_ONE(MaximumGrad,
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32), .AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGradGpuKernel, float) BroadcastOpGradGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(MinimumGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGradGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(MaximumGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGradGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(MinimumGrad, MS_REG_GPU_KERNEL_ONE(MinimumGrad,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt32)


+ 29
- 14
tests/st/ops/gpu/test_transpose_op.py View File

@@ -28,25 +28,25 @@ context.set_context(device_target='GPU')




class Transpose(nn.Cell): class Transpose(nn.Cell):
def __init__(self):
def __init__(self, nptype):
super(Transpose, self).__init__() super(Transpose, self).__init__()
self.transpose = P.Transpose() self.transpose = P.Transpose()


self.x_2D = Parameter(initializer(Tensor(np.arange(5 * 6).reshape(5, 6).astype(np.float32)), [5, 6]),
self.x_2D = Parameter(initializer(Tensor(np.arange(5 * 6).reshape(5, 6).astype(nptype)), [5, 6]),
name='x_2D') name='x_2D')
self.perm_2D = (1, 0) self.perm_2D = (1, 0)


self.x_3D = Parameter(initializer(Tensor(np.arange(2 * 2 * 4).reshape(2, 2, 4).astype(np.float32)), [2, 2, 4]),
self.x_3D = Parameter(initializer(Tensor(np.arange(2 * 2 * 4).reshape(2, 2, 4).astype(nptype)), [2, 2, 4]),
name='x_3D') name='x_3D')
self.perm_3D = (1, 0, 2) self.perm_3D = (1, 0, 2)


self.x_4D = Parameter( self.x_4D = Parameter(
initializer(Tensor(np.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5).astype(np.float32)), [2, 3, 4, 5]),
initializer(Tensor(np.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5).astype(nptype)), [2, 3, 4, 5]),
name='x_4D') name='x_4D')
self.perm_4D = (0, 1, 2, 3) self.perm_4D = (0, 1, 2, 3)


self.x_5D = Parameter( self.x_5D = Parameter(
initializer(Tensor(np.arange(1 * 2 * 3 * 4 * 5).reshape(1, 2, 3, 4, 5).astype(np.float32)),
initializer(Tensor(np.arange(1 * 2 * 3 * 4 * 5).reshape(1, 2, 3, 4, 5).astype(nptype)),
[1, 2, 3, 4, 5]), name='x_5D') [1, 2, 3, 4, 5]), name='x_5D')
self.perm_5D = (1, 0, 3, 4, 2) self.perm_5D = (1, 0, 3, 4, 2)


@@ -56,11 +56,8 @@ class Transpose(nn.Cell):
self.transpose(self.x_4D, self.perm_4D), self.transpose(self.x_5D, self.perm_5D)) self.transpose(self.x_4D, self.perm_4D), self.transpose(self.x_5D, self.perm_5D))




@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_transpose():
transpose = Transpose()
def transpose1(nptype):
transpose = Transpose(nptype)
output = transpose() output = transpose()


expect0 = np.array([[[0, 6, 12, 18, 24], expect0 = np.array([[[0, 6, 12, 18, 24],
@@ -68,11 +65,11 @@ def test_transpose():
[2, 8, 14, 20, 26], [2, 8, 14, 20, 26],
[3, 9, 15, 21, 27], [3, 9, 15, 21, 27],
[4, 10, 16, 22, 28], [4, 10, 16, 22, 28],
[5, 11, 17, 23, 29]]]).astype(np.float32)
[5, 11, 17, 23, 29]]]).astype(nptype)
expect1 = np.array([[[[0, 1, 2, 3], expect1 = np.array([[[[0, 1, 2, 3],
[8, 9, 10, 11]], [8, 9, 10, 11]],
[[4, 5, 6, 7], [[4, 5, 6, 7],
[12, 13, 14, 15]]]]).astype(np.float32)
[12, 13, 14, 15]]]]).astype(nptype)
expect2 = np.array([[[[[0, 1, 2, 3, 4], expect2 = np.array([[[[[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9], [5, 6, 7, 8, 9],
[10, 11, 12, 13, 14], [10, 11, 12, 13, 14],
@@ -97,7 +94,7 @@ def test_transpose():
[[100, 101, 102, 103, 104], [[100, 101, 102, 103, 104],
[105, 106, 107, 108, 109], [105, 106, 107, 108, 109],
[110, 111, 112, 113, 114], [110, 111, 112, 113, 114],
[115, 116, 117, 118, 119]]]]]).astype(np.float32)
[115, 116, 117, 118, 119]]]]]).astype(nptype)
expect3 = np.array([[[[[[0, 20, 40], expect3 = np.array([[[[[[0, 20, 40],
[1, 21, 41], [1, 21, 41],
[2, 22, 42], [2, 22, 42],
@@ -138,8 +135,26 @@ def test_transpose():
[76, 96, 116], [76, 96, 116],
[77, 97, 117], [77, 97, 117],
[78, 98, 118], [78, 98, 118],
[79, 99, 119]]]]]]).astype(np.float32)
[79, 99, 119]]]]]]).astype(nptype)
assert (output[0].asnumpy() == expect0).all() assert (output[0].asnumpy() == expect0).all()
assert (output[1].asnumpy() == expect1).all() assert (output[1].asnumpy() == expect1).all()
assert (output[2].asnumpy() == expect2).all() assert (output[2].asnumpy() == expect2).all()
assert (output[3].asnumpy() == expect3).all() assert (output[3].asnumpy() == expect3).all()

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_transpose_float32():
transpose1(np.float32)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_transpose_float16():
transpose1(np.float16)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_transpose_int32():
transpose1(np.int32)

Loading…
Cancel
Save