diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cu index 466bdfd25d..a9dbde7c42 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cu @@ -22,7 +22,7 @@ template __global__ void Pad(const size_t size, const T* input, const int num, const int channels, const int old_height, const int old_width, const int padded_height, const int padded_width, const int pad_top, - const int pad_left, float pad_value, T* output) { + const int pad_left, const float pad_value, T* output) { T pad_value_ = static_cast(pad_value); for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { int block_num = pos / padded_width / padded_height; @@ -64,8 +64,7 @@ template __global__ void PadGeneral(const size_t size, const T *input, const int num, const int channels_orig, const int pad_channel_before, const int pad_channel_after, const int old_height, const int old_width, const int padded_height, const int padded_width, const int pad_top, - const int pad_left, float pad_value, T *output) { - T pad_value_template = static_cast(pad_value); + const int pad_left, const T pad_value, T *output) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { int block_num = (pos / padded_width) / padded_height; // total blocks = (batch * channels) const int padded_w = pos % padded_width; // x coordinate refered to by cur 'pos' @@ -79,7 +78,7 @@ __global__ void PadGeneral(const size_t size, const T *input, const int num, con if (padded_h - pad_top < 0 || padded_w - pad_left < 0 || padded_h - pad_top >= old_height || padded_w - pad_left >= old_width || channel_num <= pad_channel_before - 1 || channel_num > channels_orig + pad_channel_before - 1) { - output[pos] = pad_value_template; + output[pos] = pad_value; } else { // on a block/x,y positon that isn't padding, copy data from the correct block/x,y pos the input // calculate from number of blocks of padding (due to channel padding) inserted prior @@ -139,7 +138,7 @@ template void CalPadGeneral(const size_t size, const T *input, const int num, const int channels_orig, const int pad_channel_before, const int pad_channel_after, const int old_height, const int old_width, const int padded_height, const int padded_width, const int pad_top, const int pad_left, - float pad_value, T *output, cudaStream_t cuda_stream) { + const T pad_value, T *output, cudaStream_t cuda_stream) { PadGeneral<<>>(size, input, num, channels_orig, pad_channel_before, pad_channel_after, old_height, old_width, padded_height, padded_width, pad_top, pad_left, pad_value, output); @@ -199,15 +198,15 @@ template void CalPadGradNHWC(const size_t size, const half* dy, const int template void CalPadGeneral(const size_t size, const float *input, const int num, const int channels_orig, const int pad_channel_before, const int pad_channel_after, const int old_height, const int old_width, const int padded_height, const int padded_width, - const int pad_top, const int pad_left, float pad_value, float *output, + const int pad_top, const int pad_left, const float pad_value, float *output, cudaStream_t cuda_stream); template void CalPadGeneral(const size_t size, const half *input, const int num, const int channels_orig, const int pad_channel_before, const int pad_channel_after, const int old_height, const int old_width, const int padded_height, const int padded_width, - const int pad_top, const int pad_left, float pad_value, half *output, + const int pad_top, const int pad_left, const half pad_value, half *output, cudaStream_t cuda_stream); template void CalPadGeneral(const size_t size, const int *input, const int num, const int channels_orig, const int pad_channel_before, const int pad_channel_after, const int old_height, const int old_width, const int padded_height, const int padded_width, - const int pad_top, const int pad_left, float pad_value, int *output, + const int pad_top, const int pad_left, const int pad_value, int *output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh index 3c12ab5376..f387779b5c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh @@ -39,5 +39,5 @@ template void CalPadGeneral(const size_t size, const T *input, const int num, const int channels_orig, const int pad_channel_before, const int pad_channel_after, const int old_height, const int old_width, const int padded_height, const int padded_width, const int pad_top, const int pad_left, - float pad_value, T *output, cudaStream_t cuda_stream); + const T pad_value, T *output, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pad_gpu_kernel.h index 84bd46a4ff..9ad4ee3195 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pad_gpu_kernel.h @@ -45,7 +45,7 @@ class PadGpuFwdKernel : public GpuKernel { int pad_top = paddings[2][0]; int pad_channel_before = paddings[1][0]; int pad_channel_after = paddings[1][1]; - T pad_value = 0.0; + const T pad_value = 0.0; CalPadGeneral(size, input, input_shape_[0], input_shape_[1], pad_channel_before, pad_channel_after, input_shape_[2], input_shape_[3], output_shape_[2], output_shape_[3], pad_top, pad_left, pad_value, output, reinterpret_cast(stream_ptr)); diff --git a/tests/st/ops/gpu/test_pad.py b/tests/st/ops/gpu/test_pad.py index c0d9cf6d9c..4a756f6110 100644 --- a/tests/st/ops/gpu/test_pad.py +++ b/tests/st/ops/gpu/test_pad.py @@ -28,17 +28,27 @@ from mindspore.ops.composite import GradOperation @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_pad_basic(): - # confirm array is being padded with 0's + """ + Test array is being padded with 0's + """ context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + # float32 test_arr = np.array([[1, 2], [3, 4]]).astype(np.float32) test_arr_expected = np.array( [[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]]).astype(np.float32) x_test = Tensor(test_arr, dtype=mindspore.float32) - pad_op = nn.Pad(mode='CONSTANT', paddings=((1, 1), (1, 1))) y_test = pad_op(x_test).asnumpy() + np.testing.assert_array_equal(y_test, test_arr_expected) + # float16 + test_arr = np.array([[1, 2], [3, 4]]).astype(np.float16) + test_arr_expected = np.array( + [[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]]).astype(np.float16) + x_test = Tensor(test_arr, dtype=mindspore.float16) + pad_op = nn.Pad(mode='CONSTANT', paddings=((1, 1), (1, 1))) + y_test = pad_op(x_test).asnumpy() np.testing.assert_array_equal(y_test, test_arr_expected) @@ -46,12 +56,13 @@ def test_pad_basic(): @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_pad_row(): - # Confirm correct row padding + """ + Test correct row padding + """ context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") test_arr_1 = np.random.rand(40, 40).astype(np.float32) test_paddings_1 = ((2, 3), (0, 0)) - test_arr_2 = np.random.randn(3, 10, 30, 30).astype(np.float32) test_paddings_2 = ((0, 0), (0, 0), (3, 0), (0, 0)) @@ -60,7 +71,6 @@ def test_pad_row(): x_test_1 = Tensor(np.array(test_arr_1), dtype=mindspore.float32) x_test_2 = Tensor(np.array(test_arr_2), dtype=mindspore.float32) - y_test_1 = pad_op_row_1(x_test_1).asnumpy() y_test_2 = pad_op_row_2(x_test_2).asnumpy() @@ -77,12 +87,13 @@ def test_pad_row(): @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_pad_column(): - # Confirm correct column padding + """ + Test correct column padding + """ context.set_context(mode=context.GRAPH_MODE, device_target="GPU") test_arr_1 = np.random.randn(40, 40).astype(np.float32) test_paddings_1 = ((0, 0), (3, 3)) - test_arr_2 = np.random.randn(3, 10, 30, 30).astype(np.float32) test_paddings_2 = ((0, 0), (0, 0), (0, 0), (6, 1)) @@ -91,7 +102,6 @@ def test_pad_column(): x_test_1 = Tensor(np.array(test_arr_1), dtype=mindspore.float32) x_test_2 = Tensor(np.array(test_arr_2), dtype=mindspore.float32) - y_test_1 = pad_op_col_1(x_test_1).asnumpy() y_test_2 = pad_op_col_2(x_test_2).asnumpy() @@ -108,15 +118,34 @@ def test_pad_column(): @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_pad_3d_pad(): - # Confirm correct 3d padding - row, column, channel - context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + """ + Test full 3d padding, with all 3 input types + """ + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + # float32 test_arr = np.random.randn(5, 3, 30, 30).astype(np.float32) test_paddings = ((0, 0), (2, 1), (0, 1), (0, 2)) # padding 3 dims now - pad_op_3d = nn.Pad(mode='CONSTANT', paddings=test_paddings) x_test = Tensor(np.array(test_arr), dtype=mindspore.float32) + y_test = pad_op_3d(x_test).asnumpy() + assert y_test.shape == (5, 6, 31, 32) + np.testing.assert_equal(test_arr, y_test[:, 2:-1, :-1, :-2]) + # float16 + test_arr = np.random.randn(5, 3, 30, 30).astype(np.float16) + test_paddings = ((0, 0), (2, 1), (0, 1), (0, 2)) + pad_op_3d = nn.Pad(mode='CONSTANT', paddings=test_paddings) + x_test = Tensor(np.array(test_arr), dtype=mindspore.float16) + y_test = pad_op_3d(x_test).asnumpy() + assert y_test.shape == (5, 6, 31, 32) + np.testing.assert_equal(test_arr, y_test[:, 2:-1, :-1, :-2]) + + # int32 + test_arr = np.random.randint(1, 3000, (5, 3, 30, 30)).astype(np.int32) + test_paddings = ((0, 0), (2, 1), (0, 1), (0, 2)) + pad_op_3d = nn.Pad(mode='CONSTANT', paddings=test_paddings) + x_test = Tensor(np.array(test_arr), dtype=mindspore.int32) y_test = pad_op_3d(x_test).asnumpy() assert y_test.shape == (5, 6, 31, 32) np.testing.assert_equal(test_arr, y_test[:, 2:-1, :-1, :-2]) @@ -147,17 +176,36 @@ class Net(nn.Cell): @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_pad_3d_backprop(): - # Confirm correct 3d padding backprop + """ + Confirm correct 3d padding backprop + """ context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net = Grad(Net()) + padded_shape = (5, 10, 32, 32) + # float32 test_arr = np.random.randn(5, 3, 30, 30).astype(np.float32) x_test = Tensor(test_arr, dtype=mindspore.float32) - - padded_shape = (5, 10, 32, 32) dy = np.random.randn(*padded_shape).astype(np.float32) expected_dx = dy[:, 4:-3, 1:-1, :-2] + dx = net(x_test, Tensor(dy)) + dx = dx[0].asnumpy() + np.testing.assert_array_equal(dx, expected_dx) - net = Grad(Net()) + # float16 + test_arr = np.random.randn(5, 3, 30, 30).astype(np.float16) + x_test = Tensor(test_arr, dtype=mindspore.float16) + dy = np.random.randn(*padded_shape).astype(np.float16) + expected_dx = dy[:, 4:-3, 1:-1, :-2] + dx = net(x_test, Tensor(dy)) + dx = dx[0].asnumpy() + np.testing.assert_array_equal(dx, expected_dx) + + # int32 + test_arr = np.random.randint(1, 3000, (5, 3, 30, 30)).astype(np.int32) + x_test = Tensor(test_arr, dtype=mindspore.int32) + dy = np.random.randn(*padded_shape).astype(np.int32) + expected_dx = dy[:, 4:-3, 1:-1, :-2] dx = net(x_test, Tensor(dy)) dx = dx[0].asnumpy() np.testing.assert_array_equal(dx, expected_dx) @@ -167,7 +215,9 @@ def test_pad_3d_backprop(): @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_pad_error_cases(): - # Test against common errorneous inputs to catch correctly + """ + Test against common errorneous inputs to trigger correct errors + """ context.set_context(mode=context.GRAPH_MODE, device_target="GPU") # TEST 1 - Neg padding values