style fix lint fixes added check in NN layer for > 4 paddings, plus lint fix fix python lint lint fix lint fix updating to pytest asserts to improve testing removed unnecc vars from test file fail checkstags/v1.0.0
| @@ -18,6 +18,7 @@ | |||||
| #include <stdint.h> | #include <stdint.h> | ||||
| #include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" | #include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" | ||||
| // For internal OP use, not user facing | |||||
| template <typename T> | template <typename T> | ||||
| __global__ void Pad(const size_t size, const T* input, const int num, const int channels, const int old_height, | __global__ void Pad(const size_t size, const T* input, const int num, const int channels, const int old_height, | ||||
| const int old_width, const int padded_height, const int padded_width, const int pad_top, | const int old_width, const int padded_height, const int padded_width, const int pad_top, | ||||
| @@ -37,6 +38,7 @@ __global__ void Pad(const size_t size, const T* input, const int num, const int | |||||
| return; | return; | ||||
| } | } | ||||
| // For internal OP use, not user facing | |||||
| template <typename T> | template <typename T> | ||||
| __global__ void PadNHWC(const size_t size, const T* input, const int num, const int old_height, const int old_width, | __global__ void PadNHWC(const size_t size, const T* input, const int num, const int old_height, const int old_width, | ||||
| const int channels, const int padded_height, const int padded_width, const int pad_top, | const int channels, const int padded_height, const int padded_width, const int pad_top, | ||||
| @@ -57,6 +59,37 @@ __global__ void PadNHWC(const size_t size, const T* input, const int num, const | |||||
| return; | return; | ||||
| } | } | ||||
| // Used by user facing 'Pad' API | |||||
| template <typename T> | |||||
| __global__ void PadGeneral(const size_t size, const T *input, const int num, const int channels_orig, | |||||
| const int pad_channel_before, const int pad_channel_after, const int old_height, | |||||
| const int old_width, const int padded_height, const int padded_width, const int pad_top, | |||||
| const int pad_left, float pad_value, T *output) { | |||||
| T pad_value_template = static_cast<T>(pad_value); | |||||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | |||||
| int block_num = (pos / padded_width) / padded_height; // total blocks = (batch * channels) | |||||
| const int padded_w = pos % padded_width; // x coordinate refered to by cur 'pos' | |||||
| const int padded_h = (pos / padded_width) % padded_height; // y coordinate refered to by cur 'pos' | |||||
| int channels_new = channels_orig + pad_channel_after + pad_channel_before; // new number of channels from padding | |||||
| int channel_num = block_num % channels_new; // current channel | |||||
| int batch_item = block_num / channels_new; // current item in batch | |||||
| int equiv_block_num = 0; // init variable to select equivalent block to copy data from from input | |||||
| if (padded_h - pad_top < 0 || padded_w - pad_left < 0 || padded_h - pad_top >= old_height || | |||||
| padded_w - pad_left >= old_width || channel_num <= pad_channel_before - 1 || | |||||
| channel_num > channels_orig + pad_channel_before - 1) { | |||||
| output[pos] = pad_value_template; | |||||
| } else { | |||||
| // on a block/x,y positon that isn't padding, copy data from the correct block/x,y pos the input | |||||
| // calculate from number of blocks of padding (due to channel padding) inserted prior | |||||
| equiv_block_num = block_num - (batch_item * (pad_channel_before + pad_channel_after)) - pad_channel_before; | |||||
| output[pos] = input[(equiv_block_num * old_height + padded_h - pad_top) * old_width + padded_w - pad_left]; | |||||
| } | |||||
| } | |||||
| return; | |||||
| } | |||||
| template <typename T> | template <typename T> | ||||
| __global__ void PadGradNHWC(const size_t size, const T* dy, const int num, const int old_height, const int old_width, | __global__ void PadGradNHWC(const size_t size, const T* dy, const int num, const int old_height, const int old_width, | ||||
| const int channels, const int padded_height, const int padded_width, const int pad_top, | const int channels, const int padded_height, const int padded_width, const int pad_top, | ||||
| @@ -102,6 +135,17 @@ void CalPadNHWC(const size_t size, const T* input, const int num, const int old_ | |||||
| return; | return; | ||||
| } | } | ||||
| template <typename T> | |||||
| void CalPadGeneral(const size_t size, const T *input, const int num, const int channels_orig, | |||||
| const int pad_channel_before, const int pad_channel_after, const int old_height, const int old_width, | |||||
| const int padded_height, const int padded_width, const int pad_top, const int pad_left, | |||||
| float pad_value, T *output, cudaStream_t cuda_stream) { | |||||
| PadGeneral<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, num, channels_orig, pad_channel_before, | |||||
| pad_channel_after, old_height, old_width, padded_height, | |||||
| padded_width, pad_top, pad_left, pad_value, output); | |||||
| return; | |||||
| } | |||||
| template <typename T> | template <typename T> | ||||
| void CalPadGradNHWC(const size_t size, const T* dy, const int num, const int old_height, const int old_width, | void CalPadGradNHWC(const size_t size, const T* dy, const int num, const int old_height, const int old_width, | ||||
| const int channels, const int padded_height, const int padded_width, const int pad_top, | const int channels, const int padded_height, const int padded_width, const int pad_top, | ||||
| @@ -152,3 +196,13 @@ template void CalPadGradNHWC<half>(const size_t size, const half* dy, const int | |||||
| const int old_width, const int channels, const int padded_height, | const int old_width, const int channels, const int padded_height, | ||||
| const int padded_width, const int pad_top, const int pad_left, half* dx, | const int padded_width, const int pad_top, const int pad_left, half* dx, | ||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| template void CalPadGeneral<float>(const size_t size, const float *input, const int num, const int channels_orig, | |||||
| const int pad_channel_before, const int pad_channel_after, const int old_height, | |||||
| const int old_width, const int padded_height, const int padded_width, | |||||
| const int pad_top, const int pad_left, float pad_value, float *output, | |||||
| cudaStream_t cuda_stream); | |||||
| template void CalPadGeneral<half>(const size_t size, const half *input, const int num, const int channels_orig, | |||||
| const int pad_channel_before, const int pad_channel_after, const int old_height, | |||||
| const int old_width, const int padded_height, const int padded_width, | |||||
| const int pad_top, const int pad_left, float pad_value, half *output, | |||||
| cudaStream_t cuda_stream); | |||||
| @@ -31,9 +31,13 @@ template <typename T> | |||||
| void CalPadNHWC(const size_t size, const T* input, const int num, const int old_height, const int old_width, | void CalPadNHWC(const size_t size, const T* input, const int num, const int old_height, const int old_width, | ||||
| const int channels, const int padded_height, const int padded_width, const int pad_top, const int pad_left, | const int channels, const int padded_height, const int padded_width, const int pad_top, const int pad_left, | ||||
| float pad_value, T* output, cudaStream_t cuda_stream); | float pad_value, T* output, cudaStream_t cuda_stream); | ||||
| template <typename T> | template <typename T> | ||||
| void CalPadGradNHWC(const size_t size, const T* input, const int num, const int old_height, const int old_width, | void CalPadGradNHWC(const size_t size, const T* input, const int num, const int old_height, const int old_width, | ||||
| const int channels, const int padded_height, const int padded_width, const int pad_top, | const int channels, const int padded_height, const int padded_width, const int pad_top, | ||||
| const int pad_left, T* output, cudaStream_t cuda_stream); | const int pad_left, T* output, cudaStream_t cuda_stream); | ||||
| template <typename T> | |||||
| void CalPadGeneral(const size_t size, const T *input, const int num, const int channels_orig, | |||||
| const int pad_channel_before, const int pad_channel_after, const int old_height, const int old_width, | |||||
| const int padded_height, const int padded_width, const int pad_top, const int pad_left, | |||||
| float pad_value, T *output, cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_ | #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_ | ||||
| @@ -42,9 +42,12 @@ class PadGpuFwdKernel : public GpuKernel { | |||||
| size_t size = output_size_ / sizeof(T); | size_t size = output_size_ / sizeof(T); | ||||
| int pad_left = paddings[3][0]; | int pad_left = paddings[3][0]; | ||||
| int pad_top = paddings[2][0]; | int pad_top = paddings[2][0]; | ||||
| int pad_channel_before = paddings[1][0]; | |||||
| int pad_channel_after = paddings[1][1]; | |||||
| T pad_value = 0.0; | T pad_value = 0.0; | ||||
| CalPad(size, input, input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3], output_shape_[2], | |||||
| output_shape_[3], pad_top, pad_left, pad_value, output, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| CalPadGeneral(size, input, input_shape_[0], input_shape_[1], pad_channel_before, pad_channel_after, input_shape_[2], | |||||
| input_shape_[3], output_shape_[2], output_shape_[3], pad_top, pad_left, pad_value, output, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -470,6 +470,8 @@ class Pad(Cell): | |||||
| for item in paddings: | for item in paddings: | ||||
| if len(item) != 2: | if len(item) != 2: | ||||
| raise ValueError('The shape of paddings must be (n, 2).') | raise ValueError('The shape of paddings must be (n, 2).') | ||||
| if len(paddings) > 4: | |||||
| raise ValueError('Only padding up to 4 dims is supported') | |||||
| if mode == "CONSTANT": | if mode == "CONSTANT": | ||||
| self.pad = P.Pad(self.paddings) | self.pad = P.Pad(self.paddings) | ||||
| else: | else: | ||||
| @@ -0,0 +1,204 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import pytest | |||||
| import numpy as np | |||||
| import mindspore | |||||
| import mindspore.nn as nn | |||||
| import mindspore.context as context | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops.composite import GradOperation | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_pad_basic(): | |||||
| # confirm array is being padded with 0's | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| test_arr = np.array([[1, 2], [3, 4]]).astype(np.float32) | |||||
| test_arr_expected = np.array( | |||||
| [[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]]).astype(np.float32) | |||||
| x_test = Tensor(test_arr, dtype=mindspore.float32) | |||||
| pad_op = nn.Pad(mode='CONSTANT', paddings=((1, 1), (1, 1))) | |||||
| y_test = pad_op(x_test).asnumpy() | |||||
| np.testing.assert_array_equal(y_test, test_arr_expected) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_pad_row(): | |||||
| # Confirm correct row padding | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||||
| test_arr_1 = np.random.rand(40, 40).astype(np.float32) | |||||
| test_paddings_1 = ((2, 3), (0, 0)) | |||||
| test_arr_2 = np.random.randn(3, 10, 30, 30).astype(np.float32) | |||||
| test_paddings_2 = ((0, 0), (0, 0), (3, 0), (0, 0)) | |||||
| pad_op_row_1 = nn.Pad(mode='CONSTANT', paddings=test_paddings_1) | |||||
| pad_op_row_2 = nn.Pad(mode='CONSTANT', paddings=test_paddings_2) | |||||
| x_test_1 = Tensor(np.array(test_arr_1), dtype=mindspore.float32) | |||||
| x_test_2 = Tensor(np.array(test_arr_2), dtype=mindspore.float32) | |||||
| y_test_1 = pad_op_row_1(x_test_1).asnumpy() | |||||
| y_test_2 = pad_op_row_2(x_test_2).asnumpy() | |||||
| # check size | |||||
| assert y_test_1.shape == (45, 40) | |||||
| assert y_test_2.shape == (3, 10, 33, 30) | |||||
| # check values - select correct sections | |||||
| np.testing.assert_equal(y_test_1[2:-3, :], test_arr_1) | |||||
| np.testing.assert_equal(y_test_2[:, :, 3:, :], test_arr_2) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_pad_column(): | |||||
| # Confirm correct column padding | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| test_arr_1 = np.random.randn(40, 40).astype(np.float32) | |||||
| test_paddings_1 = ((0, 0), (3, 3)) | |||||
| test_arr_2 = np.random.randn(3, 10, 30, 30).astype(np.float32) | |||||
| test_paddings_2 = ((0, 0), (0, 0), (0, 0), (6, 1)) | |||||
| pad_op_col_1 = nn.Pad(mode='CONSTANT', paddings=test_paddings_1) | |||||
| pad_op_col_2 = nn.Pad(mode='CONSTANT', paddings=test_paddings_2) | |||||
| x_test_1 = Tensor(np.array(test_arr_1), dtype=mindspore.float32) | |||||
| x_test_2 = Tensor(np.array(test_arr_2), dtype=mindspore.float32) | |||||
| y_test_1 = pad_op_col_1(x_test_1).asnumpy() | |||||
| y_test_2 = pad_op_col_2(x_test_2).asnumpy() | |||||
| # check size | |||||
| assert y_test_1.shape == (40, 46) | |||||
| assert y_test_2.shape == (3, 10, 30, 37) | |||||
| # check values - select correct sections - should match | |||||
| np.testing.assert_equal(y_test_1[:, 3:-3], test_arr_1) | |||||
| np.testing.assert_equal(y_test_2[:, :, :, 6:-1], test_arr_2) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_pad_3d_pad(): | |||||
| # Confirm correct 3d padding - row, column, channel | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||||
| test_arr = np.random.randn(5, 3, 30, 30).astype(np.float32) | |||||
| test_paddings = ((0, 0), (2, 1), (0, 1), (0, 2)) # padding 3 dims now | |||||
| pad_op_3d = nn.Pad(mode='CONSTANT', paddings=test_paddings) | |||||
| x_test = Tensor(np.array(test_arr), dtype=mindspore.float32) | |||||
| y_test = pad_op_3d(x_test).asnumpy() | |||||
| assert y_test.shape == (5, 6, 31, 32) | |||||
| np.testing.assert_equal(test_arr, y_test[:, 2:-1, :-1, :-2]) | |||||
| # For testing backprop | |||||
| class Grad(nn.Cell): | |||||
| def __init__(self, network): | |||||
| super(Grad, self).__init__() | |||||
| self.grad = GradOperation(get_all=True, sens_param=True) | |||||
| self.network = network | |||||
| def construct(self, input_, output_grad): | |||||
| return self.grad(self.network)(input_, output_grad) | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.pad = nn.Pad(mode="CONSTANT", paddings=( | |||||
| (0, 0), (4, 3), (1, 1), (0, 2))) | |||||
| def construct(self, x): | |||||
| return self.pad(x) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_pad_3d_backprop(): | |||||
| # Confirm correct 3d padding backprop | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| test_arr = np.random.randn(5, 3, 30, 30).astype(np.float32) | |||||
| x_test = Tensor(test_arr, dtype=mindspore.float32) | |||||
| padded_shape = (5, 10, 32, 32) | |||||
| dy = np.random.randn(*padded_shape).astype(np.float32) | |||||
| expected_dx = dy[:, 4:-3, 1:-1, :-2] | |||||
| net = Grad(Net()) | |||||
| dx = net(x_test, Tensor(dy)) | |||||
| dx = dx[0].asnumpy() | |||||
| np.testing.assert_array_equal(dx, expected_dx) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_pad_error_cases(): | |||||
| # Test against common errorneous inputs to catch correctly | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| # TEST 1 - Neg padding values | |||||
| test_op = nn.Pad(paddings=((0, 0), (-1, -1)), mode="CONSTANT") | |||||
| test_arr = np.random.randn(3, 3) | |||||
| test_arr_ms = Tensor(test_arr, dtype=mindspore.float32) | |||||
| with pytest.raises(ValueError): | |||||
| test_op(test_arr_ms) | |||||
| # TEST 2 - Mismatched input size and paddings - 1D tensor | |||||
| test_op = nn.Pad(paddings=((0, 0), (1, 0)), mode="CONSTANT") | |||||
| test_arr = np.random.randn(3) # 1D Tensor | |||||
| test_arr_ms = Tensor(test_arr, dtype=mindspore.float32) | |||||
| with pytest.raises(ValueError): | |||||
| test_op(test_arr_ms) | |||||
| # TEST 3 - Mismatched input size and paddings - 2D tensor, 3D padding | |||||
| test_op = nn.Pad(paddings=((0, 0), (1, 0)), mode="CONSTANT") # 2D Padding | |||||
| test_arr = np.random.randn(1, 3, 3) # 3D Tensor | |||||
| test_arr_ms = Tensor(test_arr, dtype=mindspore.float32) | |||||
| with pytest.raises(ValueError): | |||||
| test_op(test_arr_ms) | |||||
| # TEST 4 - 1D Paddings should not work | |||||
| with pytest.raises(TypeError): | |||||
| test_op = nn.Pad(paddings=((0, 2)), mode="CONSTANT") | |||||
| # TEST 5 - Padding beyond 4d - (added check in nn file in PR) | |||||
| with pytest.raises(ValueError): | |||||
| _ = nn.Pad(paddings=((0, 0), (0, 0,), (0, 0), (0, 0), | |||||
| (1, 0)), mode="CONSTANT") # 2D Padding | |||||