BREAKING CHANGE:
GitOrigin-RevId: 54d726d2fe
tags/v1.5.0
| @@ -666,7 +666,7 @@ public: | |||
| * http://deeplearning.net/software/theano/library/tensor/nnet/neighbours.html | |||
| * | |||
| * \f$ dst_{n, c, oh, ow, wh, ww} = src_{n, c, ih+wh, iw+fw}\f$, | |||
| * where \f$ ih=-pad_h+oh*stride_h, iw=-pad_w+ow*stride_w\f$. | |||
| * where \f$ ih=-pad_h+oh*stride_h+(wh-1)*(dilation_h-1), iw=-pad_w+ow*stride_w+(ww-1)*(dilation_w-1)\f$. | |||
| */ | |||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| @@ -698,6 +698,53 @@ protected: | |||
| size_t workspace_in_bytes); | |||
| }; | |||
| class SlidingWindowTransposeBase : public OperatorBase { | |||
| DEF_OPR_IMPL_CTOR(SlidingWindowTransposeBase, OperatorBase); | |||
| DEF_OPR_PARAM(SlidingWindowTranspose); | |||
| protected: | |||
| void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst); | |||
| void check_layout_fwd(const TensorLayout& filter, const TensorLayout& dst); | |||
| }; | |||
| class SlidingWindowTransposeForward : public SlidingWindowTransposeBase { | |||
| DEF_OPR_IMPL(SlidingWindowTransposeForward, SlidingWindowTransposeBase, 1, 1); | |||
| public: | |||
| /** | |||
| * \param[in] src (N, C, IH, IW, window_h, window_w) | |||
| * \param[out] dst (N, C, OH, OW) | |||
| */ | |||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& dst) = 0; | |||
| void deduce_layout(const TensorLayout& src, TensorLayout& dst); | |||
| protected: | |||
| void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||
| size_t workspace_in_bytes); | |||
| }; | |||
| using SlidingWindowTranspose = SlidingWindowTransposeForward; | |||
| class SlidingWindowTransposeBackward : public SlidingWindowTransposeBase { | |||
| DEF_OPR_IMPL(SlidingWindowTransposeBackward, SlidingWindowTransposeBase, 1, 1); | |||
| public: | |||
| /** | |||
| * \param[in] diff the backpropagated gradient wrt. dst | |||
| * \param[out] grad the backpropagated gradient wrt. src | |||
| */ | |||
| virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& diff, | |||
| const TensorLayout& grad) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout& diff, const TensorLayout& grad, | |||
| size_t workspace_in_bytes); | |||
| }; | |||
| /** | |||
| * \brief base class for Pooling | |||
| */ | |||
| @@ -224,6 +224,10 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||
| add_fields('uint32', 'pad_h', 0, 'pad_w', 0, 'stride_h', 1, 'stride_w', 1, | |||
| 'dilate_h', 1, 'dilate_w', 1, 'window_h', 3, 'window_w', 3)) | |||
| (pdef('SlidingWindowTranspose'). | |||
| add_fields('uint32', 'out_h', 0, 'out_w', 0, 'pad_h', 0, 'pad_w', 0, 'stride_h', 1, 'stride_w', 1, | |||
| 'dilate_h', 1, 'dilate_w', 1, 'window_h', 3, 'window_w', 3)) | |||
| (pdef('Pooling', version=0, is_legacy=True). | |||
| add_enum( | |||
| 'Mode', | |||
| @@ -104,6 +104,8 @@ private: | |||
| cb(ConvBiasForward) \ | |||
| cb(Images2NeibsForward) \ | |||
| cb(Images2NeibsBackward) \ | |||
| cb(SlidingWindowTransposeForward) \ | |||
| cb(SlidingWindowTransposeBackward) \ | |||
| cb(ElemwiseForward) \ | |||
| cb(ElemwiseMultiType) \ | |||
| cb(AddUpdateForward) \ | |||
| @@ -39,6 +39,8 @@ DEF(SeparableConvForward, 4, true, true); | |||
| DEF(SeparableFilterForward, 4, true, true); | |||
| DEF(Images2NeibsForward, 2, true, true); | |||
| DEF(Images2NeibsBackward, 2, true, false); | |||
| DEF(SlidingWindowTransposeForward, 2, true, true); | |||
| DEF(SlidingWindowTransposeBackward, 2, true, false); | |||
| DEF(PoolingForward, 2, true, true); | |||
| DEF(PoolingBackward, 4, true, false); | |||
| DEF(AdaptivePoolingForward, 2, true, false); | |||
| @@ -0,0 +1,75 @@ | |||
| /** | |||
| * \file dnn/src/common/sliding_window_transpose.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megdnn/oprs.h" | |||
| #include "src/common/utils.h" | |||
| namespace megdnn { | |||
| void SlidingWindowTransposeBase::deduce_layout_fwd(const TensorLayout &src, | |||
| TensorLayout &dst) | |||
| { | |||
| auto errmsg = [&]() { | |||
| return megdnn_layout_msg(src) + ", " + | |||
| "out_h=" + std::to_string(param().out_h) + ", " + | |||
| "out_w=" + std::to_string(param().out_w) + ", " + | |||
| "pad_h=" + std::to_string(param().pad_h) + ", " + | |||
| "pad_w=" + std::to_string(param().pad_w) + ", " + | |||
| "stride_h=" + std::to_string(param().stride_h) + ", " + | |||
| "stride_w=" + std::to_string(param().stride_w) + ", " + | |||
| "window_h=" + std::to_string(param().window_h) + ", " + | |||
| "window_w=" + std::to_string(param().window_w); | |||
| }; | |||
| MEGDNN_MARK_USED_VAR(errmsg); | |||
| megdnn_assert_contiguous(src); | |||
| megdnn_assert(src.ndim == 6_z, "%s", errmsg().c_str()); | |||
| size_t n = src[0], ic = src[1]; | |||
| size_t oh = this->param().out_h; | |||
| size_t ow = this->param().out_w; | |||
| dst = TensorLayout(TensorShape({n, ic, oh, ow}), src.dtype); | |||
| } | |||
| void SlidingWindowTransposeBase::check_layout_fwd(const TensorLayout &src, | |||
| const TensorLayout &dst) | |||
| { | |||
| TensorLayout dst_expected; | |||
| deduce_layout_fwd(src, dst_expected); | |||
| megdnn_assert_eq_layout(dst_expected, dst); | |||
| } | |||
| void SlidingWindowTransposeForward::deduce_layout(const TensorLayout &src, | |||
| TensorLayout &dst) | |||
| { | |||
| deduce_layout_fwd(src, dst); | |||
| } | |||
| void SlidingWindowTransposeForward::check_exec(const TensorLayout &src, | |||
| const TensorLayout &dst, | |||
| size_t workspace_in_bytes) | |||
| { | |||
| check_layout_fwd(src, dst); | |||
| auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); | |||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
| } | |||
| void SlidingWindowTransposeBackward::check_exec(const TensorLayout &diff, | |||
| const TensorLayout &grad, | |||
| size_t workspace_in_bytes) | |||
| { | |||
| check_layout_fwd(grad, diff); | |||
| auto required_workspace_in_bytes = get_workspace_in_bytes(grad, diff); | |||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
| } | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -71,6 +71,7 @@ | |||
| #include "src/cuda/separable_conv/opr_impl.h" | |||
| #include "src/cuda/separable_filter/opr_impl.h" | |||
| #include "src/cuda/sleep/opr_impl.h" | |||
| #include "src/cuda/sliding_window_transpose/opr_impl.h" | |||
| #include "src/cuda/split/opr_impl.h" | |||
| #include "src/cuda/svd/opr_impl.h" | |||
| #include "src/cuda/tensor_remap/opr_impl.h" | |||
| @@ -0,0 +1,76 @@ | |||
| /** | |||
| * \file dnn/src/cuda/sliding_window_transpose/opr_impl.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "src/cuda/sliding_window_transpose/opr_impl.h" | |||
| #include "src/cuda/utils.h" | |||
| #include "src/cuda/sliding_window_transpose/sliding_window_transpose.cuh" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| void SlidingWindowTransposeForwardImpl::exec(_megdnn_tensor_in src, | |||
| _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) | |||
| { | |||
| check_exec(src.layout, dst.layout, workspace.size); | |||
| auto stream = cuda_stream(handle()); | |||
| int N = src.layout[0], C = src.layout[1], | |||
| OH = src.layout[2], OW = src.layout[3]; | |||
| int IH = dst.layout[2], IW = dst.layout[3]; | |||
| int ph = param().pad_h, pw = param().pad_w; | |||
| int sh = param().stride_h, sw = param().stride_w; | |||
| int dh = param().dilate_h, dw = param().dilate_w; | |||
| int wh = param().window_h, ww = param().window_w; | |||
| #define cb(DType) \ | |||
| if (src.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \ | |||
| using T = DTypeTrait<DType>::ctype; \ | |||
| sliding_window_transpose::forward(src.ptr<T>(), dst.ptr<T>(), \ | |||
| N, C, IH, IW, OH, OW, \ | |||
| ph, pw, sh, sw, dh, dw, wh, ww, \ | |||
| stream); \ | |||
| return; \ | |||
| } | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb); | |||
| #undef cb | |||
| megdnn_assert_internal(0); | |||
| } | |||
| void SlidingWindowTransposeBackwardImpl::exec(_megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) | |||
| { | |||
| check_exec(diff.layout, grad.layout, workspace.size); | |||
| auto stream = cuda_stream(handle()); | |||
| int N = grad.layout[0], C = grad.layout[1], | |||
| OH = grad.layout[2], OW = grad.layout[3]; | |||
| int IH = diff.layout[2], IW = diff.layout[3]; | |||
| int ph = param().pad_h, pw = param().pad_w; | |||
| int sh = param().stride_h, sw = param().stride_w; | |||
| int dh = param().dilate_h, dw = param().dilate_w; | |||
| int wh = param().window_h, ww = param().window_w; | |||
| #define cb(DType) \ | |||
| if (diff.layout.dtype == DType()) { \ | |||
| using T = DTypeTrait<DType>::ctype; \ | |||
| sliding_window_transpose::backward(diff.ptr<T>(), grad.ptr<T>(), \ | |||
| N, C, IH, IW, OH, OW, \ | |||
| ph, pw, sh, sw, dh, dw, wh, ww, \ | |||
| stream); \ | |||
| return; \ | |||
| } | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb); | |||
| #undef cb | |||
| megdnn_assert_internal(0); | |||
| } | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,45 @@ | |||
| /** | |||
| * \file dnn/src/cuda/sliding_window_transpose/opr_impl.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs.h" | |||
| #include <cuda_runtime_api.h> | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| class SlidingWindowTransposeForwardImpl: public SlidingWindowTransposeForward { | |||
| public: | |||
| using SlidingWindowTransposeForward::SlidingWindowTransposeForward; | |||
| void exec(_megdnn_tensor_in src, | |||
| _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout &, | |||
| const TensorLayout &) override { | |||
| return 0; | |||
| } | |||
| }; | |||
| class SlidingWindowTransposeBackwardImpl: public SlidingWindowTransposeBackward { | |||
| public: | |||
| using SlidingWindowTransposeBackward::SlidingWindowTransposeBackward; | |||
| void exec(_megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout &, | |||
| const TensorLayout &) override { | |||
| return 0; | |||
| } | |||
| }; | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,133 @@ | |||
| /** | |||
| * \file dnn/src/cuda/sliding_window_transpose/kernel.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "src/cuda/sliding_window_transpose/sliding_window_transpose.cuh" | |||
| #include "megdnn/dtype.h" | |||
| #include "src/cuda/utils.cuh" | |||
| #include <cstdio> | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| namespace sliding_window_transpose { | |||
| template <typename T> | |||
| __global__ void forward_kernel(const T *src, T *dst, | |||
| int N, int C, int IH, int IW, int OH, int OW, | |||
| int ph, int pw, int sh, int sw, int dh, int dw, int WH, int WW) | |||
| { | |||
| int id = threadIdx.x + blockIdx.x * blockDim.x; | |||
| if (id < N*C*IH*IW) { | |||
| int nc = id / (IH*IW); | |||
| int ih = id % (IH*IW) / IW; | |||
| int iw = id % (IH*IW) % IW; | |||
| dst[nc*IH*IW + ih*IW + iw] = 0.0f; | |||
| int oh_max = min((ih+ph) / sh, OH-1); | |||
| int oh_min = max((ih+ph-(WH-1)*dh+sh-1) / sh, 0); | |||
| int ow_max = min((iw+pw) / sw, OW-1); | |||
| int ow_min = max((iw+pw-(WW-1)*dw+sw-1) / sw, 0); | |||
| for (int oh = oh_min; oh <= oh_max; ++oh) | |||
| for (int ow = ow_min; ow <= ow_max; ++ow) | |||
| { | |||
| if ((ih+ph - sh*oh)%dh==0 && (iw+pw - sw*ow)%dw==0){ | |||
| int wh = ih+ph - sh*oh - (ih+ph - sh*oh)/dh * (dh-1); | |||
| int ww = iw+pw - sw*ow - (iw+pw - sw*ow)/dw * (dw-1); | |||
| dst[nc*IH*IW + ih*IW + iw] += | |||
| src[nc*OH*OW*WH*WW + oh*OW*WH*WW + ow*WH*WW + | |||
| wh*WW + ww]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void forward(const T* src, T* dst, int N, int C, int IH, int IW, int OH, int OW, | |||
| int ph, int pw, int sh, int sw, int dh, int dw, int wh, int ww, | |||
| cudaStream_t stream) { | |||
| int threads = NR_THREADS; | |||
| int blocks = DIVUP(N*C*IH*IW, threads); | |||
| forward_kernel<<<blocks, threads, 0, stream>>>(src, dst, | |||
| N, C, IH, IW, OH, OW, | |||
| ph, pw, sh, sw, dh, dw, wh, ww); | |||
| after_kernel_launch(); | |||
| } | |||
| #define grid_y_max 512 | |||
| template <typename T> | |||
| __global__ void backward_kernel(const T *diff, T *grad, | |||
| int N, int C, int IH, int IW, int OH, int OW, | |||
| int ph, int pw, int sh, int sw, int dh, int dw, int WH, int WW) | |||
| { | |||
| int NC = N * C; | |||
| int WP = WH*WW; | |||
| for (int wp = threadIdx.x; wp < WP; wp += blockDim.x) { | |||
| int nc = blockIdx.y; | |||
| while (nc < NC) { | |||
| int wh = wp / WW; | |||
| int ww = wp % WW; | |||
| int op = threadIdx.y + blockIdx.x * blockDim.y; | |||
| if (op < OH * OW) { | |||
| int oh = op / OW; | |||
| int ow = op % OW; | |||
| int ih = -ph + sh * oh + wh* dh; | |||
| int iw = -pw + sw * ow + ww* dw; | |||
| int dst_pos = nc * OH * OW * WH * WW + op * WH * WW + wp; | |||
| int src_pos = nc * IH * IW + ih * IW + iw; | |||
| grad[dst_pos] = (ih >= 0 && ih < IH && iw >= 0 && iw < IW) | |||
| ? diff[src_pos] | |||
| : 0.0f; | |||
| } | |||
| nc += grid_y_max; | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void backward(const T *diff, T *grad, | |||
| int N, int C, int IH, int IW, int OH, int OW, | |||
| int ph, int pw, int sh, int sw, int dh, int dw, int wh, int ww, | |||
| cudaStream_t stream) | |||
| { | |||
| int spatial_size = OH * OW; | |||
| int kernel_size = wh * ww; | |||
| int tx = min(NR_THREADS, kernel_size); | |||
| int ty = NR_THREADS / tx; | |||
| megdnn_assert(ty > 0); | |||
| int bx = DIVUP(spatial_size, ty); | |||
| int by = N * C; | |||
| backward_kernel<<<dim3(bx, std::min(grid_y_max, by)), dim3(tx, ty), 0, | |||
| stream>>>(diff, grad, N, C, IH, IW, OH, OW, ph, pw, sh, sw, dh, dw, | |||
| wh, ww); | |||
| after_kernel_launch(); | |||
| } | |||
| #undef grid_y_max | |||
| #define INST(T) \ | |||
| template void forward<T>(const T *, T *, int, int, int, int, int, int, \ | |||
| int, int, int, int, int, int, int, int, \ | |||
| cudaStream_t); \ | |||
| template void backward<T>(const T *, T *, int, int, int, int, int, int, \ | |||
| int, int, int, int, int, int, int, int, \ | |||
| cudaStream_t); | |||
| #define cb(DType) \ | |||
| INST(DTypeTrait<DType>::ctype) | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
| } // namespace sliding_window_transpose | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,33 @@ | |||
| /** | |||
| * \file dnn/src/cuda/sliding_window_transpose/kernel.cuh | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include <cuda_runtime_api.h> | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| namespace sliding_window_transpose { | |||
| template <typename T> | |||
| void forward(const T *src, T *dst, | |||
| int N, int C, int IH, int IW, int OH, int OW, | |||
| int ph, int pw, int sh, int sw, int dh, int dw, int wh, int ww, | |||
| cudaStream_t stream); | |||
| template <typename T> | |||
| void backward(const T *diff, T *grad, | |||
| int N, int C, int IH, int IW, int OH, int OW, | |||
| int ph, int pw, int sh, int sw, int dh, int dw, int wh, int ww, | |||
| cudaStream_t stream); | |||
| } // namespace sliding_window_transpose | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -73,6 +73,7 @@ | |||
| #include "src/naive/separable_conv/opr_impl.h" | |||
| #include "src/naive/separable_filter/opr_impl.h" | |||
| #include "src/naive/sleep/opr_impl.h" | |||
| #include "src/naive/sliding_window_transpose/opr_impl.h" | |||
| #include "src/naive/split/opr_impl.h" | |||
| #include "src/naive/svd/opr_impl.h" | |||
| #include "src/naive/tensor_remap/opr_impl.h" | |||
| @@ -0,0 +1,141 @@ | |||
| /** | |||
| * \file dnn/src/naive/sliding_window_transpose/opr_impl.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "src/naive/sliding_window_transpose/opr_impl.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/handle.h" | |||
| #include <cstring> | |||
| namespace megdnn { | |||
| namespace naive { | |||
| template <typename T> | |||
| void SlidingWindowTransposeForwardImpl::exec_internal(_megdnn_tensor_in src, | |||
| _megdnn_tensor_out dst) | |||
| { | |||
| int N = dst.layout.shape[0], C = dst.layout.shape[1], | |||
| IH = dst.layout.shape[2], IW = dst.layout.shape[3]; | |||
| auto sptr = src.ptr<T>(); | |||
| auto dptr = dst.ptr<T>(); | |||
| size_t idx = 0; | |||
| int window_h = static_cast<int>(param().window_h); | |||
| int window_w = static_cast<int>(param().window_w); | |||
| int pad_h = static_cast<int>(param().pad_h); | |||
| int pad_w = static_cast<int>(param().pad_w); | |||
| int stride_h = static_cast<int>(param().stride_h); | |||
| int stride_w = static_cast<int>(param().stride_w); | |||
| int dilate_h = static_cast<int>(param().dilate_h); | |||
| int dilate_w = static_cast<int>(param().dilate_w); | |||
| int equ_window_h = dilate_h * (window_h-1) + 1; | |||
| int equ_window_w = dilate_w * (window_w-1) + 1; | |||
| memset(dptr, 0, sizeof(T) * N*C*IH*IW); | |||
| for (int n = 0; n < N; ++n) | |||
| for (int c = 0; c < C; ++c) | |||
| { | |||
| int ih = -pad_h; | |||
| for (; ih+equ_window_h <= IH+pad_h; ih += stride_h) { | |||
| int iw = -pad_w; | |||
| for (; iw+equ_window_w <= IW+pad_w; iw += stride_w) { | |||
| for (int kh = 0; kh < window_h; ++kh) | |||
| for (int kw = 0; kw < window_w; ++kw) | |||
| { | |||
| int ih2 = ih+dilate_h*kh, iw2 = iw+dilate_w*kw; | |||
| if (ih2 >= 0 && ih2 < IH && iw2 >= 0 && iw2 < IW) { | |||
| dptr[n*C*IH*IW + c*IH*IW + ih2*IW + iw2] += | |||
| sptr[idx*window_h*window_w + kh*window_w + kw]; | |||
| } | |||
| } | |||
| ++idx; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void SlidingWindowTransposeForwardImpl::exec(_megdnn_tensor_in src, | |||
| _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) | |||
| { | |||
| check_exec(src.layout, dst.layout, workspace.size); | |||
| #define cb(DType) \ | |||
| if (src.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
| exec_internal<typename DTypeTrait<DType>::ctype>(src, dst); \ | |||
| ); \ | |||
| return; \ | |||
| } | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb); | |||
| #undef cb | |||
| megdnn_assert_internal(0); | |||
| } | |||
| template <typename T> | |||
| void SlidingWindowTransposeBackwardImpl::exec_internal(_megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad) | |||
| { | |||
| int N = diff.layout.shape[0], C = diff.layout.shape[1], | |||
| IH = diff.layout.shape[2], IW = diff.layout.shape[3]; | |||
| auto sptr = grad.ptr<T>(); | |||
| auto dptr = diff.ptr<T>(); | |||
| size_t idx = 0; | |||
| int window_h = static_cast<int>(param().window_h); | |||
| int window_w = static_cast<int>(param().window_w); | |||
| int pad_h = static_cast<int>(param().pad_h); | |||
| int pad_w = static_cast<int>(param().pad_w); | |||
| int stride_h = static_cast<int>(param().stride_h); | |||
| int stride_w = static_cast<int>(param().stride_w); | |||
| int dilate_h = static_cast<int>(param().dilate_h); | |||
| int dilate_w = static_cast<int>(param().dilate_w); | |||
| int equ_window_h = dilate_h * (window_h-1) + 1; | |||
| int equ_window_w = dilate_w * (window_w-1) + 1; | |||
| for (int n = 0; n < N; ++n) | |||
| for (int c = 0; c < C; ++c) | |||
| { | |||
| int ih = -pad_h; | |||
| for (; ih+equ_window_h <= IH+pad_h; ih += stride_h) { | |||
| int iw = -pad_w; | |||
| for (; iw+equ_window_w <= IW+pad_w; iw += stride_w) { | |||
| for (int kh = 0; kh < window_h; ++kh) | |||
| for (int kw = 0; kw < window_w; ++kw) | |||
| { | |||
| int ih2 = ih+dilate_h*kh, iw2 = iw+dilate_w*kw; | |||
| sptr[idx*window_h*window_w + kh*window_w + kw] = | |||
| ih2 >= 0 && ih2 < IH && | |||
| iw2 >= 0 && iw2 < IW ? | |||
| dptr[n*C*IH*IW + c*IH*IW + ih2*IW + iw2] : 0.0f; | |||
| } | |||
| ++idx; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void SlidingWindowTransposeBackwardImpl::exec(_megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) | |||
| { | |||
| check_exec(diff.layout, grad.layout, workspace.size); | |||
| #define cb(DType) \ | |||
| if (diff.layout.dtype == DType()) { \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
| exec_internal<typename DTypeTrait<DType>::ctype>(diff, grad); \ | |||
| ); \ | |||
| return; \ | |||
| } | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb); | |||
| #undef cb | |||
| megdnn_assert_internal(0); | |||
| } | |||
| } // namespace naive | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,51 @@ | |||
| /** | |||
| * \file dnn/src/naive/sliding_window_transpose/opr_impl.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs.h" | |||
| namespace megdnn { | |||
| namespace naive { | |||
| class SlidingWindowTransposeForwardImpl: public SlidingWindowTransposeForward { | |||
| public: | |||
| using SlidingWindowTransposeForward::SlidingWindowTransposeForward; | |||
| void exec(_megdnn_tensor_in src, | |||
| _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout &, | |||
| const TensorLayout &) override { | |||
| return 0; | |||
| } | |||
| private: | |||
| template <typename T> | |||
| void exec_internal(_megdnn_tensor_in src, | |||
| _megdnn_tensor_out dst); | |||
| }; | |||
| class SlidingWindowTransposeBackwardImpl: public SlidingWindowTransposeBackward { | |||
| public: | |||
| using SlidingWindowTransposeBackward::SlidingWindowTransposeBackward; | |||
| void exec(_megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout &, | |||
| const TensorLayout &) override { | |||
| return 0; | |||
| } | |||
| private: | |||
| template <typename T> | |||
| void exec_internal(_megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad); | |||
| }; | |||
| } // namespace naive | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,86 @@ | |||
| /** | |||
| * \file dnn/test/common/sliding_window_transpose.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/opr_param_defs.h" | |||
| #include "megdnn/basic_types.h" | |||
| #include <cstddef> | |||
| namespace megdnn { | |||
| namespace test { | |||
| namespace sliding_window_transpose { | |||
| struct TestArg { | |||
| param::SlidingWindowTranspose param; | |||
| TensorShape ishape; | |||
| TestArg(param::SlidingWindowTranspose param, TensorShape ishape) | |||
| : param(param), ishape(ishape) {} | |||
| }; | |||
| inline std::vector<TestArg> get_args() { | |||
| std::vector<TestArg> args; | |||
| // clang-format off | |||
| for (uint32_t ih : {25, 96}) | |||
| for (uint32_t iw : {26, 128}) | |||
| for (uint32_t ph : {0, 1}) | |||
| for (uint32_t pw : {0, 1}) | |||
| for (uint32_t sh : {1, 2}) | |||
| for (uint32_t sw : {1, 2}) | |||
| for (uint32_t dh : {1, 2}) | |||
| for (uint32_t dw : {1, 2}) | |||
| for (uint32_t wh : {3, 4}) | |||
| for (uint32_t ww : {3, 4}) { | |||
| unsigned long int oh = (ih + 2 * ph - dh * (wh-1)-1) / sh + 1; | |||
| unsigned long int ow = (iw + 2 * pw - dw * (ww-1)-1) / sw + 1; | |||
| args.emplace_back(param::SlidingWindowTranspose{ih, iw, ph, pw, sh, sw, dh, dw, wh, ww}, | |||
| TensorShape{2, 3, oh, ow, wh, ww}); | |||
| } | |||
| // clang-format on | |||
| // large window case | |||
| args.emplace_back(param::SlidingWindowTranspose{96, 128, 0, 0, 1, 1, 1, 1, 32, 64}, | |||
| TensorShape{2, 3, 65, 65, 32, 64}); | |||
| // // large size | |||
| args.emplace_back(param::SlidingWindowTranspose{28, 24, 0, 0, 1, 1, 1, 1, 1, 1}, | |||
| TensorShape{128, 128, 28, 24, 1, 1}); | |||
| return args; | |||
| } | |||
| inline std::vector<TestArg> get_benchmark_args() { | |||
| std::vector<TestArg> args; | |||
| // clang-format off | |||
| for (uint32_t ph : {0, 1}) | |||
| for (uint32_t pw : {0, 1}) | |||
| for (uint32_t sh : {1, 2}) | |||
| for (uint32_t sw : {1, 2}) | |||
| for (uint32_t dh : {1, 2}) | |||
| for (uint32_t dw : {1, 2}) | |||
| for (uint32_t wh : {3, 4}) | |||
| for (uint32_t ww : {3, 4}) | |||
| for (uint32_t b : {1, 64}) | |||
| for (uint32_t c : {64, 128}) | |||
| for (uint32_t hw : {64, 128}) { | |||
| unsigned long int o_hw = (hw + 2 * ph - dh * (wh-1)-1) / sh + 1; | |||
| args.emplace_back(param::SlidingWindowTranspose{hw, hw, ph, pw, sh, sw, dh, dw, wh, ww}, | |||
| TensorShape{b, c, o_hw, o_hw, wh, ww}); | |||
| } | |||
| // clang-format on | |||
| // large size | |||
| args.emplace_back(param::SlidingWindowTranspose{28, 24, 0, 0, 1, 1, 1, 1, 1, 1}, | |||
| TensorShape{1024, 128, 28, 24, 1, 1}); | |||
| return args; | |||
| } | |||
| } // namespace sliding_window_transpose | |||
| } // namespace test | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,96 @@ | |||
| /** | |||
| * \file dnn/test/cuda/sliding_window_transpose.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "test/cuda/fixture.h" | |||
| #include "test/common/checker.h" | |||
| #include "test/common/sliding_window_transpose.h" | |||
| #include "test/common/rng.h" | |||
| #include "test/cuda/benchmark.h" | |||
| namespace megdnn { | |||
| namespace test { | |||
| TEST_F(CUDA, SLIDINGWINDOWTRANSPOSE_FORWARD) | |||
| { | |||
| UniformFloatRNG rng(0, 1); | |||
| auto args = sliding_window_transpose::get_args(); | |||
| for (auto &&arg: args) { | |||
| Checker<SlidingWindowTransposeForward> checker(handle_cuda()); | |||
| checker.set_rng(0, &rng); | |||
| checker.set_epsilon(1e-2); | |||
| TensorLayout ilayout = TensorLayout(arg.ishape, dtype::Float32()); | |||
| TensorLayout olayout; | |||
| { | |||
| auto opr = handle_cuda()->create_operator<SlidingWindowTransposeForward>(); | |||
| opr->param() = arg.param; | |||
| opr->deduce_layout(ilayout, olayout); | |||
| } | |||
| auto set_dtype = [&checker](DType dtype) | |||
| { | |||
| checker.set_dtype(0, dtype). | |||
| set_dtype(1, dtype); | |||
| }; | |||
| set_dtype(dtype::Float32()); | |||
| checker.set_param(arg.param).exec(TensorShapeArray{ | |||
| ilayout, olayout}); | |||
| set_dtype(dtype::Float16()); | |||
| checker.set_param(arg.param).exec(TensorShapeArray{ | |||
| ilayout, olayout}); | |||
| } | |||
| } | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| TEST_F(CUDA, BENCHMARK_SLIDINGWINDOWTRANSPOSE_FORWARD) | |||
| { | |||
| auto args = sliding_window_transpose::get_benchmark_args(); | |||
| for (auto &&arg: args) { | |||
| CUBenchmarker<SlidingWindowTransposeForward> bencher(handle_cuda()); | |||
| bencher.set_param(arg.param).set_dtype(0, dtype::Float32()). | |||
| exec(TensorShapeArray{ | |||
| arg.ishape, {}}); | |||
| } | |||
| } | |||
| #endif | |||
| TEST_F(CUDA, SLIDINGWINDOWTRANSPOSE_BACKWARD) | |||
| { | |||
| UniformFloatRNG rng(0, 1); | |||
| auto args = sliding_window_transpose::get_args(); | |||
| for (auto &&arg: args) { | |||
| Checker<SlidingWindowTransposeBackward> checker(handle_cuda()); | |||
| // checker.set_epsilon(1e-2); | |||
| checker.set_rng(0, &rng); | |||
| TensorLayout ilayout = TensorLayout(arg.ishape, dtype::Float32()); | |||
| TensorLayout olayout; | |||
| { | |||
| auto opr = handle_cuda()->create_operator<SlidingWindowTranspose>(); | |||
| opr->param() = arg.param; | |||
| opr->deduce_layout(ilayout, olayout); | |||
| } | |||
| auto set_dtype = [&checker](DType dtype) | |||
| { | |||
| checker.set_dtype(0, dtype). | |||
| set_dtype(1, dtype); | |||
| }; | |||
| set_dtype(dtype::Float32()); | |||
| checker.set_param(arg.param).exec(TensorShapeArray{ | |||
| olayout, ilayout}); | |||
| set_dtype(dtype::Float16()); | |||
| checker.set_param(arg.param).exec(TensorShapeArray{ | |||
| olayout, ilayout}); | |||
| } | |||
| } | |||
| } // namespace test | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,61 @@ | |||
| /** | |||
| * \file dnn/test/naive/sliding_window_transpose.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "test/naive/fixture.h" | |||
| #include "megdnn/oprs/nn.h" | |||
| #include "test/common/checker.h" | |||
| using namespace megdnn; | |||
| using namespace test; | |||
| TEST_F(NAIVE, SlidingWindowTranspose_FORWARD) { | |||
| Checker<SlidingWindowTranspose> checker(handle(), /* check_dispatch */false); | |||
| SlidingWindowTranspose::Param param(3,3,0,0,1,1,1,1,2,2); | |||
| checker.set_param(param).exect( | |||
| Testcase{TensorValue({1, 1, 2, 2, 2, 2}, dtype::Uint8(), | |||
| {0,1,3,4, | |||
| 1,2,4,5, | |||
| 3,4,6,7, | |||
| 4,5,7,8}), {}}, | |||
| Testcase{{}, | |||
| TensorValue({1, 1, 3, 3}, dtype::Uint8(), | |||
| {0,2,2, | |||
| 6,16,10, | |||
| 6,14,8})}); | |||
| param.out_h = 6; | |||
| param.out_w = 7; | |||
| param.pad_h = 1; | |||
| param.pad_w = 1; | |||
| param.stride_h = 2; | |||
| param.stride_w = 2; | |||
| param.dilate_h = 2; | |||
| param.dilate_w = 2; | |||
| param.window_h = 3; | |||
| param.window_w = 3; | |||
| checker.set_param(param).exect( | |||
| Testcase{TensorValue({1, 1, 2, 3, 3, 3}, dtype::Uint8(), | |||
| {0,0,0,0,8,10,0,22,24, | |||
| 0,0,0,8,10,12,22,24,26, | |||
| 0,0,0,10,12,0,24,26,0, | |||
| 0,8,10,0,22,24,0,36,38, | |||
| 8,10,12,22,24,26,36,38,40, | |||
| 10,12,0,24,26,0,38,40,0}), {}}, | |||
| Testcase{{}, | |||
| TensorValue({1, 1, 6, 7}, dtype::Uint8(), | |||
| {0,0,0,0,0,0,0, | |||
| 0,32,0,60,0,48,0, | |||
| 0,0,0,0,0,0,0, | |||
| 0,88,0,144,0,104,0, | |||
| 0,0,0,0,0,0,0, | |||
| 0,72,0,114,0,80,0})}); | |||
| } | |||
| @@ -71,6 +71,7 @@ __all__ = [ | |||
| "resize", | |||
| "sigmoid", | |||
| "sliding_window", | |||
| "sliding_window_transpose", | |||
| "softmax", | |||
| "softplus", | |||
| "sync_batch_norm", | |||
| @@ -1396,6 +1397,60 @@ def sliding_window( | |||
| return output | |||
| def sliding_window_transpose( | |||
| inp: Tensor, | |||
| output_size: Union[int, Tuple[int, int]], | |||
| kernel_size: Union[int, Tuple[int, int]], | |||
| padding: Union[int, Tuple[int, int]] = 0, | |||
| stride: Union[int, Tuple[int, int]] = 1, | |||
| dilation: Union[int, Tuple[int, int]] = 1, | |||
| ) -> Tensor: | |||
| """ | |||
| Sum over the sliding windows on the corresponding input location. | |||
| Refer to :class:`~.SlidingWindowTranspose` for more information. | |||
| :param inp: input tensor. | |||
| :param output_size: shape of output tensor. | |||
| :param kernel_size: size of the window. | |||
| :param padding: implicit zero padding added on both sides of input. Default: 0 | |||
| :param stride: stride of the window. Default: 1 | |||
| :param dilation: dilation of the window. Default: 1 | |||
| :return: output tensor. | |||
| """ | |||
| output_h, output_w = _pair_nonzero(output_size) | |||
| padding_h, padding_w = _pair(padding) | |||
| stride_h, stride_w = _pair_nonzero(stride) | |||
| dilation_h, dilation_w = _pair_nonzero(dilation) | |||
| window_h, window_w = _pair_nonzero(kernel_size) | |||
| expected_h = ( | |||
| output_h + 2 * padding_h - dilation_h * (window_h - 1) - 1 | |||
| ) // stride_h + 1 | |||
| expected_w = ( | |||
| output_w + 2 * padding_w - dilation_w * (window_w - 1) - 1 | |||
| ) // stride_w + 1 | |||
| assert inp.ndim == 6, "the input dimension of sliding_window_transpose should be 6" | |||
| assert ( | |||
| inp.shape[2] == expected_h and inp.shape[3] == expected_w | |||
| ), "the input shape and output size do not match" | |||
| op = builtin.SlidingWindowTranspose( | |||
| out_h=output_h, | |||
| out_w=output_w, | |||
| pad_h=padding_h, | |||
| pad_w=padding_w, | |||
| stride_h=stride_h, | |||
| stride_w=stride_w, | |||
| dilate_h=dilation_h, | |||
| dilate_w=dilation_w, | |||
| window_h=window_h, | |||
| window_w=window_w, | |||
| ) | |||
| (output,) = apply(op, inp) | |||
| return output | |||
| interpolate = deprecated_func("1.3", "megengine.functional.vision", "interpolate", True) | |||
| roi_pooling = deprecated_func("1.3", "megengine.functional.vision", "roi_pooling", True) | |||
| roi_align = deprecated_func("1.3", "megengine.functional.vision", "roi_align", True) | |||
| @@ -34,4 +34,4 @@ from .normalization import GroupNorm, InstanceNorm, LayerNorm | |||
| from .pooling import AvgPool2d, MaxPool2d | |||
| from .quant_dequant import DequantStub, QuantStub | |||
| from .sequential import Sequential | |||
| from .sliding_window import SlidingWindow | |||
| from .sliding_window import SlidingWindow, SlidingWindowTranspose | |||
| @@ -8,7 +8,7 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Tuple, Union | |||
| from ..functional import sliding_window | |||
| from ..functional import sliding_window, sliding_window_transpose | |||
| from .module import Module | |||
| @@ -86,3 +86,87 @@ class SlidingWindow(Module): | |||
| return sliding_window( | |||
| inp, self.kernel_size, self.padding, self.stride, self.dilation | |||
| ) | |||
| class SlidingWindowTranspose(Module): | |||
| r""" | |||
| Opposite opration of SlidingWindow, sum over the sliding windows on the | |||
| corresponding input location. Given an input of the size | |||
| :math:`(N, C, IH, IW, window_h, window_w)` and :attr:`output_size`, the | |||
| output shape would be :math:`(N, C, output\_size_{h}, output\_size_{w})` and the | |||
| arguments must satisfy | |||
| .. math:: | |||
| \text{IH} = \lfloor \frac{\text{output_size}_{h} + 2 * \text{padding}_{h} - | |||
| \text{dilation}_{h} * (\text{kernel_size}_{h} - 1) - 1}{\text{stride}_{h}} + 1 \rfloor | |||
| .. math:: | |||
| \text{IW} = \lfloor \frac{\text{output_size}_{w} + 2 * \text{padding}_{w} - | |||
| \text{dilation}_{w} * (\text{kernel_size}_{w} - 1) - 1}{\text{stride}_{w}} + 1 \rfloor | |||
| For each output location, we have: | |||
| .. math:: | |||
| \text{out}_{n, c, oh, ow} = \sum_{n,c,oh,ow=location(n, c, ih, iw, wh, ww)}\text{src}_{n, c, ih, iw, wh, ww} | |||
| .. math:: | |||
| \text{location}(n, c, ih, iw, wh, ww) &= (n, c, oh+wh, ow+ww) \\ | |||
| \text{where } & oh=-pad_h+ih \times stride_h + (wh-1) \times (dilation_h-1) \\ | |||
| & ow=-pad_w+iw \times stride_w + (ww-1) \times (dilation_w-1) | |||
| :param output_size: the size of the output tensor. | |||
| :param kernel_size: the size of the window to take a max over. | |||
| :param padding: implicit zero padding to be added on both sides. Default: 0 | |||
| :param stride: the stride of the window. Default: 1 | |||
| :param dilation: the dilation of the window. Default: 1 | |||
| Example: | |||
| .. testcode:: | |||
| from megengine import tensor | |||
| import megengine.module as M | |||
| import numpy as np | |||
| inp = tensor(np.arange(20).reshape(1,1,4,5)) | |||
| unfold = M.SlidingWindow(kernel_size=3, padding=0, stride=1, dilation=1) | |||
| fold = M.SlidingWindowTranspose((4,5), kernel_size=3, padding=0, stride=1, dilation=1) | |||
| out = fold(unfold(inp)) | |||
| print(out.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[[[ 0 2 6 6 4] | |||
| [10 24 42 32 18] | |||
| [20 44 72 52 28] | |||
| [15 32 51 36 19]]]] | |||
| """ | |||
| def __init__( | |||
| self, | |||
| output_size: Union[int, Tuple[int, int]], | |||
| kernel_size: Union[int, Tuple[int, int]], | |||
| padding: Union[int, Tuple[int, int]] = 0, | |||
| stride: Union[int, Tuple[int, int]] = 1, | |||
| dilation: Union[int, Tuple[int, int]] = 1, | |||
| **kwargs | |||
| ): | |||
| super(SlidingWindowTranspose, self).__init__(**kwargs) | |||
| self.output_size = output_size | |||
| self.kernel_size = kernel_size | |||
| self.padding = padding | |||
| self.stride = stride | |||
| self.dilation = dilation | |||
| def forward(self, inp): | |||
| return sliding_window_transpose( | |||
| inp, | |||
| self.output_size, | |||
| self.kernel_size, | |||
| self.padding, | |||
| self.stride, | |||
| self.dilation, | |||
| ) | |||
| @@ -953,3 +953,39 @@ def test_sliding_window(): | |||
| tensor(inp), (wh, ww), padding=(ph, pw), stride=(sh, sw), dilation=(dh, dw) | |||
| ) | |||
| np.testing.assert_equal(gt_out, out.numpy()) | |||
| def test_sliding_window_transpose(): | |||
| N, C, H, W = 2, 3, 7, 8 | |||
| ph, pw = 1, 2 | |||
| sh, sw = 2, 1 | |||
| wh, ww = 3, 2 | |||
| dh, dw = 1, 3 | |||
| s = lambda i, p, s, d, w: (i + p * 2 - (w - 1) * d - 1) // s + 1 | |||
| inp = np.random.normal( | |||
| size=(N, C, s(H, ph, sh, dh, wh), s(W, pw, sw, dw, ww), wh, ww) | |||
| ).astype(np.float32) | |||
| gt_out = np.zeros((N, C, H, W), dtype=np.float32) | |||
| for n, c in itertools.product(*map(range, inp.shape[:2])): | |||
| oh = 0 | |||
| for ih in range(-ph, H + ph - dh * (wh - 1), sh): | |||
| ow = 0 | |||
| for iw in range(-pw, W + pw - dw * (ww - 1), sw): | |||
| for kh, kw in itertools.product(*map(range, inp.shape[-2:])): | |||
| ih2 = ih + dh * kh | |||
| iw2 = iw + dw * kw | |||
| if ih2 >= 0 and ih2 < H and iw2 >= 0 and iw2 < W: | |||
| gt_out[n, c, ih2, iw2] += inp[n, c, oh, ow, kh, kw] | |||
| ow += 1 | |||
| oh += 1 | |||
| out = F.sliding_window_transpose( | |||
| tensor(inp), | |||
| (H, W), | |||
| (wh, ww), | |||
| padding=(ph, pw), | |||
| stride=(sh, sw), | |||
| dilation=(dh, dw), | |||
| ) | |||
| np.testing.assert_equal(gt_out, out.numpy()) | |||
| @@ -35,6 +35,8 @@ | |||
| #include "megbrain/opr/tensor_gen.h" | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| #include "megbrain/opr/utility.h" | |||
| #include "megbrain/opr/dnn/images2neibs.h" | |||
| #include "megbrain/opr/dnn/sliding_window_transpose.h" | |||
| #include "../op_trait.h" | |||
| @@ -658,4 +660,17 @@ OP_TRAIT_REG(LSQ, LSQ).apply_on_var_node(apply_on_var_node).fallback(); | |||
| } // namespace lsq | |||
| } // namespace | |||
| } // namespace mgb::imperative | |||
| namespace { namespace sliding_window_transpose { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const SlidingWindowTranspose&>(def); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::SlidingWindowTranspose::make(inputs[0], op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(SlidingWindowTranspose, SlidingWindowTranspose) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // sliding_window_transpose | |||
| } // namespace mgb::imperative | |||
| @@ -81,6 +81,8 @@ def BatchConvBias : MgbHashableOp<"BatchConvBias", [BatchConvBiasParam, Executio | |||
| def Images2Neibs : MgbHashableOp<"Images2Neibs", [Images2NeibsParam]>; | |||
| def SlidingWindowTranspose : MgbHashableOp<"SlidingWindowTranspose", [SlidingWindowTransposeParam]>; | |||
| def BatchNorm : MgbHashableOp<"BatchNorm", [BNParam]>; | |||
| def ROIAlign: MgbHashableOp<"ROIAlign", [ROIAlignParam]>; | |||
| @@ -16,6 +16,8 @@ | |||
| #include "megbrain/opr/dnn/correlation.h" | |||
| #include "megbrain/opr/dnn/fake_quant.h" | |||
| #include "megbrain/opr/dnn/images2neibs.h" | |||
| #include "megbrain/opr/dnn/sliding_window_transpose.h" | |||
| #include "megbrain/opr/dnn/adaptive_pooling.h" | |||
| #include "megbrain/opr/dnn/local.h" | |||
| #include "megbrain/opr/dnn/lrn.h" | |||
| #include "megbrain/opr/dnn/lsq.h" | |||
| @@ -531,6 +533,9 @@ MGB_SEREG_OPR(ConvolutionBackwardFilterV2, 0); | |||
| MGB_SEREG_OPR(Images2Neibs, 1); | |||
| MGB_SEREG_OPR(Images2NeibsBackward, 2); | |||
| MGB_SEREG_OPR(SlidingWindowTranspose, 1); | |||
| MGB_SEREG_OPR(SlidingWindowTransposeBackward, 2); | |||
| using LocalV2 = Local; | |||
| using LocalBackwardDataV2 = LocalBackwardData; | |||
| using LocalBackwardFilterV2 = LocalBackwardFilter; | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * \file src/opr/impl/dnn/sliding_window_transpose.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/opr/dnn/sliding_window_transpose.h" | |||
| #include "megbrain/graph/grad_impl.h" | |||
| #include "../internal/megdnn_opr_wrapper.inl" | |||
| using namespace mgb; | |||
| using namespace opr; | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(SlidingWindowTransposeForward); | |||
| MEGDNN_OPR_INIT1(SlidingWindowTransposeForward, "sliding_window_transpose") | |||
| #if MGB_ENABLE_GRAD | |||
| MGB_IMPL_OPR_GRAD(SlidingWindowTransposeForward) { | |||
| mgb_assert(wrt_idx == 0 && out_grad.size() == 2 && !out_grad[1]); | |||
| return SlidingWindowTransposeBackward::make( | |||
| out_grad[0], opr.input(0), opr.param()).node(); | |||
| } | |||
| #endif | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(SlidingWindowTransposeBackward); | |||
| MEGDNN_OPR_INIT2(SlidingWindowTransposeBackward, "sliding_window_transpose_grad", 1, false); | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * \file src/opr/include/megbrain/opr/dnn/sliding_window_transpose.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
| #include "megdnn/oprs.h" | |||
| namespace mgb { | |||
| namespace opr { | |||
| MGB_DEFINE_OPR_CLASS(SlidingWindowTransposeForward, | |||
| intl::MegDNNOprWrapperFwd<megdnn::SlidingWindowTransposeForward>) // { | |||
| public: | |||
| SlidingWindowTransposeForward(VarNode *src, | |||
| const Param ¶m, | |||
| const OperatorNodeConfig &config); | |||
| static SymbolVar make(SymbolVar src, | |||
| const Param ¶m = {}, | |||
| const OperatorNodeConfig &config = {}); | |||
| }; | |||
| using SlidingWindowTranspose = SlidingWindowTransposeForward; | |||
| MGB_DEFINE_OPR_CLASS(SlidingWindowTransposeBackward, | |||
| intl::MegDNNOprWrapperBwd<megdnn::SlidingWindowTransposeBackward>) // { | |||
| public: | |||
| SlidingWindowTransposeBackward(VarNode *diff, VarNode *src_for_shape, | |||
| const Param ¶m, | |||
| const OperatorNodeConfig &config); | |||
| static SymbolVar make(SymbolVar diff, SymbolVar src_for_shape, | |||
| const Param ¶m = {}, | |||
| const OperatorNodeConfig &config = {}); | |||
| }; | |||
| } // namespace opr | |||
| } // namespace mgb | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -0,0 +1,64 @@ | |||
| /** | |||
| * \file src/opr/test/dnn/sliding_window_transpose.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/test/helper.h" | |||
| #include "megbrain/test/autocheck.h" | |||
| #include "megbrain/test/megdnn_helper.h" | |||
| #include "megbrain/opr/dnn/sliding_window_transpose.h" | |||
| #include "megdnn/oprs.h" | |||
| using namespace mgb; | |||
| TEST(TestOprDNN, SlidingWindowTranspose) { | |||
| using Checker = AutoOprChecker<1, 1>; | |||
| opr::SlidingWindowTranspose::Param param; | |||
| param.pad_h = 1; | |||
| param.pad_w = 2; | |||
| param.stride_w = 2; | |||
| param.window_h = 4; | |||
| param.dilate_h = 2; | |||
| unsigned long ih = 16, iw = 15; | |||
| unsigned long oh = (ih + 2 * param.pad_h - param.dilate_h * (param.window_h-1)-1) / param.stride_h + 1; | |||
| unsigned long ow = (iw + 2 * param.pad_w - param.dilate_w * (param.window_w-1)-1) / param.stride_w + 1; | |||
| param.out_h = ih; | |||
| param.out_w = iw; | |||
| auto make_graph = [&](const Checker::SymInpArray &inputs) -> | |||
| Checker::SymOutArray { | |||
| return {opr::SlidingWindowTranspose::make(inputs[0], param)}; | |||
| }; | |||
| auto fwd = [&](Checker::NumOutArray &dest, Checker::NumInpArray inp) { | |||
| auto opr = megdnn_naive_handle()-> | |||
| create_operator<megdnn::SlidingWindowTranspose>(); | |||
| opr->param() = param; | |||
| TensorLayout dest_layout; | |||
| opr->deduce_layout(inp[0]->layout(), dest_layout); | |||
| std::vector<dt_byte> workspace( | |||
| opr->get_workspace_in_bytes(inp[0]->layout(), dest_layout)); | |||
| dest[0].dtype(dtype::Float32()). | |||
| comp_node(inp[0]->comp_node()).resize(dest_layout); | |||
| opr->exec(inp[0]->as_megdnn(), dest[0].as_megdnn(), | |||
| {workspace.data(), workspace.size()}); | |||
| }; | |||
| Checker::RunOptions opt; | |||
| opt.numdiff_eps = 1; | |||
| Checker checker{make_graph, fwd}; | |||
| checker. | |||
| run({TensorShape{2, 3, oh, ow, param.window_h, param.window_w}}, opt). | |||
| run({TensorShape{4, 5, oh, ow, param.window_h, param.window_w}}, opt). | |||
| run({TensorShape{3, 2, oh, ow, param.window_h, param.window_w}}, opt); | |||
| } | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -112,6 +112,7 @@ union OperatorParam { | |||
| param.PoissonRNG = 78, | |||
| param.PermutationRNG = 79, | |||
| param.BetaRNG = 80, | |||
| param.SlidingWindowTranspose = 81, | |||
| } | |||
| table Operator { | |||