Browse Source

!10090 [MS][GPU][CI Alarm] Fixed CI Alarm for Pad GPU kernel + code clean up

From: @danishnxt
Reviewed-by: @robingrosman
Signed-off-by: @robingrosman
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
4b7e9975a0
4 changed files with 75 additions and 26 deletions
  1. +7
    -8
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cu
  2. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh
  3. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pad_gpu_kernel.h
  4. +66
    -16
      tests/st/ops/gpu/test_pad.py

+ 7
- 8
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cu View File

@@ -22,7 +22,7 @@
template <typename T> template <typename T>
__global__ void Pad(const size_t size, const T* input, const int num, const int channels, const int old_height, __global__ void Pad(const size_t size, const T* input, const int num, const int channels, const int old_height,
const int old_width, const int padded_height, const int padded_width, const int pad_top, const int old_width, const int padded_height, const int padded_width, const int pad_top,
const int pad_left, float pad_value, T* output) {
const int pad_left, const float pad_value, T* output) {
T pad_value_ = static_cast<T>(pad_value); T pad_value_ = static_cast<T>(pad_value);
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
int block_num = pos / padded_width / padded_height; int block_num = pos / padded_width / padded_height;
@@ -64,8 +64,7 @@ template <typename T>
__global__ void PadGeneral(const size_t size, const T *input, const int num, const int channels_orig, __global__ void PadGeneral(const size_t size, const T *input, const int num, const int channels_orig,
const int pad_channel_before, const int pad_channel_after, const int old_height, const int pad_channel_before, const int pad_channel_after, const int old_height,
const int old_width, const int padded_height, const int padded_width, const int pad_top, const int old_width, const int padded_height, const int padded_width, const int pad_top,
const int pad_left, float pad_value, T *output) {
T pad_value_template = static_cast<T>(pad_value);
const int pad_left, const T pad_value, T *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
int block_num = (pos / padded_width) / padded_height; // total blocks = (batch * channels) int block_num = (pos / padded_width) / padded_height; // total blocks = (batch * channels)
const int padded_w = pos % padded_width; // x coordinate refered to by cur 'pos' const int padded_w = pos % padded_width; // x coordinate refered to by cur 'pos'
@@ -79,7 +78,7 @@ __global__ void PadGeneral(const size_t size, const T *input, const int num, con
if (padded_h - pad_top < 0 || padded_w - pad_left < 0 || padded_h - pad_top >= old_height || if (padded_h - pad_top < 0 || padded_w - pad_left < 0 || padded_h - pad_top >= old_height ||
padded_w - pad_left >= old_width || channel_num <= pad_channel_before - 1 || padded_w - pad_left >= old_width || channel_num <= pad_channel_before - 1 ||
channel_num > channels_orig + pad_channel_before - 1) { channel_num > channels_orig + pad_channel_before - 1) {
output[pos] = pad_value_template;
output[pos] = pad_value;
} else { } else {
// on a block/x,y positon that isn't padding, copy data from the correct block/x,y pos the input // on a block/x,y positon that isn't padding, copy data from the correct block/x,y pos the input
// calculate from number of blocks of padding (due to channel padding) inserted prior // calculate from number of blocks of padding (due to channel padding) inserted prior
@@ -139,7 +138,7 @@ template <typename T>
void CalPadGeneral(const size_t size, const T *input, const int num, const int channels_orig, void CalPadGeneral(const size_t size, const T *input, const int num, const int channels_orig,
const int pad_channel_before, const int pad_channel_after, const int old_height, const int old_width, const int pad_channel_before, const int pad_channel_after, const int old_height, const int old_width,
const int padded_height, const int padded_width, const int pad_top, const int pad_left, const int padded_height, const int padded_width, const int pad_top, const int pad_left,
float pad_value, T *output, cudaStream_t cuda_stream) {
const T pad_value, T *output, cudaStream_t cuda_stream) {
PadGeneral<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, num, channels_orig, pad_channel_before, PadGeneral<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, num, channels_orig, pad_channel_before,
pad_channel_after, old_height, old_width, padded_height, pad_channel_after, old_height, old_width, padded_height,
padded_width, pad_top, pad_left, pad_value, output); padded_width, pad_top, pad_left, pad_value, output);
@@ -199,15 +198,15 @@ template void CalPadGradNHWC<half>(const size_t size, const half* dy, const int
template void CalPadGeneral<float>(const size_t size, const float *input, const int num, const int channels_orig, template void CalPadGeneral<float>(const size_t size, const float *input, const int num, const int channels_orig,
const int pad_channel_before, const int pad_channel_after, const int old_height, const int pad_channel_before, const int pad_channel_after, const int old_height,
const int old_width, const int padded_height, const int padded_width, const int old_width, const int padded_height, const int padded_width,
const int pad_top, const int pad_left, float pad_value, float *output,
const int pad_top, const int pad_left, const float pad_value, float *output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void CalPadGeneral<half>(const size_t size, const half *input, const int num, const int channels_orig, template void CalPadGeneral<half>(const size_t size, const half *input, const int num, const int channels_orig,
const int pad_channel_before, const int pad_channel_after, const int old_height, const int pad_channel_before, const int pad_channel_after, const int old_height,
const int old_width, const int padded_height, const int padded_width, const int old_width, const int padded_height, const int padded_width,
const int pad_top, const int pad_left, float pad_value, half *output,
const int pad_top, const int pad_left, const half pad_value, half *output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void CalPadGeneral<int>(const size_t size, const int *input, const int num, const int channels_orig, template void CalPadGeneral<int>(const size_t size, const int *input, const int num, const int channels_orig,
const int pad_channel_before, const int pad_channel_after, const int old_height, const int pad_channel_before, const int pad_channel_after, const int old_height,
const int old_width, const int padded_height, const int padded_width, const int old_width, const int padded_height, const int padded_width,
const int pad_top, const int pad_left, float pad_value, int *output,
const int pad_top, const int pad_left, const int pad_value, int *output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);

+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh View File

@@ -39,5 +39,5 @@ template <typename T>
void CalPadGeneral(const size_t size, const T *input, const int num, const int channels_orig, void CalPadGeneral(const size_t size, const T *input, const int num, const int channels_orig,
const int pad_channel_before, const int pad_channel_after, const int old_height, const int old_width, const int pad_channel_before, const int pad_channel_after, const int old_height, const int old_width,
const int padded_height, const int padded_width, const int pad_top, const int pad_left, const int padded_height, const int padded_width, const int pad_top, const int pad_left,
float pad_value, T *output, cudaStream_t cuda_stream);
const T pad_value, T *output, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_ #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_

+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pad_gpu_kernel.h View File

@@ -45,7 +45,7 @@ class PadGpuFwdKernel : public GpuKernel {
int pad_top = paddings[2][0]; int pad_top = paddings[2][0];
int pad_channel_before = paddings[1][0]; int pad_channel_before = paddings[1][0];
int pad_channel_after = paddings[1][1]; int pad_channel_after = paddings[1][1];
T pad_value = 0.0;
const T pad_value = 0.0;
CalPadGeneral(size, input, input_shape_[0], input_shape_[1], pad_channel_before, pad_channel_after, input_shape_[2], CalPadGeneral(size, input, input_shape_[0], input_shape_[1], pad_channel_before, pad_channel_after, input_shape_[2],
input_shape_[3], output_shape_[2], output_shape_[3], pad_top, pad_left, pad_value, output, input_shape_[3], output_shape_[2], output_shape_[3], pad_top, pad_left, pad_value, output,
reinterpret_cast<cudaStream_t>(stream_ptr)); reinterpret_cast<cudaStream_t>(stream_ptr));


+ 66
- 16
tests/st/ops/gpu/test_pad.py View File

@@ -28,17 +28,27 @@ from mindspore.ops.composite import GradOperation
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_pad_basic(): def test_pad_basic():
# confirm array is being padded with 0's
"""
Test array is being padded with 0's
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")


# float32
test_arr = np.array([[1, 2], [3, 4]]).astype(np.float32) test_arr = np.array([[1, 2], [3, 4]]).astype(np.float32)
test_arr_expected = np.array( test_arr_expected = np.array(
[[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]]).astype(np.float32) [[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]]).astype(np.float32)
x_test = Tensor(test_arr, dtype=mindspore.float32) x_test = Tensor(test_arr, dtype=mindspore.float32)

pad_op = nn.Pad(mode='CONSTANT', paddings=((1, 1), (1, 1))) pad_op = nn.Pad(mode='CONSTANT', paddings=((1, 1), (1, 1)))
y_test = pad_op(x_test).asnumpy() y_test = pad_op(x_test).asnumpy()
np.testing.assert_array_equal(y_test, test_arr_expected)


# float16
test_arr = np.array([[1, 2], [3, 4]]).astype(np.float16)
test_arr_expected = np.array(
[[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]]).astype(np.float16)
x_test = Tensor(test_arr, dtype=mindspore.float16)
pad_op = nn.Pad(mode='CONSTANT', paddings=((1, 1), (1, 1)))
y_test = pad_op(x_test).asnumpy()
np.testing.assert_array_equal(y_test, test_arr_expected) np.testing.assert_array_equal(y_test, test_arr_expected)




@@ -46,12 +56,13 @@ def test_pad_basic():
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_pad_row(): def test_pad_row():
# Confirm correct row padding
"""
Test correct row padding
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")


test_arr_1 = np.random.rand(40, 40).astype(np.float32) test_arr_1 = np.random.rand(40, 40).astype(np.float32)
test_paddings_1 = ((2, 3), (0, 0)) test_paddings_1 = ((2, 3), (0, 0))

test_arr_2 = np.random.randn(3, 10, 30, 30).astype(np.float32) test_arr_2 = np.random.randn(3, 10, 30, 30).astype(np.float32)
test_paddings_2 = ((0, 0), (0, 0), (3, 0), (0, 0)) test_paddings_2 = ((0, 0), (0, 0), (3, 0), (0, 0))


@@ -60,7 +71,6 @@ def test_pad_row():


x_test_1 = Tensor(np.array(test_arr_1), dtype=mindspore.float32) x_test_1 = Tensor(np.array(test_arr_1), dtype=mindspore.float32)
x_test_2 = Tensor(np.array(test_arr_2), dtype=mindspore.float32) x_test_2 = Tensor(np.array(test_arr_2), dtype=mindspore.float32)

y_test_1 = pad_op_row_1(x_test_1).asnumpy() y_test_1 = pad_op_row_1(x_test_1).asnumpy()
y_test_2 = pad_op_row_2(x_test_2).asnumpy() y_test_2 = pad_op_row_2(x_test_2).asnumpy()


@@ -77,12 +87,13 @@ def test_pad_row():
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_pad_column(): def test_pad_column():
# Confirm correct column padding
"""
Test correct column padding
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")


test_arr_1 = np.random.randn(40, 40).astype(np.float32) test_arr_1 = np.random.randn(40, 40).astype(np.float32)
test_paddings_1 = ((0, 0), (3, 3)) test_paddings_1 = ((0, 0), (3, 3))

test_arr_2 = np.random.randn(3, 10, 30, 30).astype(np.float32) test_arr_2 = np.random.randn(3, 10, 30, 30).astype(np.float32)
test_paddings_2 = ((0, 0), (0, 0), (0, 0), (6, 1)) test_paddings_2 = ((0, 0), (0, 0), (0, 0), (6, 1))


@@ -91,7 +102,6 @@ def test_pad_column():


x_test_1 = Tensor(np.array(test_arr_1), dtype=mindspore.float32) x_test_1 = Tensor(np.array(test_arr_1), dtype=mindspore.float32)
x_test_2 = Tensor(np.array(test_arr_2), dtype=mindspore.float32) x_test_2 = Tensor(np.array(test_arr_2), dtype=mindspore.float32)

y_test_1 = pad_op_col_1(x_test_1).asnumpy() y_test_1 = pad_op_col_1(x_test_1).asnumpy()
y_test_2 = pad_op_col_2(x_test_2).asnumpy() y_test_2 = pad_op_col_2(x_test_2).asnumpy()


@@ -108,15 +118,34 @@ def test_pad_column():
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_pad_3d_pad(): def test_pad_3d_pad():
# Confirm correct 3d padding - row, column, channel
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
"""
Test full 3d padding, with all 3 input types
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")


# float32
test_arr = np.random.randn(5, 3, 30, 30).astype(np.float32) test_arr = np.random.randn(5, 3, 30, 30).astype(np.float32)
test_paddings = ((0, 0), (2, 1), (0, 1), (0, 2)) # padding 3 dims now test_paddings = ((0, 0), (2, 1), (0, 1), (0, 2)) # padding 3 dims now

pad_op_3d = nn.Pad(mode='CONSTANT', paddings=test_paddings) pad_op_3d = nn.Pad(mode='CONSTANT', paddings=test_paddings)
x_test = Tensor(np.array(test_arr), dtype=mindspore.float32) x_test = Tensor(np.array(test_arr), dtype=mindspore.float32)
y_test = pad_op_3d(x_test).asnumpy()
assert y_test.shape == (5, 6, 31, 32)
np.testing.assert_equal(test_arr, y_test[:, 2:-1, :-1, :-2])


# float16
test_arr = np.random.randn(5, 3, 30, 30).astype(np.float16)
test_paddings = ((0, 0), (2, 1), (0, 1), (0, 2))
pad_op_3d = nn.Pad(mode='CONSTANT', paddings=test_paddings)
x_test = Tensor(np.array(test_arr), dtype=mindspore.float16)
y_test = pad_op_3d(x_test).asnumpy()
assert y_test.shape == (5, 6, 31, 32)
np.testing.assert_equal(test_arr, y_test[:, 2:-1, :-1, :-2])

# int32
test_arr = np.random.randint(1, 3000, (5, 3, 30, 30)).astype(np.int32)
test_paddings = ((0, 0), (2, 1), (0, 1), (0, 2))
pad_op_3d = nn.Pad(mode='CONSTANT', paddings=test_paddings)
x_test = Tensor(np.array(test_arr), dtype=mindspore.int32)
y_test = pad_op_3d(x_test).asnumpy() y_test = pad_op_3d(x_test).asnumpy()
assert y_test.shape == (5, 6, 31, 32) assert y_test.shape == (5, 6, 31, 32)
np.testing.assert_equal(test_arr, y_test[:, 2:-1, :-1, :-2]) np.testing.assert_equal(test_arr, y_test[:, 2:-1, :-1, :-2])
@@ -147,17 +176,36 @@ class Net(nn.Cell):
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_pad_3d_backprop(): def test_pad_3d_backprop():
# Confirm correct 3d padding backprop
"""
Confirm correct 3d padding backprop
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = Grad(Net())
padded_shape = (5, 10, 32, 32)


# float32
test_arr = np.random.randn(5, 3, 30, 30).astype(np.float32) test_arr = np.random.randn(5, 3, 30, 30).astype(np.float32)
x_test = Tensor(test_arr, dtype=mindspore.float32) x_test = Tensor(test_arr, dtype=mindspore.float32)

padded_shape = (5, 10, 32, 32)
dy = np.random.randn(*padded_shape).astype(np.float32) dy = np.random.randn(*padded_shape).astype(np.float32)
expected_dx = dy[:, 4:-3, 1:-1, :-2] expected_dx = dy[:, 4:-3, 1:-1, :-2]
dx = net(x_test, Tensor(dy))
dx = dx[0].asnumpy()
np.testing.assert_array_equal(dx, expected_dx)


net = Grad(Net())
# float16
test_arr = np.random.randn(5, 3, 30, 30).astype(np.float16)
x_test = Tensor(test_arr, dtype=mindspore.float16)
dy = np.random.randn(*padded_shape).astype(np.float16)
expected_dx = dy[:, 4:-3, 1:-1, :-2]
dx = net(x_test, Tensor(dy))
dx = dx[0].asnumpy()
np.testing.assert_array_equal(dx, expected_dx)

# int32
test_arr = np.random.randint(1, 3000, (5, 3, 30, 30)).astype(np.int32)
x_test = Tensor(test_arr, dtype=mindspore.int32)
dy = np.random.randn(*padded_shape).astype(np.int32)
expected_dx = dy[:, 4:-3, 1:-1, :-2]
dx = net(x_test, Tensor(dy)) dx = net(x_test, Tensor(dy))
dx = dx[0].asnumpy() dx = dx[0].asnumpy()
np.testing.assert_array_equal(dx, expected_dx) np.testing.assert_array_equal(dx, expected_dx)
@@ -167,7 +215,9 @@ def test_pad_3d_backprop():
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_pad_error_cases(): def test_pad_error_cases():
# Test against common errorneous inputs to catch correctly
"""
Test against common errorneous inputs to trigger correct errors
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")


# TEST 1 - Neg padding values # TEST 1 - Neg padding values


Loading…
Cancel
Save