Merge pull request !4251 from Peilin/efficientnettags/v0.7.0-beta
| @@ -21,5 +21,7 @@ MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat32).A | |||
| TransposeGpuFwdKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| TransposeGpuFwdKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| TransposeGpuFwdKernel, int) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -15,6 +15,7 @@ | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cuh" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| @@ -22,9 +23,9 @@ struct MinimumGradFunc { | |||
| __device__ __forceinline__ void operator()(const T &x1, const T &x2, const bool &grad_x1, const bool &grad_x2, | |||
| const T &dy, T *dx1, T *dx2) { | |||
| if (grad_x1 && x1 < x2) { | |||
| atomicAdd(dx1, dy); | |||
| ms_atomic_add(dx1, dy); | |||
| } else if (grad_x2 && x1 >= x2) { | |||
| atomicAdd(dx2, dy); | |||
| ms_atomic_add(dx2, dy); | |||
| } | |||
| } | |||
| }; | |||
| @@ -34,9 +35,9 @@ struct MaximumGradFunc { | |||
| __device__ __forceinline__ void operator()(const T &x1, const T &x2, const bool &grad_x1, const bool &grad_x2, | |||
| const T &dy, T *dx1, T *dx2) { | |||
| if (grad_x1 && x1 > x2) { | |||
| atomicAdd(dx1, dy); | |||
| ms_atomic_add(dx1, dy); | |||
| } else if (grad_x2 && x1 <= x2) { | |||
| atomicAdd(dx2, dy); | |||
| ms_atomic_add(dx2, dy); | |||
| } | |||
| } | |||
| }; | |||
| @@ -117,6 +118,9 @@ template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool & | |||
| cudaStream_t stream); | |||
| template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, | |||
| const int *x1, const int *x2, const int *dy, int *dx1, int *dx2, cudaStream_t stream); | |||
| template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, | |||
| const half *x1, const half *x2, const half *dy, half *dx1, half *dx2, | |||
| cudaStream_t stream); | |||
| template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, | |||
| const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, | |||
| const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const float *x1, | |||
| @@ -125,3 +129,7 @@ template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const i | |||
| const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, | |||
| const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const int *x1, | |||
| const int *x2, const int *dy, int *dx1, int *dx2, cudaStream_t stream); | |||
| template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, | |||
| const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, | |||
| const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const half *x1, | |||
| const half *x2, const half *dy, half *dx1, half *dx2, cudaStream_t stream); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -15,8 +15,10 @@ | |||
| */ | |||
| #include <cuda_runtime.h> | |||
| #include "transpose_impl.cuh" | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| __global__ void Transpose(const int size, const T* input, const int* input_shape, const int* input_axis, | |||
| const int shape_size, T* output) { | |||
| @@ -63,3 +65,5 @@ template void CalTranspose<float>(const int size, const float* input, const int* | |||
| const int shape_size, float* output, cudaStream_t cuda_stream); | |||
| template void CalTranspose<half>(const int size, const half* input, const int* input_shape, const int* input_axis, | |||
| const int shape_size, half* output, cudaStream_t cuda_stream); | |||
| template void CalTranspose<int>(const int size, const int* input, const int* input_shape, const int* input_axis, | |||
| const int shape_size, int* output, cudaStream_t cuda_stream); | |||
| @@ -34,6 +34,22 @@ MS_REG_GPU_KERNEL_ONE(MaximumGrad, | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| BroadcastOpGradGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(MinimumGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| BroadcastOpGradGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(MaximumGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| BroadcastOpGradGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(MinimumGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| @@ -28,25 +28,25 @@ context.set_context(device_target='GPU') | |||
| class Transpose(nn.Cell): | |||
| def __init__(self): | |||
| def __init__(self, nptype): | |||
| super(Transpose, self).__init__() | |||
| self.transpose = P.Transpose() | |||
| self.x_2D = Parameter(initializer(Tensor(np.arange(5 * 6).reshape(5, 6).astype(np.float32)), [5, 6]), | |||
| self.x_2D = Parameter(initializer(Tensor(np.arange(5 * 6).reshape(5, 6).astype(nptype)), [5, 6]), | |||
| name='x_2D') | |||
| self.perm_2D = (1, 0) | |||
| self.x_3D = Parameter(initializer(Tensor(np.arange(2 * 2 * 4).reshape(2, 2, 4).astype(np.float32)), [2, 2, 4]), | |||
| self.x_3D = Parameter(initializer(Tensor(np.arange(2 * 2 * 4).reshape(2, 2, 4).astype(nptype)), [2, 2, 4]), | |||
| name='x_3D') | |||
| self.perm_3D = (1, 0, 2) | |||
| self.x_4D = Parameter( | |||
| initializer(Tensor(np.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5).astype(np.float32)), [2, 3, 4, 5]), | |||
| initializer(Tensor(np.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5).astype(nptype)), [2, 3, 4, 5]), | |||
| name='x_4D') | |||
| self.perm_4D = (0, 1, 2, 3) | |||
| self.x_5D = Parameter( | |||
| initializer(Tensor(np.arange(1 * 2 * 3 * 4 * 5).reshape(1, 2, 3, 4, 5).astype(np.float32)), | |||
| initializer(Tensor(np.arange(1 * 2 * 3 * 4 * 5).reshape(1, 2, 3, 4, 5).astype(nptype)), | |||
| [1, 2, 3, 4, 5]), name='x_5D') | |||
| self.perm_5D = (1, 0, 3, 4, 2) | |||
| @@ -56,11 +56,8 @@ class Transpose(nn.Cell): | |||
| self.transpose(self.x_4D, self.perm_4D), self.transpose(self.x_5D, self.perm_5D)) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_transpose(): | |||
| transpose = Transpose() | |||
| def transpose1(nptype): | |||
| transpose = Transpose(nptype) | |||
| output = transpose() | |||
| expect0 = np.array([[[0, 6, 12, 18, 24], | |||
| @@ -68,11 +65,11 @@ def test_transpose(): | |||
| [2, 8, 14, 20, 26], | |||
| [3, 9, 15, 21, 27], | |||
| [4, 10, 16, 22, 28], | |||
| [5, 11, 17, 23, 29]]]).astype(np.float32) | |||
| [5, 11, 17, 23, 29]]]).astype(nptype) | |||
| expect1 = np.array([[[[0, 1, 2, 3], | |||
| [8, 9, 10, 11]], | |||
| [[4, 5, 6, 7], | |||
| [12, 13, 14, 15]]]]).astype(np.float32) | |||
| [12, 13, 14, 15]]]]).astype(nptype) | |||
| expect2 = np.array([[[[[0, 1, 2, 3, 4], | |||
| [5, 6, 7, 8, 9], | |||
| [10, 11, 12, 13, 14], | |||
| @@ -97,7 +94,7 @@ def test_transpose(): | |||
| [[100, 101, 102, 103, 104], | |||
| [105, 106, 107, 108, 109], | |||
| [110, 111, 112, 113, 114], | |||
| [115, 116, 117, 118, 119]]]]]).astype(np.float32) | |||
| [115, 116, 117, 118, 119]]]]]).astype(nptype) | |||
| expect3 = np.array([[[[[[0, 20, 40], | |||
| [1, 21, 41], | |||
| [2, 22, 42], | |||
| @@ -138,8 +135,26 @@ def test_transpose(): | |||
| [76, 96, 116], | |||
| [77, 97, 117], | |||
| [78, 98, 118], | |||
| [79, 99, 119]]]]]]).astype(np.float32) | |||
| [79, 99, 119]]]]]]).astype(nptype) | |||
| assert (output[0].asnumpy() == expect0).all() | |||
| assert (output[1].asnumpy() == expect1).all() | |||
| assert (output[2].asnumpy() == expect2).all() | |||
| assert (output[3].asnumpy() == expect3).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_transpose_float32(): | |||
| transpose1(np.float32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_transpose_float16(): | |||
| transpose1(np.float16) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_transpose_int32(): | |||
| transpose1(np.int32) | |||