| @@ -18,130 +18,34 @@ | |||
| #include "maxpool_with_argmax_grad_impl.cuh" | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| #include "include/cuda_fp16.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" | |||
| template <typename T, typename S> | |||
| __global__ void MaxPoolWithArgmaxGrad(const T* x, | |||
| const T* dy, | |||
| __global__ void MaxPoolWithArgmaxGrad(const T* dy, | |||
| const S* index, | |||
| const int n, | |||
| const int c, | |||
| const int xHeight, | |||
| const int xWidth, | |||
| const int dyHeight, | |||
| const int dyWidth, | |||
| const int windowHeight, | |||
| const int windowWidth, | |||
| const int strideHeight, | |||
| const int strideWidth, | |||
| const int padTop, | |||
| const int padLeft, | |||
| const int xNCHW, | |||
| const int xCHW, | |||
| const int xHW, | |||
| const int dyCHW, | |||
| const int dyHW, | |||
| const int dyNCHW, | |||
| T* dx) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; | |||
| pos < (xNCHW); | |||
| pos += blockDim.x * gridDim.x) { | |||
| const int posn = pos / xCHW; | |||
| const int posc = pos / xHW % c; | |||
| const int posh = pos / xHeight % xHeight; | |||
| const int posw = pos % xWidth; | |||
| const S posIdx = posh*xWidth + posw; | |||
| int hstart = posh+padTop; | |||
| if (hstart < windowHeight) { | |||
| hstart = 0; | |||
| } else { | |||
| hstart = (hstart-windowHeight)/strideHeight + 1; | |||
| } | |||
| int wstart = posw+padLeft; | |||
| if (wstart < windowWidth) { | |||
| wstart = 0; | |||
| } else { | |||
| wstart = (wstart-windowWidth)/strideWidth + 1; | |||
| } | |||
| const int hend = min((posh+padTop)/strideHeight +1, dyHeight); | |||
| const int wend = min((posw+padLeft)/strideWidth +1, dyWidth); | |||
| const int channelStart = posn*dyCHW + posc*dyHW; | |||
| T dySum = static_cast<T>(0.0); | |||
| for (int hcur = hstart; hcur < hend; ++hcur) { | |||
| for (int wcur = wstart; wcur < wend; ++wcur) { | |||
| const int curIdx = hcur*dyWidth + wcur; | |||
| S maxIdx = index[channelStart+curIdx]; | |||
| if (maxIdx == posIdx) { | |||
| dySum += dy[channelStart+curIdx]; | |||
| } | |||
| } | |||
| } | |||
| dx[pos] = dySum; | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (dyNCHW); pos += blockDim.x * gridDim.x) { | |||
| const S idx = index[pos]; | |||
| const int posn = pos / dyCHW; | |||
| MsAtomicAdd(dx + posn*xCHW + static_cast<int>(idx), dy[pos]); | |||
| } | |||
| return; | |||
| } | |||
| template <> | |||
| __global__ void MaxPoolWithArgmaxGrad(const half* x, | |||
| const half* dy, | |||
| const int* index, | |||
| const int n, | |||
| const int c, | |||
| const int xHeight, | |||
| const int xWidth, | |||
| const int dyHeight, | |||
| const int dyWidth, | |||
| const int windowHeight, | |||
| const int windowWidth, | |||
| const int strideHeight, | |||
| const int strideWidth, | |||
| const int padTop, | |||
| const int padLeft, | |||
| const int xNCHW, | |||
| const int xCHW, | |||
| const int xHW, | |||
| const int dyCHW, | |||
| const int dyHW, | |||
| half* dx) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; | |||
| pos < (xNCHW); | |||
| pos += blockDim.x * gridDim.x) { | |||
| const int posn = pos / xCHW; | |||
| const int posc = pos / xHW % c; | |||
| const int posh = pos / xHeight % xHeight; | |||
| const int posw = pos % xWidth; | |||
| const int posIdx = posh*xWidth + posw; | |||
| int hstart = posh+padTop; | |||
| if (hstart < windowHeight) { | |||
| hstart = 0; | |||
| } else { | |||
| hstart = (hstart-windowHeight)/strideHeight + 1; | |||
| template <typename T> | |||
| __global__ void InitOutput(const int size, T *output) { | |||
| T zero = 0; | |||
| for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id < size; id += blockDim.x * gridDim.x) { | |||
| output[id] = zero; | |||
| } | |||
| int wstart = posw+padLeft; | |||
| if (wstart < windowWidth) { | |||
| wstart = 0; | |||
| } else { | |||
| wstart = (wstart-windowWidth)/strideWidth + 1; | |||
| } | |||
| const int hend = min((posh+padTop)/strideHeight +1, dyHeight); | |||
| const int wend = min((posw+padLeft)/strideWidth +1, dyWidth); | |||
| const int channelStart = posn*dyCHW + posc*dyHW; | |||
| float dySum = 0.0f; | |||
| for (int hcur = hstart; hcur < hend; ++hcur) { | |||
| for (int wcur = wstart; wcur < wend; ++wcur) { | |||
| const int curIdx = hcur*dyWidth + wcur; | |||
| int maxIdx = index[channelStart+curIdx]; | |||
| if (maxIdx == posIdx) { | |||
| dySum += __half2float(dy[channelStart+curIdx]); | |||
| } | |||
| } | |||
| } | |||
| dx[pos] = __float2half(dySum); | |||
| } | |||
| return; | |||
| return; | |||
| } | |||
| template <typename T, typename S> | |||
| void CalMaxPoolWithArgmaxGrad(const T* x, | |||
| const T* dy, | |||
| void CalMaxPoolWithArgmaxGrad(const T* dy, | |||
| const S* index, | |||
| const int n, | |||
| const int c, | |||
| @@ -149,12 +53,6 @@ void CalMaxPoolWithArgmaxGrad(const T* x, | |||
| const int xWidth, | |||
| const int dyHeight, | |||
| const int dyWidth, | |||
| const int windowHeight, | |||
| const int windowWidth, | |||
| const int strideHeight, | |||
| const int strideWidth, | |||
| const int padTop, | |||
| const int padLeft, | |||
| T* dx, | |||
| cudaStream_t cuda_stream) { | |||
| const int xHW = xHeight*xWidth; | |||
| @@ -162,36 +60,22 @@ void CalMaxPoolWithArgmaxGrad(const T* x, | |||
| const int xNCHW = n*xCHW; | |||
| const int dyHW = dyHeight*dyWidth; | |||
| const int dyCHW = c*dyHW; | |||
| MaxPoolWithArgmaxGrad<<<GET_BLOCKS(xNCHW), | |||
| const int dyNCHW = n*dyCHW; | |||
| InitOutput<<<GET_BLOCKS(xNCHW), GET_THREADS, 0, cuda_stream>>>(xNCHW, dx); | |||
| MaxPoolWithArgmaxGrad<<<GET_BLOCKS(dyNCHW), | |||
| GET_THREADS, | |||
| 0, | |||
| cuda_stream>>>( | |||
| x, | |||
| dy, | |||
| index, | |||
| n, | |||
| c, | |||
| xHeight, | |||
| xWidth, | |||
| dyHeight, | |||
| dyWidth, | |||
| windowHeight, | |||
| windowWidth, | |||
| strideHeight, | |||
| strideWidth, | |||
| padTop, | |||
| padLeft, | |||
| xNCHW, | |||
| xCHW, | |||
| xHW, | |||
| dyCHW, | |||
| dyHW, | |||
| dyNCHW, | |||
| dx); | |||
| return; | |||
| } | |||
| template void CalMaxPoolWithArgmaxGrad<float, int>(const float* x, | |||
| const float* dy, | |||
| template void CalMaxPoolWithArgmaxGrad<float, int>(const float* dy, | |||
| const int* index, | |||
| const int n, | |||
| const int c, | |||
| @@ -199,16 +83,9 @@ template void CalMaxPoolWithArgmaxGrad<float, int>(const float* x, | |||
| const int xWidth, | |||
| const int dyHeight, | |||
| const int dyWidth, | |||
| const int windowHeight, | |||
| const int windowWidth, | |||
| const int strideHeight, | |||
| const int strideWidth, | |||
| const int padTop, | |||
| const int padLeft, | |||
| float* dx, | |||
| cudaStream_t cuda_stream); | |||
| template void CalMaxPoolWithArgmaxGrad<half, int>(const half* x, | |||
| const half* dy, | |||
| template void CalMaxPoolWithArgmaxGrad<half, int>(const half* dy, | |||
| const int* index, | |||
| const int n, | |||
| const int c, | |||
| @@ -216,11 +93,5 @@ template void CalMaxPoolWithArgmaxGrad<half, int>(const half* x, | |||
| const int xWidth, | |||
| const int dyHeight, | |||
| const int dyWidth, | |||
| const int windowHeight, | |||
| const int windowWidth, | |||
| const int strideHeight, | |||
| const int strideWidth, | |||
| const int padTop, | |||
| const int padLeft, | |||
| half* dx, | |||
| cudaStream_t cuda_stream); | |||
| @@ -17,9 +17,7 @@ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MAXPOOLWITHARGMAX_GRAD_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MAXPOOLWITHARGMAX_GRAD_H_ | |||
| template <typename T, typename S> | |||
| void CalMaxPoolWithArgmaxGrad(const T* x, const T* dy, const S* index, const int n, const int c, const int xHeight, | |||
| const int xWidth, const int dyHeight, const int dyWidth, const int windowHeight, | |||
| const int windowWidth, const int strideHeight, const int strideWidth, const int padTop, | |||
| const int padLeft, T* dx, cudaStream_t cuda_stream); | |||
| void CalMaxPoolWithArgmaxGrad(const T* dy, const S* index, const int n, const int c, const int xHeight, | |||
| const int xWidth, const int dyHeight, const int dyWidth, T* dx, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MAXPOOLWITHARGMAX_GRAD_H_ | |||
| @@ -42,7 +42,7 @@ __global__ void MaxPoolWithArgmax(const T* input, | |||
| pos += blockDim.x * gridDim.x) { | |||
| const int posn = pos / outputCHW; | |||
| const int posc = pos / outputHW % c; | |||
| const int posh = pos / outputHeight % outputHeight; | |||
| const int posh = pos / outputWidth % outputHeight; | |||
| const int posw = pos % outputWidth; | |||
| int hstart = posh * strideHeight - padTop; | |||
| int wstart = posw * strideWidth - padLeft; | |||
| @@ -50,12 +50,12 @@ __global__ void MaxPoolWithArgmax(const T* input, | |||
| const int wend = min(wstart + windowWidth, w); | |||
| hstart = max(hstart, 0); | |||
| wstart = max(wstart, 0); | |||
| S inputStart = posn*c*h*w + posc*h*w; | |||
| S maxIdx = hstart*w + wstart; | |||
| S inputStart = posn*c*h*w; | |||
| S maxIdx = posc*h*w + hstart*w + wstart; | |||
| T maxData = input[inputStart+maxIdx]; | |||
| for (int hcur = hstart; hcur < hend; ++hcur) { | |||
| for (int wcur = wstart; wcur < wend; ++wcur) { | |||
| S inputIdx = hcur*w + wcur; | |||
| S inputIdx = posc*h*w + hcur*w + wcur; | |||
| T inputData = input[inputStart+inputIdx]; | |||
| if (inputData > maxData) { | |||
| maxIdx = inputIdx; | |||
| @@ -48,12 +48,10 @@ class MaxPoolWithArgmaxGradGpuKernel : public GpuKernel { | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) { | |||
| T *x_addr = GetDeviceAddress<T>(inputs, 0); | |||
| T *dy_addr = GetDeviceAddress<T>(inputs, 1); | |||
| S *index_addr = GetDeviceAddress<S>(inputs, 2); | |||
| T *dx_addr = GetDeviceAddress<T>(outputs, 0); | |||
| CalMaxPoolWithArgmaxGrad(x_addr, dy_addr, index_addr, n_, c_, x_height_, x_width_, dy_height_, dy_width_, | |||
| window_height_, window_width_, stride_height_, stride_width_, pad_top_, pad_left_, dx_addr, | |||
| CalMaxPoolWithArgmaxGrad(dy_addr, index_addr, n_, c_, x_height_, x_width_, dy_height_, dy_width_, dx_addr, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| @@ -95,57 +93,19 @@ class MaxPoolWithArgmaxGradGpuKernel : public GpuKernel { | |||
| x_width_ = SizeToInt(x_shape[3]); | |||
| dy_height_ = SizeToInt(dy_shape[2]); | |||
| dy_width_ = SizeToInt(dy_shape[3]); | |||
| std::vector<int> window; | |||
| std::vector<int64_t> window_me = | |||
| GetValue<std::vector<int64_t>>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ksize")); | |||
| (void)std::transform(window_me.begin(), window_me.end(), std::back_inserter(window), | |||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||
| window_height_ = window[1]; | |||
| window_width_ = window[2]; | |||
| std::vector<int> stride; | |||
| std::vector<int64_t> stride_me = | |||
| GetValue<std::vector<int64_t>>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides")); | |||
| (void)std::transform(stride_me.begin(), stride_me.end(), std::back_inserter(stride), | |||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||
| stride_height_ = stride[1]; | |||
| stride_width_ = stride[2]; | |||
| pad_mode_ = GetValue<std::string>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("padding")); | |||
| pad_top_ = 0; | |||
| pad_left_ = 0; | |||
| if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { | |||
| SetPad(); | |||
| } | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(x_size_); | |||
| input_size_list_.push_back(dy_size_); | |||
| input_size_list_.push_back(index_size_); | |||
| output_size_list_.push_back(dx_size_); | |||
| } | |||
| private: | |||
| void SetPad() { | |||
| pad_height_ = std::max<int>( | |||
| 0, (((x_height_ / stride_height_) * stride_height_ == x_height_ ? (x_height_ / stride_height_) | |||
| : (x_height_ / stride_height_) + 1) - | |||
| 1) * | |||
| stride_height_ + | |||
| window_height_ - x_height_); | |||
| pad_width_ = | |||
| std::max<int>(0, (((x_width_ / stride_width_) * stride_width_ == x_width_ ? (x_width_ / stride_width_) | |||
| : (x_width_ / stride_width_) + 1) - | |||
| 1) * | |||
| stride_width_ + | |||
| window_width_ - x_width_); | |||
| pad_top_ = pad_height_ / 2; | |||
| pad_left_ = pad_width_ / 2; | |||
| } | |||
| std::string pad_mode_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| @@ -156,14 +116,6 @@ class MaxPoolWithArgmaxGradGpuKernel : public GpuKernel { | |||
| int x_width_; | |||
| int dy_height_; | |||
| int dy_width_; | |||
| int window_height_; | |||
| int window_width_; | |||
| int pad_height_; | |||
| int pad_width_; | |||
| int pad_top_; | |||
| int pad_left_; | |||
| int stride_height_; | |||
| int stride_width_; | |||
| size_t x_size_; | |||
| size_t dy_size_; | |||
| @@ -16,28 +16,83 @@ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| import mindspore.ops.operations as P | |||
| from mindspore import context, Tensor | |||
| from mindspore.nn import Cell | |||
| from mindspore.ops import composite as C | |||
| class Net_Pool(nn.Cell): | |||
| def __init__(self): | |||
| super(Net_Pool, self).__init__() | |||
| self.maxpool_fun = P.MaxPoolWithArgmax(ksize=2, strides=2, padding="VALID") | |||
| def construct(self, x): | |||
| return self.maxpool_fun(x) | |||
| class MaxPoolWithArgMax_Net(Cell): | |||
| def __init__(self, padding, ksize, strides): | |||
| super(MaxPoolWithArgMax_Net, self).__init__() | |||
| self.maxpool_with_argmax = P.MaxPoolWithArgmax(padding=padding, ksize=ksize, strides=strides) | |||
| def construct(self, input_data): | |||
| output, argmax = self.maxpool_with_argmax(input_data) | |||
| return output, argmax | |||
| class Net_Pool2(nn.Cell): | |||
| def __init__(self): | |||
| super(Net_Pool2, self).__init__() | |||
| self.maxpool_fun = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="SAME") | |||
| def construct(self, x): | |||
| return self.maxpool_fun(x) | |||
| class Grad(Cell): | |||
| def __init__(self, network, argmax): | |||
| super(Grad, self).__init__() | |||
| self.grad = C.GradOperation(get_all=True, sens_param=True) | |||
| self.network = network | |||
| self.sens = (Tensor(np.ones(argmax.shape).astype(np.float32)), | |||
| Tensor(np.ones(argmax.shape).astype(np.int32))) | |||
| def construct(self, input_data): | |||
| gout = self.grad(self.network)(input_data, self.sens) | |||
| return gout | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_train_forward_backward(): | |||
| x = np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4).astype(np.float32) | |||
| expect_output = np.array([[[[5, 6, 7, 7], | |||
| [9, 10, 11, 11], | |||
| [9, 10, 11, 11]], | |||
| [[17, 18, 19, 19], | |||
| [21, 22, 23, 23], | |||
| [21, 22, 23, 23]], | |||
| [[29, 30, 31, 31], | |||
| [33, 34, 35, 35], | |||
| [33, 34, 35, 35]]]]).astype(np.float32) | |||
| expect_argmax = np.array([[[[5, 6, 7, 7], | |||
| [9, 10, 11, 11], | |||
| [9, 10, 11, 11]], | |||
| [[17, 18, 19, 19], | |||
| [21, 22, 23, 23], | |||
| [21, 22, 23, 23]], | |||
| [[29, 30, 31, 31], | |||
| [33, 34, 35, 35], | |||
| [33, 34, 35, 35]]]]).astype(np.int32) | |||
| expect_dx = np.array([[[[0, 0, 0, 0], | |||
| [0, 1, 1, 2], | |||
| [0, 2, 2, 4]], | |||
| [[0, 0, 0, 0], | |||
| [0, 1, 1, 2], | |||
| [0, 2, 2, 4]], | |||
| [[0, 0, 0, 0], | |||
| [0, 1, 1, 2], | |||
| [0, 2, 2, 4]]]]).astype(np.float32) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = MaxPoolWithArgMax_Net(padding="SAME", ksize=2, strides=1) | |||
| output_tensor, argmax_tensor = net(Tensor(x)) | |||
| assert output_tensor.shape == expect_output.shape | |||
| assert argmax_tensor.shape == expect_argmax.shape | |||
| error = np.ones(shape=expect_output.shape) * 1.0e-5 | |||
| diff_output = output_tensor.asnumpy() - expect_output | |||
| assert np.all(diff_output < error) | |||
| net_grad = Grad(net, argmax_tensor) | |||
| dx = net_grad(Tensor(x))[0].asnumpy() | |||
| assert dx.shape == expect_dx.shape | |||
| diff = dx - expect_dx | |||
| assert np.all(diff < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @@ -73,8 +128,8 @@ def test_maxpool_with_argmax_2d(): | |||
| ]]])) | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| maxpool2d = Net_Pool() | |||
| maxpool2d2 = Net_Pool2() | |||
| maxpool2d = MaxPoolWithArgMax_Net(padding="VALID", ksize=2, strides=2) | |||
| maxpool2d2 = MaxPoolWithArgMax_Net(padding="SAME", ksize=3, strides=2) | |||
| output2, index2 = maxpool2d2(x) | |||
| output, index = maxpool2d(x) | |||
| assert (output.asnumpy() == expect_result).all() | |||
| @@ -83,8 +138,8 @@ def test_maxpool_with_argmax_2d(): | |||
| assert (index2.asnumpy() == expect__index_result2).all() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| maxpool2d = Net_Pool() | |||
| maxpool2d2 = Net_Pool2() | |||
| maxpool2d = MaxPoolWithArgMax_Net(padding="VALID", ksize=2, strides=2) | |||
| maxpool2d2 = MaxPoolWithArgMax_Net(padding="SAME", ksize=3, strides=2) | |||
| output2, index2 = maxpool2d2(x) | |||
| output, index = maxpool2d(x) | |||
| assert (output.asnumpy() == expect_result).all() | |||
| @@ -126,8 +181,8 @@ def test_maxpool_with_argmax_2d_fp16(): | |||
| ]]])) | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| maxpool2d = Net_Pool() | |||
| maxpool2d2 = Net_Pool2() | |||
| maxpool2d = MaxPoolWithArgMax_Net(padding="VALID", ksize=2, strides=2) | |||
| maxpool2d2 = MaxPoolWithArgMax_Net(padding="SAME", ksize=3, strides=2) | |||
| output2, index2 = maxpool2d2(x) | |||
| output, index = maxpool2d(x) | |||
| assert (output.asnumpy() == expect_result).all() | |||
| @@ -136,12 +191,11 @@ def test_maxpool_with_argmax_2d_fp16(): | |||
| assert (index2.asnumpy() == expect__index_result2).all() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| maxpool2d = Net_Pool() | |||
| maxpool2d2 = Net_Pool2() | |||
| maxpool2d = MaxPoolWithArgMax_Net(padding="VALID", ksize=2, strides=2) | |||
| maxpool2d2 = MaxPoolWithArgMax_Net(padding="SAME", ksize=3, strides=2) | |||
| output2, index2 = maxpool2d2(x) | |||
| output, index = maxpool2d(x) | |||
| assert (output.asnumpy() == expect_result).all() | |||
| assert (output2.asnumpy() == expect_result2).all() | |||
| assert (index.asnumpy() == expect_index_result).all() | |||
| assert (index2.asnumpy() == expect__index_result2).all() | |||