| @@ -192,6 +192,87 @@ class ReduceForward: public OperatorBase { | |||||
| }; | }; | ||||
| using Reduce = ReduceForward; | using Reduce = ReduceForward; | ||||
| class CorrelationBase : public OperatorBase { | |||||
| DEF_OPR_IMPL_CTOR(CorrelationBase, OperatorBase); | |||||
| DEF_OPR_PARAM(Correlation); | |||||
| protected: | |||||
| void deduce_layout_fwd(const TensorLayout& data1, const TensorLayout& data2, | |||||
| TensorLayout& dst); | |||||
| void check_layout_fwd(const TensorLayout& data1, const TensorLayout& data2, | |||||
| const TensorLayout& dst); | |||||
| }; | |||||
| class CorrelationForward : public CorrelationBase { | |||||
| DEF_OPR_IMPL(CorrelationForward, CorrelationBase, 2, 1); | |||||
| public: | |||||
| /** | |||||
| * \param[in] data1 (n, c, ih, iw) | |||||
| * \param[in] data2 (n, c, ih, iw) | |||||
| * \param[out] dst (n, q, oh, ow), q is the number of neighborhood | |||||
| * */ | |||||
| virtual void exec(_megdnn_tensor_in data1, _megdnn_tensor_in data2, | |||||
| _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | |||||
| void deduce_layout(const TensorLayout& data1, const TensorLayout& data2, | |||||
| TensorLayout& dst); | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout& data1, | |||||
| const TensorLayout& data2, | |||||
| const TensorLayout& dst) = 0; | |||||
| protected: | |||||
| void check_exec(const TensorLayout& data1, const TensorLayout& data2, | |||||
| const TensorLayout& dst, size_t workspace_in_bytes); | |||||
| }; | |||||
| using Correlation = CorrelationForward; | |||||
| class CorrelationBackwardData1 : public CorrelationBase { | |||||
| DEF_OPR_IMPL(CorrelationBackwardData1, CorrelationBase, 3, 1); | |||||
| public: | |||||
| /** | |||||
| * \param[in] diff the backpropagated gradient wrt. dst | |||||
| * \param[in] data1 the `data1' parameter in CorrelationForward::exec | |||||
| * \param[in] data2 the `data2' parameter in CorrelationForward::exec | |||||
| * \param[out] grad1 the backpropagated gradient wrt. data1 | |||||
| */ | |||||
| virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, _megdnn_tensor_in data2, | |||||
| _megdnn_tensor_out grad1, _megdnn_workspace workspace) = 0; | |||||
| void deduce_layout(const TensorLayout& diff1, const TensorLayout& data1, | |||||
| const TensorLayout& data2, TensorLayout& dst); | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout& diff, | |||||
| const TensorLayout& data1, | |||||
| const TensorLayout& data2, | |||||
| const TensorLayout& grad1) = 0; | |||||
| protected: | |||||
| void check_exec(const TensorLayout& diff, const TensorLayout& data1, const TensorLayout& data2, | |||||
| const TensorLayout& grad1, size_t workspace_in_bytes); | |||||
| }; | |||||
| class CorrelationBackwardData2 : public CorrelationBase { | |||||
| DEF_OPR_IMPL(CorrelationBackwardData2, CorrelationBase, 3, 1); | |||||
| public: | |||||
| /** | |||||
| * \param[in] diff the backpropagated gradient wrt. dst | |||||
| * \param[in] data1 the `data1' parameter in CorrelationForward::exec | |||||
| * \param[in] data2 the `data2' parameter in CorrelationForward::exec | |||||
| * \param[out] grad2 the backpropagated gradient wrt. data2 | |||||
| */ | |||||
| virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, _megdnn_tensor_in data2, | |||||
| _megdnn_tensor_out grad2, _megdnn_workspace workspace) = 0; | |||||
| void deduce_layout(const TensorLayout& diff1, const TensorLayout& data1, | |||||
| const TensorLayout& data2, TensorLayout& dst); | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout& diff, | |||||
| const TensorLayout& data1, | |||||
| const TensorLayout& data2, | |||||
| const TensorLayout& grad2) = 0; | |||||
| protected: | |||||
| void check_exec(const TensorLayout& diff, const TensorLayout& data1, const TensorLayout& data2, | |||||
| const TensorLayout& grad2, size_t workspace_in_bytes); | |||||
| }; | |||||
| class CumsumForward: public OperatorBase { | class CumsumForward: public OperatorBase { | ||||
| DEF_OPR_PARAM(Cumsum); | DEF_OPR_PARAM(Cumsum); | ||||
| DEF_OPR_IMPL(CumsumForward, OperatorBase, 1, 1); | DEF_OPR_IMPL(CumsumForward, OperatorBase, 1, 1); | ||||
| @@ -1053,6 +1053,16 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o | |||||
| 'sample_width', '2') | 'sample_width', '2') | ||||
| ) | ) | ||||
| (pdef('Correlation'). | |||||
| add_enum_alias('Format', 'ConvolutionV0'). | |||||
| add_fields('uint32', 'kernel_size', '1'). | |||||
| add_fields('uint32', 'max_displacement', '1'). | |||||
| add_fields('uint32', 'stride1', '1'). | |||||
| add_fields('uint32', 'stride2', '1'). | |||||
| add_fields('uint32', 'pad_size', '0'). | |||||
| add_fields('bool', 'is_multiply', 'true') | |||||
| ) | |||||
| (pdef('DeformablePSROIPooling'). | (pdef('DeformablePSROIPooling'). | ||||
| add_fields('bool', 'no_trans', 'true'). | add_fields('bool', 'no_trans', 'true'). | ||||
| add_fields('float32', 'spatial_scale', 1, | add_fields('float32', 'spatial_scale', 1, | ||||
| @@ -0,0 +1,132 @@ | |||||
| /** | |||||
| * \file dnn/src/common/correlation.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "megdnn/oprs.h" | |||||
| #include "src/common/utils.h" | |||||
| namespace megdnn { | |||||
| void CorrelationBase::deduce_layout_fwd(const TensorLayout& data1, | |||||
| const TensorLayout& data2, | |||||
| TensorLayout& dst) { | |||||
| megdnn_assert_contiguous(data1); | |||||
| megdnn_assert_contiguous(data2); | |||||
| megdnn_assert_contiguous(dst); | |||||
| auto errmsg = [&]() { | |||||
| return megdnn_layout_msg(data1) + ", " + megdnn_layout_msg(data2) + | |||||
| ", " + megdnn_layout_msg(dst); | |||||
| }; | |||||
| MEGDNN_MARK_USED_VAR(errmsg); | |||||
| using Format = CorrelationBase::Param::Format; | |||||
| megdnn_assert(param().format == Format::NCHW); | |||||
| auto data1_dtype = data1.dtype, data2_dtype = data2.dtype; | |||||
| megdnn_assert(data1_dtype == data2_dtype && | |||||
| data1_dtype.category() == DTypeCategory::FLOAT); | |||||
| megdnn_assert(data1.ndim == 4_z, "%s", errmsg().c_str()); | |||||
| megdnn_assert(data2.ndim == 4_z, "%s", errmsg().c_str()); | |||||
| uint32_t pad_size = param().pad_size; | |||||
| uint32_t kernel_size = param().kernel_size; | |||||
| uint32_t stride1 = param().stride1; | |||||
| uint32_t stride2 = param().stride2; | |||||
| uint32_t max_displacement = param().max_displacement; | |||||
| int paddedbottomheight = data1[2] + 2 * pad_size; | |||||
| int paddedbottomwidth = data1[3] + 2 * pad_size; | |||||
| uint32_t kernel_radius = (kernel_size - 1) / 2; | |||||
| uint32_t border_size = max_displacement + kernel_radius; | |||||
| uint32_t top_width = | |||||
| ceil(static_cast<float>(paddedbottomwidth - border_size * 2) / | |||||
| static_cast<float>(stride1)); | |||||
| uint32_t top_height = | |||||
| ceil(static_cast<float>(paddedbottomheight - border_size * 2) / | |||||
| static_cast<float>(stride1)); | |||||
| uint32_t neighborhood_grid_radius = max_displacement / stride2; | |||||
| uint32_t neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; | |||||
| uint32_t top_channels = neighborhood_grid_width * neighborhood_grid_width; | |||||
| megdnn_assert(top_width >= 1 && top_height >= 1); | |||||
| dst = TensorLayout{{data1[0], top_channels, top_height, top_width}, | |||||
| data1.dtype}; | |||||
| } | |||||
| void CorrelationBase::check_layout_fwd(const TensorLayout& data1, | |||||
| const TensorLayout& data2, | |||||
| const TensorLayout& dst) { | |||||
| TensorLayout dst_expected; | |||||
| megdnn_assert_eq_dtype(data1, dst); | |||||
| megdnn_assert_eq_shape(data1, data2); | |||||
| deduce_layout_fwd(data1, data2, dst_expected); | |||||
| megdnn_assert_eq_shape(dst_expected, dst); | |||||
| } | |||||
| void CorrelationForward::deduce_layout(const TensorLayout& data1, | |||||
| const TensorLayout& data2, | |||||
| TensorLayout& dst) { | |||||
| deduce_layout_fwd(data1, data2, dst); | |||||
| } | |||||
| void CorrelationForward::check_exec(const TensorLayout& data1, | |||||
| const TensorLayout& data2, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_in_bytes) { | |||||
| check_layout_fwd(data1, data2, dst); | |||||
| auto required_workspace_in_bytes = | |||||
| get_workspace_in_bytes(data1, data2, dst); | |||||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||||
| } | |||||
| void CorrelationBackwardData1::check_exec(const TensorLayout& diff, | |||||
| const TensorLayout& data1, | |||||
| const TensorLayout& data2, | |||||
| const TensorLayout& grad1, | |||||
| size_t workspace_in_bytes) { | |||||
| check_layout_fwd(grad1, data2, diff); | |||||
| megdnn_assert_eq_shape(data1, data2); | |||||
| auto required_workspace_in_bytes = | |||||
| get_workspace_in_bytes(diff, data1, data2, grad1); | |||||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||||
| } | |||||
| void CorrelationBackwardData2::check_exec(const TensorLayout& diff, | |||||
| const TensorLayout& data1, | |||||
| const TensorLayout& data2, | |||||
| const TensorLayout& grad2, | |||||
| size_t workspace_in_bytes) { | |||||
| check_layout_fwd(data1, grad2, diff); | |||||
| megdnn_assert_eq_shape(data1, data2); | |||||
| auto required_workspace_in_bytes = | |||||
| get_workspace_in_bytes(diff, data1, data2, grad2); | |||||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||||
| } | |||||
| void CorrelationBackwardData2::deduce_layout(const TensorLayout& diff, | |||||
| const TensorLayout& data1, | |||||
| const TensorLayout& data2, | |||||
| TensorLayout& grad) { | |||||
| megdnn_assert_eq_shape(data1, data2); | |||||
| check_layout_fwd(data1, data2, diff); | |||||
| grad = data2; | |||||
| } | |||||
| void CorrelationBackwardData1::deduce_layout(const TensorLayout& diff, | |||||
| const TensorLayout& data1, | |||||
| const TensorLayout& data2, | |||||
| TensorLayout& grad) { | |||||
| megdnn_assert_eq_shape(data1, data2); | |||||
| check_layout_fwd(data1, data2, diff); | |||||
| grad = data1; | |||||
| } | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -194,6 +194,9 @@ private: | |||||
| cb(LocalShareBackwardFilter) \ | cb(LocalShareBackwardFilter) \ | ||||
| cb(ROIAlignForward) \ | cb(ROIAlignForward) \ | ||||
| cb(ROIAlignBackward) \ | cb(ROIAlignBackward) \ | ||||
| cb(CorrelationForward) \ | |||||
| cb(CorrelationBackwardData1) \ | |||||
| cb(CorrelationBackwardData2) \ | |||||
| cb(BatchConvBiasForward) \ | cb(BatchConvBiasForward) \ | ||||
| cb(Remap) \ | cb(Remap) \ | ||||
| cb(RemapBackwardData) \ | cb(RemapBackwardData) \ | ||||
| @@ -54,6 +54,9 @@ DEF(BNForward, 8, true, true); | |||||
| DEF(BNBackward, 8, true, false); | DEF(BNBackward, 8, true, false); | ||||
| DEF(ROIPoolingForward, 4, true, false); | DEF(ROIPoolingForward, 4, true, false); | ||||
| DEF(ROIPoolingBackward, 5, true, false); | DEF(ROIPoolingBackward, 5, true, false); | ||||
| DEF(CorrelationForward, 3, true, true); | |||||
| DEF(CorrelationBackwardData1, 4, true, true); | |||||
| DEF(CorrelationBackwardData2, 4, true, true); | |||||
| DEF(WarpPerspectiveForward, 3, true, false); | DEF(WarpPerspectiveForward, 3, true, false); | ||||
| DEF(WarpPerspectiveBackwardData, 3, true, false); | DEF(WarpPerspectiveBackwardData, 3, true, false); | ||||
| DEF(WarpPerspectiveBackwardMat, 4, true, false); | DEF(WarpPerspectiveBackwardMat, 4, true, false); | ||||
| @@ -0,0 +1,371 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/roi_align/roi_align.cu | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/cuda/correlation/correlation_cuda.cuh" | |||||
| #include <cfloat> | |||||
| #include "megdnn/dtype.h" | |||||
| #include "src/cuda/query_blocksize.cuh" | |||||
| #include "src/cuda/utils.cuh" | |||||
| #define ROUND_OFF 50000 | |||||
| using namespace megdnn; | |||||
| namespace megdnn { | |||||
| namespace cuda { | |||||
| namespace correlation { | |||||
| #define CUDA_KERNEL_LOOP(vtid, vthreads) \ | |||||
| for (int vtid = blockIdx.x * blockDim.x + threadIdx.x; vtid < vthreads; \ | |||||
| vtid += blockDim.x * gridDim.x) | |||||
| template <typename T> | |||||
| __global__ void forward_kernel(const int nthreads, const T* data1, | |||||
| const T* data2, T* dst, const int bchannels, | |||||
| const int bheight, const int bwidth, | |||||
| const int tchannels, const int theight, | |||||
| const int twidth, const int kernel_size, | |||||
| const int max_displacement, const int stride1, | |||||
| const int stride2, const int pad_size, | |||||
| const bool is_multiply) { | |||||
| CUDA_KERNEL_LOOP(idx, nthreads) { | |||||
| int kernel_radius = (kernel_size - 1) / 2; | |||||
| int neighborhood_grid_radius = max_displacement / stride2; | |||||
| int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; | |||||
| int x = idx % twidth; | |||||
| int y = (idx / twidth) % theight; | |||||
| int c = (idx / twidth / theight) % tchannels; | |||||
| int n = idx / twidth / theight / tchannels; | |||||
| // get src center position in image1 | |||||
| int x1 = x * stride1 + kernel_radius + max_displacement - pad_size; | |||||
| int y1 = y * stride1 + kernel_radius + max_displacement - pad_size; | |||||
| // get offset of center in image2 | |||||
| int s2o = (c % neighborhood_grid_width - neighborhood_grid_radius) * | |||||
| stride2; | |||||
| int s2p = (c / neighborhood_grid_width - neighborhood_grid_radius) * | |||||
| stride2; | |||||
| int x2 = x1 + s2o; | |||||
| int y2 = y1 + s2p; | |||||
| // compute kernel correlation | |||||
| T sum = T(0.f); | |||||
| for (int i = -kernel_radius; i <= kernel_radius; i++) { | |||||
| for (int j = -kernel_radius; j <= kernel_radius; j++) { | |||||
| int in_x1 = x1 + i; | |||||
| int in_y1 = y1 + j; | |||||
| int in_x2 = x2 + i; | |||||
| int in_y2 = y2 + j; | |||||
| for (int channel = 0; channel < bchannels; channel++) { | |||||
| T tmp1 = T(0.f); | |||||
| T tmp2 = T(0.f); | |||||
| if (in_x1 >= 0 && in_x1 < bwidth && in_y1 >= 0 && | |||||
| in_y1 < bheight) { | |||||
| int idx1 = | |||||
| ((n * bchannels + channel) * bheight + in_y1) * | |||||
| bwidth + | |||||
| in_x1; | |||||
| tmp1 = data1[idx1]; | |||||
| } | |||||
| if (in_x2 >= 0 && in_x2 < bwidth && in_y2 >= 0 && | |||||
| in_y2 < bheight) { | |||||
| int idx2 = | |||||
| ((n * bchannels + channel) * bheight + in_y2) * | |||||
| bwidth + | |||||
| in_x2; | |||||
| tmp2 = data2[idx2]; | |||||
| } | |||||
| if (is_multiply) { | |||||
| sum += tmp1 * tmp2; | |||||
| } else { | |||||
| sum += fabsf(tmp1 - tmp2); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| const int sumelems = | |||||
| (kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels; | |||||
| dst[idx] = sum / sumelems; | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| __global__ void backward_kernel_data1( | |||||
| const int nthreads, const T* diff, const T* data1, const T* data2, | |||||
| T* grad1, const int bchannels, const int bheight, const int bwidth, | |||||
| const int tchannels, const int theight, const int twidth, | |||||
| const int kernel_size, const int max_displacement, const int stride1, | |||||
| const int stride2, const int pad_size, const bool is_multiply) { | |||||
| CUDA_KERNEL_LOOP(idx, nthreads) { | |||||
| int kernel_radius = (kernel_size - 1) / 2; | |||||
| int neighborhood_grid_radius = max_displacement / stride2; | |||||
| int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; | |||||
| int x = idx % bwidth; | |||||
| int y = (idx / bwidth) % bheight; | |||||
| int c = (idx / bwidth / bheight) % bchannels; | |||||
| int n = idx / bwidth / bheight / bchannels; | |||||
| T tmp1 = data1[idx]; | |||||
| // Get X,Y ranges and clamp | |||||
| // round_off is a trick to enable integer division with ceil, even for | |||||
| // negative numbers We use a large offset, for the inner part not to | |||||
| // become negative. | |||||
| const int round_off = ROUND_OFF; | |||||
| const int round_off_s1 = stride1 * round_off; | |||||
| // we show cal the x_min,y_min,x_max,y_max of diff for grad1(x,y) | |||||
| // for diff_x_min, diff_y_min, x,y at the position of right-down | |||||
| // ceil (l - 2*kernel_radius - max_displacement + pad_size) / stride1 | |||||
| int xmin = (x + pad_size - 2 * kernel_radius - max_displacement + | |||||
| round_off_s1 - 1) / | |||||
| stride1 + | |||||
| 1 - round_off; | |||||
| int ymin = (y + pad_size - 2 * kernel_radius - max_displacement + | |||||
| round_off_s1 - 1) / | |||||
| stride1 + | |||||
| 1 - round_off; | |||||
| // floor (l - max_displacement + pad_size) / stride1 | |||||
| int xmax = (x + pad_size - max_displacement + round_off_s1) / stride1 - | |||||
| round_off; | |||||
| int ymax = (y + pad_size - max_displacement + round_off_s1) / stride1 - | |||||
| round_off; | |||||
| T sum = T(0.f); | |||||
| if (xmax >= 0 && ymax >= 0 && (xmin <= twidth - 1) && | |||||
| (ymin <= theight - 1)) { | |||||
| xmin = max(0, xmin); | |||||
| xmax = min(twidth - 1, xmax); | |||||
| ymin = max(0, ymin); | |||||
| ymax = min(theight - 1, ymax); | |||||
| for (int p = -neighborhood_grid_radius; | |||||
| p <= neighborhood_grid_radius; p++) { | |||||
| for (int o = -neighborhood_grid_radius; | |||||
| o <= neighborhood_grid_radius; o++) { | |||||
| // Get bottom1 data: | |||||
| int s2o = stride2 * o; | |||||
| int s2p = stride2 * p; | |||||
| int x2 = x + s2o, y2 = y + s2p; | |||||
| int idx2 = | |||||
| ((n * bchannels + c) * bheight + y2) * bwidth + x2; | |||||
| T tmp2 = T(0.f); | |||||
| if (x2 >= 0 && x2 < bwidth && y2 >= 0 && y2 < bheight) { | |||||
| tmp2 = data2[idx2]; | |||||
| } | |||||
| int op = (p + neighborhood_grid_radius) * | |||||
| neighborhood_grid_width + | |||||
| (o + neighborhood_grid_radius); | |||||
| int diff_channels_offset = (n * tchannels + op); | |||||
| for (int diff_y = ymin; diff_y <= ymax; diff_y++) { | |||||
| for (int diff_x = xmin; diff_x <= xmax; diff_x++) { | |||||
| int idxtopdiff = | |||||
| (diff_channels_offset * theight + diff_y) * | |||||
| twidth + | |||||
| diff_x; | |||||
| if (is_multiply) { | |||||
| sum += diff[idxtopdiff] * tmp2; | |||||
| } else { | |||||
| T sign = (tmp1 >= tmp2) ? T(1.f) : T(-1.f); | |||||
| sum += diff[idxtopdiff] * sign; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| const int sumelems = | |||||
| (kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels; | |||||
| grad1[idx] = sum / sumelems; | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| __global__ void backward_kernel_data2( | |||||
| const int nthreads, const T* diff, const T* data1, const T* data2, | |||||
| T* grad2, const int bchannels, const int bheight, const int bwidth, | |||||
| const int tchannels, const int theight, const int twidth, | |||||
| const int kernel_size, const int max_displacement, const int stride1, | |||||
| const int stride2, const int pad_size, const bool is_multiply) { | |||||
| CUDA_KERNEL_LOOP(idx, nthreads) { | |||||
| int kernel_radius = (kernel_size - 1) / 2; | |||||
| int neighborhood_grid_radius = max_displacement / stride2; | |||||
| int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; | |||||
| int x = idx % bwidth; | |||||
| int y = (idx / bwidth) % bheight; | |||||
| int c = (idx / bwidth / bheight) % bchannels; | |||||
| int n = idx / bwidth / bheight / bchannels; | |||||
| T tmp2 = data2[idx]; | |||||
| T sum = T(0.f); | |||||
| for (int p = -neighborhood_grid_radius; p <= neighborhood_grid_radius; | |||||
| p++) { | |||||
| for (int o = -neighborhood_grid_radius; | |||||
| o <= neighborhood_grid_radius; o++) { | |||||
| int s2o = o * stride2; | |||||
| int s2p = p * stride2; | |||||
| int x1 = x - s2o; | |||||
| int y1 = y - s2p; | |||||
| const int round_off = ROUND_OFF; | |||||
| const int round_off_s1 = stride1 * round_off; | |||||
| int xmin = (x1 + pad_size - 2 * kernel_radius - | |||||
| max_displacement + round_off_s1 - 1) / | |||||
| stride1 + | |||||
| 1 - round_off; | |||||
| int ymin = (y1 + pad_size - 2 * kernel_radius - | |||||
| max_displacement + round_off_s1 - 1) / | |||||
| stride1 + | |||||
| 1 - round_off; | |||||
| int xmax = (x1 + pad_size - max_displacement + round_off_s1) / | |||||
| stride1 - | |||||
| round_off; | |||||
| int ymax = (y1 + pad_size - max_displacement + round_off_s1) / | |||||
| stride1 - | |||||
| round_off; | |||||
| if (xmax >= 0 && ymax >= 0 && (xmin <= twidth - 1) && | |||||
| (ymin <= theight - 1)) { | |||||
| xmin = max(0, xmin); | |||||
| xmax = min(twidth - 1, xmax); | |||||
| ymin = max(0, ymin); | |||||
| ymax = min(theight - 1, ymax); | |||||
| int idx1 = | |||||
| ((n * bchannels + c) * bheight + y1) * bwidth + x1; | |||||
| T tmp1 = T(0.f); | |||||
| if (x1 >= 0 && x1 < bwidth && y1 >= 0 && y1 < bheight) { | |||||
| tmp1 = data1[idx1]; | |||||
| } | |||||
| int op = (p + neighborhood_grid_radius) * | |||||
| neighborhood_grid_width + | |||||
| (o + neighborhood_grid_radius); | |||||
| int diff_channels_offset = (n * tchannels + op); | |||||
| for (int diff_y = ymin; diff_y <= ymax; diff_y++) { | |||||
| for (int diff_x = xmin; diff_x <= xmax; diff_x++) { | |||||
| int idxtopdiff = | |||||
| (diff_channels_offset * theight + diff_y) * | |||||
| twidth + | |||||
| diff_x; | |||||
| if (is_multiply) { | |||||
| sum += diff[idxtopdiff] * tmp1; | |||||
| } else { | |||||
| T sign = (tmp1 >= tmp2) ? T(-1.f) : T(1.f); | |||||
| sum += diff[idxtopdiff] * sign; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| const int sumelems = | |||||
| (kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels; | |||||
| grad2[idx] = sum / sumelems; | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void forward_proxy(const int nthreads, const T* data1, const T* data2, T* dst, | |||||
| const int bchannels, const int bheight, const int bwidth, | |||||
| const int tchannels, const int theight, const int twidth, | |||||
| const int kernel_size, const int max_displacement, | |||||
| const int stride1, const int stride2, const int pad_size, | |||||
| const bool is_multiply, cudaStream_t stream) { | |||||
| int threads_block = query_blocksize_for_kernel(forward_kernel<T>); | |||||
| forward_kernel<T> | |||||
| <<<DIVUP(nthreads, threads_block), threads_block, 0, stream>>>( | |||||
| nthreads, data1, data2, dst, bchannels, bheight, bwidth, | |||||
| tchannels, theight, twidth, kernel_size, max_displacement, | |||||
| stride1, stride2, pad_size, is_multiply); | |||||
| after_kernel_launch(); | |||||
| } | |||||
| template <typename T> | |||||
| void backward_proxy_data1(const int nthreads, const T* diff, const T* data1, | |||||
| const T* data2, T* grad1, const int bchannels, | |||||
| const int bheight, const int bwidth, | |||||
| const int tchannels, const int theight, | |||||
| const int twidth, const int kernel_size, | |||||
| const int max_displacement, const int stride1, | |||||
| const int stride2, const int pad_size, | |||||
| const bool is_multiply, cudaStream_t stream) { | |||||
| int threads_block = query_blocksize_for_kernel(backward_kernel_data1<T>); | |||||
| backward_kernel_data1<T> | |||||
| <<<DIVUP(nthreads, threads_block), threads_block, 0, stream>>>( | |||||
| nthreads, diff, data1, data2, grad1, bchannels, bheight, | |||||
| bwidth, tchannels, theight, twidth, kernel_size, | |||||
| max_displacement, stride1, stride2, pad_size, is_multiply); | |||||
| after_kernel_launch(); | |||||
| } | |||||
| template <typename T> | |||||
| void backward_proxy_data2(const int nthreads, const T* diff, const T* data1, | |||||
| const T* data2, T* grad2, const int bchannels, | |||||
| const int bheight, const int bwidth, | |||||
| const int tchannels, const int theight, | |||||
| const int twidth, const int kernel_size, | |||||
| const int max_displacement, const int stride1, | |||||
| const int stride2, const int pad_size, | |||||
| const bool is_multiply, cudaStream_t stream) { | |||||
| int threads_block = query_blocksize_for_kernel(backward_kernel_data2<T>); | |||||
| backward_kernel_data2<T> | |||||
| <<<DIVUP(nthreads, threads_block), threads_block, 0, stream>>>( | |||||
| nthreads, diff, data1, data2, grad2, bchannels, bheight, | |||||
| bwidth, tchannels, theight, twidth, kernel_size, | |||||
| max_displacement, stride1, stride2, pad_size, is_multiply); | |||||
| after_kernel_launch(); | |||||
| } | |||||
| #define INST(T) \ | |||||
| template void forward_proxy<T>( \ | |||||
| const int, const T*, const T*, T* dst, const int, const int, \ | |||||
| const int, const int, const int, const int, const int, const int, \ | |||||
| const int, const int, const int, const bool, cudaStream_t); \ | |||||
| template void backward_proxy_data1<T>( \ | |||||
| const int, const T*, const T*, const T*, T*, const int, const int, \ | |||||
| const int, const int, const int, const int, const int, const int, \ | |||||
| const int, const int, const int, const bool, cudaStream_t); \ | |||||
| template void backward_proxy_data2<T>( \ | |||||
| const int, const T*, const T*, const T*, T*, const int, const int, \ | |||||
| const int, const int, const int, const int, const int, const int, \ | |||||
| const int, const int, const int, const bool, cudaStream_t); | |||||
| INST(dt_float32) | |||||
| INST(dt_float16) | |||||
| INST(dt_bfloat16) | |||||
| #undef INST | |||||
| } // namespace roi_align | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,51 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/correlation/correlation.cuh | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include <cuda_runtime_api.h> | |||||
| namespace megdnn { | |||||
| namespace cuda { | |||||
| namespace correlation { | |||||
| template <typename T> | |||||
| void forward_proxy(const int nthreads, const T* data1, const T* data2, T* dst, | |||||
| const int bchannels, const int bheight, const int bwidth, | |||||
| const int tchannels, const int theight, const int twidth, | |||||
| const int kernel_size, const int max_displacement, | |||||
| const int stride1, const int stride2, const int pad_size, | |||||
| const bool is_multiply, cudaStream_t stream); | |||||
| template <typename T> | |||||
| void backward_proxy_data1(const int nthreads, const T* diff, const T* data1, | |||||
| const T* data2, T* grad1, const int bchannels, | |||||
| const int bheight, const int bwidth, | |||||
| const int tchannels, const int theight, | |||||
| const int twidth, const int kernel_size, | |||||
| const int max_displacement, const int stride1, | |||||
| const int stride2, const int pad_size, | |||||
| const bool is_multiply, cudaStream_t stream); | |||||
| template <typename T> | |||||
| void backward_proxy_data2(const int nthreads, const T* diff, const T* data1, | |||||
| const T* data2, T* grad2, const int bchannels, | |||||
| const int bheight, const int bwidth, | |||||
| const int tchannels, const int theight, | |||||
| const int twidth, const int kernel_size, | |||||
| const int max_displacement, const int stride1, | |||||
| const int stride2, const int pad_size, | |||||
| const bool is_multiply, cudaStream_t stream); | |||||
| } // namespace correlation | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,129 @@ | |||||
| /** | |||||
| * \file dnn/src/naive/correlation/opr_impl.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/cuda/correlation/opr_impl.h" | |||||
| #include "src/cuda/correlation/correlation_cuda.cuh" | |||||
| #include "src/cuda/utils.h" | |||||
| namespace megdnn { | |||||
| namespace cuda { | |||||
| void CorrelationForwardImpl::exec(_megdnn_tensor_in data1, | |||||
| _megdnn_tensor_in data2, | |||||
| _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) { | |||||
| check_exec(data1.layout, data2.layout, dst.layout, workspace.size); | |||||
| auto p = param(); | |||||
| auto stream = cuda_stream(handle()); | |||||
| int nthreads = dst.layout.total_nr_elems(); | |||||
| int stride1 = p.stride1; | |||||
| int stride2 = p.stride2; | |||||
| int kernel_size = p.kernel_size; | |||||
| int max_displacement = p.max_displacement; | |||||
| int pad_size = p.pad_size; | |||||
| bool is_multiply = p.is_multiply; | |||||
| int tchannels = dst.layout[1]; | |||||
| int theight = dst.layout[2], twidth = dst.layout[3]; | |||||
| int bchannels = data1.layout[1]; | |||||
| int bheight = data1.layout[2], bwidth = data1.layout[3]; | |||||
| using namespace ::megdnn::cuda::correlation; | |||||
| #define cb(DType) \ | |||||
| if (data1.layout.dtype == DType()) { \ | |||||
| using T = typename DTypeTrait<DType>::ctype; \ | |||||
| forward_proxy<T>(nthreads, data1.ptr<T>(), data2.ptr<T>(), \ | |||||
| dst.ptr<T>(), bchannels, bheight, bwidth, tchannels, \ | |||||
| theight, twidth, kernel_size, max_displacement, \ | |||||
| stride1, stride2, pad_size, is_multiply, stream); \ | |||||
| } | |||||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
| #undef cb | |||||
| } | |||||
| void CorrelationBackwardData1Impl::exec(_megdnn_tensor_in diff, | |||||
| _megdnn_tensor_in data1, | |||||
| _megdnn_tensor_in data2, | |||||
| _megdnn_tensor_out grad1, | |||||
| _megdnn_workspace workspace) { | |||||
| check_exec(diff.layout, data1.layout, data2.layout, grad1.layout, | |||||
| workspace.size); | |||||
| auto stream = cuda_stream(handle()); | |||||
| int nthreads = grad1.layout.total_nr_elems(); | |||||
| int stride1 = param().stride1; | |||||
| int stride2 = param().stride2; | |||||
| int kernel_size = param().kernel_size; | |||||
| int max_displacement = param().max_displacement; | |||||
| int pad_size = param().pad_size; | |||||
| bool is_multiply = param().is_multiply; | |||||
| int tchannels = diff.layout[1]; | |||||
| int theight = diff.layout[2], twidth = diff.layout[3]; | |||||
| int bchannels = data1.layout[1]; | |||||
| int bheight = data1.layout[2], bwidth = data1.layout[3]; | |||||
| using namespace ::megdnn::cuda::correlation; | |||||
| #define cb(DType) \ | |||||
| if (diff.layout.dtype == DType()) { \ | |||||
| using T = typename DTypeTrait<DType>::ctype; \ | |||||
| backward_proxy_data1<T>(nthreads, diff.ptr<T>(), data1.ptr<T>(), \ | |||||
| data2.ptr<T>(), grad1.ptr<T>(), bchannels, \ | |||||
| bheight, bwidth, tchannels, theight, twidth, \ | |||||
| kernel_size, max_displacement, stride1, \ | |||||
| stride2, pad_size, is_multiply, stream); \ | |||||
| } | |||||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
| #undef cb | |||||
| } | |||||
| void CorrelationBackwardData2Impl::exec(_megdnn_tensor_in diff, | |||||
| _megdnn_tensor_in data1, | |||||
| _megdnn_tensor_in data2, | |||||
| _megdnn_tensor_out grad2, | |||||
| _megdnn_workspace workspace) { | |||||
| check_exec(diff.layout, data1.layout, data2.layout, grad2.layout, | |||||
| workspace.size); | |||||
| auto p = param(); | |||||
| auto stream = cuda_stream(handle()); | |||||
| int nthreads = grad2.layout.total_nr_elems(); | |||||
| int stride1 = p.stride1; | |||||
| int stride2 = p.stride2; | |||||
| int kernel_size = p.kernel_size; | |||||
| int max_displacement = p.max_displacement; | |||||
| int pad_size = p.pad_size; | |||||
| bool is_multiply = p.is_multiply; | |||||
| int tchannels = diff.layout[1]; | |||||
| int theight = diff.layout[2], twidth = diff.layout[3]; | |||||
| int bchannels = data1.layout[1]; | |||||
| int bheight = data1.layout[2], bwidth = data1.layout[3]; | |||||
| using namespace ::megdnn::cuda::correlation; | |||||
| #define cb(DType) \ | |||||
| if (diff.layout.dtype == DType()) { \ | |||||
| using T = typename DTypeTrait<DType>::ctype; \ | |||||
| backward_proxy_data2<T>(nthreads, diff.ptr<T>(), data1.ptr<T>(), \ | |||||
| data2.ptr<T>(), grad2.ptr<T>(), bchannels, \ | |||||
| bheight, bwidth, tchannels, theight, twidth, \ | |||||
| kernel_size, max_displacement, stride1, \ | |||||
| stride2, pad_size, is_multiply, stream); \ | |||||
| } | |||||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
| #undef cb | |||||
| } | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,61 @@ | |||||
| /** | |||||
| * \file dnn/src/naive/correlation/opr_impl.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megdnn/oprs.h" | |||||
| #include "src/cuda/cudnn_wrapper.h" | |||||
| namespace megdnn { | |||||
| namespace cuda { | |||||
| class CorrelationForwardImpl final : public CorrelationForward { | |||||
| public: | |||||
| using CorrelationForward::CorrelationForward; | |||||
| void exec(_megdnn_tensor_in data1, _megdnn_tensor_in data2, | |||||
| _megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||||
| size_t get_workspace_in_bytes(const TensorLayout& data1, | |||||
| const TensorLayout& data2, | |||||
| const TensorLayout& dst) override { | |||||
| return 0; | |||||
| } | |||||
| }; | |||||
| class CorrelationBackwardData1Impl final : public CorrelationBackwardData1 { | |||||
| public: | |||||
| using CorrelationBackwardData1::CorrelationBackwardData1; | |||||
| void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, | |||||
| _megdnn_tensor_in data2, _megdnn_tensor_out grad1, | |||||
| _megdnn_workspace workspace) override; | |||||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||||
| const TensorLayout&, | |||||
| const TensorLayout&) override { | |||||
| return 0; | |||||
| } | |||||
| }; | |||||
| class CorrelationBackwardData2Impl final : public CorrelationBackwardData2 { | |||||
| public: | |||||
| using CorrelationBackwardData2::CorrelationBackwardData2; | |||||
| void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, | |||||
| _megdnn_tensor_in data2, _megdnn_tensor_out grad2, | |||||
| _megdnn_workspace workspace) override; | |||||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||||
| const TensorLayout&, | |||||
| const TensorLayout&) override { | |||||
| return 0; | |||||
| } | |||||
| }; | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -24,6 +24,7 @@ | |||||
| #include "src/cuda/convolution/opr_impl.h" | #include "src/cuda/convolution/opr_impl.h" | ||||
| #include "src/cuda/convolution3d/opr_impl.h" | #include "src/cuda/convolution3d/opr_impl.h" | ||||
| #include "src/cuda/convpooling/opr_impl.h" | #include "src/cuda/convpooling/opr_impl.h" | ||||
| #include "src/cuda/correlation/opr_impl.h" | |||||
| #include "src/cuda/cumsum/opr_impl.h" | #include "src/cuda/cumsum/opr_impl.h" | ||||
| #include "src/cuda/cvt_color/opr_impl.h" | #include "src/cuda/cvt_color/opr_impl.h" | ||||
| #include "src/cuda/dct/opr_impl.h" | #include "src/cuda/dct/opr_impl.h" | ||||
| @@ -0,0 +1,384 @@ | |||||
| /** | |||||
| * \file dnn/src/naive/correlation/opr_impl.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/naive/correlation/opr_impl.h" | |||||
| #include <algorithm> | |||||
| #include "src/common/utils.h" | |||||
| #include "src/naive/handle.h" | |||||
| #define ROUND_OFF 50000 | |||||
| using namespace megdnn; | |||||
| using namespace naive; | |||||
| using namespace std; | |||||
| namespace { | |||||
| using Param = megdnn::Correlation::Param; | |||||
| template <typename T> | |||||
| void forward(_megdnn_tensor_in data1, _megdnn_tensor_in data2, | |||||
| _megdnn_tensor_out dst, const Param& param) { | |||||
| // data1 treat as no-padding tensor | |||||
| int total_nr_elems = dst.layout.total_nr_elems(); | |||||
| int stride1 = param.stride1, stride2 = param.stride2; | |||||
| int kernel_size = param.kernel_size; | |||||
| int kernel_radius = (kernel_size - 1) / 2; | |||||
| int max_displacement = param.max_displacement; | |||||
| int pad_size = param.pad_size; | |||||
| int tchannels = dst.layout[1]; | |||||
| int theight = dst.layout[2], twidth = dst.layout[3]; | |||||
| int bchannels = data1.layout[1]; | |||||
| int bheight = data1.layout[2], bwidth = data1.layout[3]; | |||||
| int neighborhood_grid_radius = max_displacement / stride2; | |||||
| int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; | |||||
| for (int idx = 0; idx < total_nr_elems; ++idx) { | |||||
| int x = idx % twidth; | |||||
| int y = (idx / twidth) % theight; | |||||
| int c = (idx / twidth / theight) % tchannels; | |||||
| int n = idx / twidth / theight / tchannels; | |||||
| // get src center position in image1 | |||||
| int x1 = x * stride1 + kernel_radius + max_displacement - pad_size; | |||||
| int y1 = y * stride1 + kernel_radius + max_displacement - pad_size; | |||||
| // get offset of center in image2 | |||||
| int s2o = (c % neighborhood_grid_width - neighborhood_grid_radius) * | |||||
| stride2; | |||||
| int s2p = (c / neighborhood_grid_width - neighborhood_grid_radius) * | |||||
| stride2; | |||||
| int x2 = x1 + s2o; | |||||
| int y2 = y1 + s2p; | |||||
| // compute kernel correlation | |||||
| float sum = 0.; | |||||
| for (int i = -kernel_radius; i <= kernel_radius; i++) { | |||||
| for (int j = -kernel_radius; j <= kernel_radius; j++) { | |||||
| int in_x1 = x1 + i; | |||||
| int in_y1 = y1 + j; | |||||
| int in_x2 = x2 + i; | |||||
| int in_y2 = y2 + j; | |||||
| for (int channel = 0; channel < bchannels; channel++) { | |||||
| float tmp1 = 0.; | |||||
| float tmp2 = 0.; | |||||
| if (in_x1 >= 0 && in_x1 < bwidth && in_y1 >= 0 && | |||||
| in_y1 < bheight) { | |||||
| int idx1 = | |||||
| ((n * bchannels + channel) * bheight + in_y1) * | |||||
| bwidth + | |||||
| in_x1; | |||||
| tmp1 = data1.ptr<T>()[idx1]; | |||||
| } | |||||
| if (in_x2 >= 0 && in_x2 < bwidth && in_y2 >= 0 && | |||||
| in_y2 < bheight) { | |||||
| int idx2 = | |||||
| ((n * bchannels + channel) * bheight + in_y2) * | |||||
| bwidth + | |||||
| in_x2; | |||||
| tmp2 = data2.ptr<T>()[idx2]; | |||||
| } | |||||
| if (param.is_multiply) { | |||||
| sum += tmp1 * tmp2; | |||||
| } else { | |||||
| sum += fabsf(tmp1 - tmp2); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| const int sumelems = | |||||
| (kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels; | |||||
| dst.ptr<T>()[idx] = sum / sumelems; | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void backward_data1(_megdnn_tensor_in diff, _megdnn_tensor_in data1, | |||||
| _megdnn_tensor_in data2, _megdnn_tensor_out grad1, | |||||
| const Param& param) { | |||||
| // data1 treat as no-padding tensor | |||||
| // int total_nr_elems = diff.layout.total_nr_elems(); | |||||
| int total_nr_elems = grad1.layout.total_nr_elems(); | |||||
| int stride1 = param.stride1, stride2 = param.stride2; | |||||
| int kernel_size = param.kernel_size; | |||||
| int kernel_radius = (kernel_size - 1) / 2; | |||||
| int max_displacement = param.max_displacement; | |||||
| int pad_size = param.pad_size; | |||||
| int tchannels = diff.layout[1]; | |||||
| int theight = diff.layout[2], twidth = diff.layout[3]; | |||||
| int bchannels = grad1.layout[1]; | |||||
| int bheight = grad1.layout[2], bwidth = grad1.layout[3]; | |||||
| int neighborhood_grid_radius = max_displacement / stride2; | |||||
| int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; | |||||
| for (int idx = 0; idx < total_nr_elems; ++idx) { | |||||
| // idx for grad1 | |||||
| int x = idx % bwidth; | |||||
| int y = (idx / bwidth) % bheight; | |||||
| int c = (idx / bwidth / bheight) % bchannels; | |||||
| int n = idx / bwidth / bheight / bchannels; | |||||
| float tmp1 = data1.ptr<T>()[idx]; | |||||
| // Get X,Y ranges and clamp | |||||
| // round_off is a trick to enable integer division with ceil, even for | |||||
| // negative numbers We use a large offset, for the inner part not to | |||||
| // become negative. | |||||
| const int round_off = ROUND_OFF; | |||||
| const int round_off_s1 = stride1 * round_off; | |||||
| // we show cal the x_min,y_min,x_max,y_max of diff for grad1(x,y) | |||||
| // for diff_x_min, diff_y_min, x,y at the position of right-down | |||||
| // ceil (l - 2*kernel_radius - max_displacement + pad_size) / stride1 | |||||
| int xmin = (x + pad_size - 2 * kernel_radius - max_displacement + | |||||
| round_off_s1 - 1) / | |||||
| stride1 + | |||||
| 1 - round_off; | |||||
| int ymin = (y + pad_size - 2 * kernel_radius - max_displacement + | |||||
| round_off_s1 - 1) / | |||||
| stride1 + | |||||
| 1 - round_off; | |||||
| // floor (l - max_displacement + pad_size) / stride1 | |||||
| int xmax = (x + pad_size - max_displacement + round_off_s1) / stride1 - | |||||
| round_off; | |||||
| int ymax = (y + pad_size - max_displacement + round_off_s1) / stride1 - | |||||
| round_off; | |||||
| float sum = 0.; | |||||
| if (xmax >= 0 && ymax >= 0 && (xmin <= twidth - 1) && | |||||
| (ymin <= theight - 1)) { | |||||
| xmin = max(0, xmin); | |||||
| xmax = min(twidth - 1, xmax); | |||||
| ymin = max(0, ymin); | |||||
| ymax = min(theight - 1, ymax); | |||||
| for (int p = -neighborhood_grid_radius; | |||||
| p <= neighborhood_grid_radius; p++) { | |||||
| for (int o = -neighborhood_grid_radius; | |||||
| o <= neighborhood_grid_radius; o++) { | |||||
| // Get bottom1 data: | |||||
| int s2o = stride2 * o; | |||||
| int s2p = stride2 * p; | |||||
| int x2 = x + s2p, y2 = y + s2o; | |||||
| int idx2 = | |||||
| ((n * bchannels + c) * bheight + y2) * bwidth + x2; | |||||
| float tmp2 = 0.; | |||||
| if (x2 >= 0 && x2 < bwidth && y2 >= 0 && y2 < bheight) { | |||||
| tmp2 = data2.ptr<T>()[idx2]; | |||||
| } | |||||
| int op = (p + neighborhood_grid_radius) * | |||||
| neighborhood_grid_width + | |||||
| (o + neighborhood_grid_radius); | |||||
| int diff_channels_offset = (n * tchannels + op); | |||||
| for (int diff_y = ymin; diff_y <= ymax; diff_y++) { | |||||
| for (int diff_x = xmin; diff_x <= xmax; diff_x++) { | |||||
| int idxtopdiff = | |||||
| (diff_channels_offset * theight + diff_y) * | |||||
| twidth + | |||||
| diff_x; | |||||
| if (param.is_multiply) { | |||||
| sum += diff.ptr<T>()[idxtopdiff] * tmp2; | |||||
| } else { | |||||
| T sign = (tmp1 > tmp2) ? T(1.) : T(-1.); | |||||
| sum += diff.ptr<T>()[idxtopdiff] * sign; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| const int sumelems = | |||||
| (kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels; | |||||
| grad1.ptr<T>()[idx] = sum / sumelems; | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void backward_data2(_megdnn_tensor_in diff, _megdnn_tensor_in data1, | |||||
| _megdnn_tensor_in data2, _megdnn_tensor_out grad2, | |||||
| const Param& param) { | |||||
| // data1 treat as no-padding tensor | |||||
| int total_nr_elems = grad2.layout.total_nr_elems(); | |||||
| int stride1 = param.stride1, stride2 = param.stride2; | |||||
| int kernel_size = param.kernel_size; | |||||
| int kernel_radius = (kernel_size - 1) / 2; | |||||
| int max_displacement = param.max_displacement; | |||||
| int pad_size = param.pad_size; | |||||
| int tchannels = diff.layout[1]; | |||||
| int theight = diff.layout[2], twidth = diff.layout[3]; | |||||
| int bchannels = grad2.layout[1]; | |||||
| int bheight = grad2.layout[2], bwidth = grad2.layout[3]; | |||||
| int neighborhood_grid_radius = max_displacement / stride2; | |||||
| int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; | |||||
| for (int idx = 0; idx < total_nr_elems; ++idx) { | |||||
| int x = idx % bwidth; | |||||
| int y = (idx / bwidth) % bheight; | |||||
| int c = (idx / bwidth / bheight) % bchannels; | |||||
| int n = idx / bwidth / bheight / bchannels; | |||||
| T tmp2 = data2.ptr<T>()[idx]; | |||||
| T sum = T(0.f); | |||||
| for (int p = -neighborhood_grid_radius; p <= neighborhood_grid_radius; | |||||
| p++) { | |||||
| for (int o = -neighborhood_grid_radius; | |||||
| o <= neighborhood_grid_radius; o++) { | |||||
| int s2o = o * stride2; | |||||
| int s2p = p * stride2; | |||||
| int x1 = x - s2o; | |||||
| int y1 = y - s2p; | |||||
| const int round_off = ROUND_OFF; | |||||
| const int round_off_s1 = stride1 * round_off; | |||||
| int xmin = (x1 + pad_size - 2 * kernel_radius - | |||||
| max_displacement + round_off_s1 - 1) / | |||||
| stride1 + | |||||
| 1 - round_off; | |||||
| int ymin = (y1 + pad_size - 2 * kernel_radius - | |||||
| max_displacement + round_off_s1 - 1) / | |||||
| stride1 + | |||||
| 1 - round_off; | |||||
| int xmax = (x1 + pad_size - max_displacement + round_off_s1) / | |||||
| stride1 - | |||||
| round_off; | |||||
| int ymax = (y1 + pad_size - max_displacement + round_off_s1) / | |||||
| stride1 - | |||||
| round_off; | |||||
| if (xmax >= 0 && ymax >= 0 && (xmin <= twidth - 1) && | |||||
| (ymin <= theight - 1)) { | |||||
| xmin = max(0, xmin); | |||||
| xmax = min(twidth - 1, xmax); | |||||
| ymin = max(0, ymin); | |||||
| ymax = min(theight - 1, ymax); | |||||
| int idx1 = | |||||
| ((n * bchannels + c) * bheight + y1) * bwidth + x1; | |||||
| T tmp1 = T(0.f); | |||||
| if (x1 >= 0 && x1 < bwidth && y1 >= 0 && y1 < bheight) { | |||||
| tmp1 = data1.ptr<T>()[idx1]; | |||||
| } | |||||
| int op = (p + neighborhood_grid_radius) * | |||||
| neighborhood_grid_width + | |||||
| (o + neighborhood_grid_radius); | |||||
| int diff_channels_offset = (n * tchannels + op); | |||||
| for (int diff_y = ymin; diff_y <= ymax; diff_y++) { | |||||
| for (int diff_x = xmin; diff_x <= xmax; diff_x++) { | |||||
| int idxtopdiff = | |||||
| (diff_channels_offset * theight + diff_y) * | |||||
| twidth + | |||||
| diff_x; | |||||
| if (param.is_multiply) { | |||||
| sum += diff.ptr<T>()[idxtopdiff] * tmp1; | |||||
| } else { | |||||
| T sign = (tmp1 >= tmp2) ? T(-1.f) : T(1.f); | |||||
| sum += diff.ptr<T>()[idxtopdiff] * sign; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| const int sumelems = | |||||
| (kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels; | |||||
| grad2.ptr<T>()[idx] = sum / sumelems; | |||||
| } | |||||
| } | |||||
| } // namespace | |||||
| namespace megdnn { | |||||
| namespace naive { | |||||
| void CorrelationForwardImpl::exec(_megdnn_tensor_in data1, | |||||
| _megdnn_tensor_in data2, | |||||
| _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) { | |||||
| check_exec(data1.layout, data2.layout, dst.layout, workspace.size); | |||||
| #define cb(DType) \ | |||||
| if (data1.layout.dtype == DType()) { \ | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
| forward<typename DTypeTrait<DType>::ctype>(data1, data2, dst, \ | |||||
| param())); \ | |||||
| return; \ | |||||
| } | |||||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
| #undef cb | |||||
| megdnn_throw("bad dtype"); | |||||
| } | |||||
| void CorrelationBackwardData1Impl::exec(_megdnn_tensor_in diff, | |||||
| _megdnn_tensor_in data1, | |||||
| _megdnn_tensor_in data2, | |||||
| _megdnn_tensor_out grad1, | |||||
| _megdnn_workspace workspace) { | |||||
| check_exec(diff.layout, data1.layout, data2.layout, grad1.layout, | |||||
| workspace.size); | |||||
| #define cb(DType) \ | |||||
| if (diff.layout.dtype == DType()) { \ | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
| backward_data1<typename DTypeTrait<DType>::ctype>( \ | |||||
| diff, data1, data2, grad1, param())); \ | |||||
| return; \ | |||||
| } | |||||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
| #undef cb | |||||
| megdnn_throw("bad dtype"); | |||||
| } | |||||
| void CorrelationBackwardData2Impl::exec(_megdnn_tensor_in diff, | |||||
| _megdnn_tensor_in data1, | |||||
| _megdnn_tensor_in data2, | |||||
| _megdnn_tensor_out grad2, | |||||
| _megdnn_workspace workspace) { | |||||
| check_exec(diff.layout, data1.layout, data2.layout, grad2.layout, | |||||
| workspace.size); | |||||
| #define cb(DType) \ | |||||
| if (diff.layout.dtype == DType()) { \ | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
| backward_data2<typename DTypeTrait<DType>::ctype>( \ | |||||
| diff, data1, data2, grad2, param())); \ | |||||
| return; \ | |||||
| } | |||||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
| #undef cb | |||||
| megdnn_throw("bad dtype"); | |||||
| } | |||||
| } // namespace naive | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,58 @@ | |||||
| /** | |||||
| * \file dnn/src/naive/correlation/opr_impl.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megdnn/oprs.h" | |||||
| namespace megdnn { | |||||
| namespace naive { | |||||
| class CorrelationForwardImpl final : public CorrelationForward { | |||||
| public: | |||||
| using CorrelationForward::CorrelationForward; | |||||
| void exec(_megdnn_tensor_in data1, _megdnn_tensor_in data2, | |||||
| _megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||||
| const TensorLayout&) override { | |||||
| return 0; | |||||
| } | |||||
| }; | |||||
| class CorrelationBackwardData1Impl final : public CorrelationBackwardData1 { | |||||
| public: | |||||
| using CorrelationBackwardData1::CorrelationBackwardData1; | |||||
| void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, | |||||
| _megdnn_tensor_in data2, _megdnn_tensor_out grad1, | |||||
| _megdnn_workspace workspace) override; | |||||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||||
| const TensorLayout&, | |||||
| const TensorLayout&) override { | |||||
| return 0; | |||||
| } | |||||
| }; | |||||
| class CorrelationBackwardData2Impl final : public CorrelationBackwardData2 { | |||||
| public: | |||||
| using CorrelationBackwardData2::CorrelationBackwardData2; | |||||
| void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, | |||||
| _megdnn_tensor_in data2, _megdnn_tensor_out grad2, | |||||
| _megdnn_workspace workspace) override; | |||||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||||
| const TensorLayout&, | |||||
| const TensorLayout&) override { | |||||
| return 0; | |||||
| } | |||||
| }; | |||||
| } // namespace naive | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -30,6 +30,7 @@ | |||||
| #include "src/naive/convpooling/opr_impl.h" | #include "src/naive/convpooling/opr_impl.h" | ||||
| #include "src/naive/cumsum/opr_impl.h" | #include "src/naive/cumsum/opr_impl.h" | ||||
| #include "src/naive/cvt_color/opr_impl.h" | #include "src/naive/cvt_color/opr_impl.h" | ||||
| #include "src/naive/correlation/opr_impl.h" | |||||
| #include "src/naive/dct/opr_impl.h" | #include "src/naive/dct/opr_impl.h" | ||||
| #include "src/naive/deformable_conv/opr_impl.h" | #include "src/naive/deformable_conv/opr_impl.h" | ||||
| #include "src/naive/deformable_ps_roi_pooling/opr_impl.h" | #include "src/naive/deformable_ps_roi_pooling/opr_impl.h" | ||||
| @@ -0,0 +1,73 @@ | |||||
| /** | |||||
| * \file dnn/test/common/correlation.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megdnn/basic_types.h" | |||||
| #include "megdnn/opr_param_defs.h" | |||||
| namespace megdnn { | |||||
| namespace test { | |||||
| namespace correlation { | |||||
| struct TestArg { | |||||
| param::Correlation param; | |||||
| TensorShape data1, data2; | |||||
| TestArg(param::Correlation param, TensorShape data1, TensorShape data2) | |||||
| : param(param), data1(data1), data2(data2) {} | |||||
| }; | |||||
| inline static std::vector<TestArg> get_args() { | |||||
| std::vector<TestArg> args; | |||||
| param::Correlation cur_param; | |||||
| for (size_t batch_size : {2}) { | |||||
| for (size_t channel : {2}) { | |||||
| for (size_t height : {160}) { | |||||
| for (size_t width : {160}) { | |||||
| cur_param.is_multiply = true; | |||||
| cur_param.kernel_size = 3; | |||||
| cur_param.max_displacement = 3; | |||||
| cur_param.pad_size = 0; | |||||
| cur_param.stride1 = 1; | |||||
| cur_param.stride2 = 1; | |||||
| cur_param.format = megdnn::param::Correlation::Format::NCHW; | |||||
| args.emplace_back( | |||||
| cur_param, | |||||
| TensorShape{batch_size, channel, height, width}, | |||||
| TensorShape{batch_size, channel, height, width}); | |||||
| // cur_param.is_multiply = false; | |||||
| // cur_param.kernel_size = 1; | |||||
| // cur_param.max_displacement = 2; | |||||
| // cur_param.pad_size = 1; | |||||
| // cur_param.stride1 = 1; | |||||
| // cur_param.stride2 = 1; | |||||
| // cur_param.format = | |||||
| // megdnn::param::Correlation::Format::NCHW; | |||||
| // args.emplace_back( | |||||
| // cur_param, | |||||
| // TensorShape{batch_size, channel, height, width}, | |||||
| // TensorShape{batch_size, channel, height, width}); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return args; | |||||
| } | |||||
| } // namespace correlation | |||||
| } // namespace test | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,160 @@ | |||||
| /** | |||||
| * \file dnn/test/cuda/correlation.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "test/cuda/fixture.h" | |||||
| #include "test/common/checker.h" | |||||
| #include "test/common/correlation.h" | |||||
| namespace megdnn { | |||||
| namespace test { | |||||
| TEST_F(CUDA, CORRELATION_FORWARD) { | |||||
| using namespace correlation; | |||||
| std::vector<TestArg> args = get_args(); | |||||
| Checker<Correlation> checker(handle_cuda()); | |||||
| for (auto&& arg : args) { | |||||
| checker.set_param(arg.param) | |||||
| .set_dtype(0, dtype::Float32()) | |||||
| .set_dtype(1, dtype::Float32()) | |||||
| .execs({arg.data1, arg.data2, {}}); | |||||
| } | |||||
| } | |||||
| TEST_F(CUDA, CORRELATION_BACKWARDDATA1) { | |||||
| ConstValue const_0{0}; | |||||
| using Param = CorrelationBackwardData1::Param; | |||||
| Param param; | |||||
| param.is_multiply = true; | |||||
| param.format = Param::Format::NCHW; | |||||
| param.stride1 = 2; | |||||
| param.stride2 = 2; | |||||
| param.kernel_size = 3; | |||||
| param.pad_size = 4; | |||||
| Checker<CorrelationBackwardData1> checker(handle_cuda()); | |||||
| checker.set_epsilon(1e-2); | |||||
| uint32_t pad_size = param.pad_size; | |||||
| uint32_t kernel_size = param.kernel_size; | |||||
| uint32_t stride1 = param.stride1; | |||||
| uint32_t stride2 = param.stride2; | |||||
| uint32_t max_displacement = param.max_displacement; | |||||
| auto run = [&](DType dtype) { | |||||
| for (size_t N : {1, 3}) | |||||
| for (size_t C : {1, 3}) | |||||
| for (size_t OH : {10, 100}) | |||||
| for (size_t OW : {10, 100}) { | |||||
| int paddedbottomheight = OH + 2 * pad_size; | |||||
| int paddedbottomwidth = OW + 2 * pad_size; | |||||
| uint32_t kernel_radius = (kernel_size - 1) / 2; | |||||
| uint32_t border_size = max_displacement + kernel_radius; | |||||
| uint32_t top_width = | |||||
| ceil(static_cast<float>(paddedbottomwidth - | |||||
| border_size * 2) / | |||||
| static_cast<float>(stride1)); | |||||
| uint32_t top_height = | |||||
| ceil(static_cast<float>(paddedbottomheight - | |||||
| border_size * 2) / | |||||
| static_cast<float>(stride1)); | |||||
| uint32_t neighborhood_grid_radius = | |||||
| max_displacement / stride2; | |||||
| uint32_t neighborhood_grid_width = | |||||
| neighborhood_grid_radius * 2 + 1; | |||||
| uint32_t top_channels = neighborhood_grid_width * | |||||
| neighborhood_grid_width; | |||||
| checker.set_param(param) | |||||
| .set_dtype(0, dtype) | |||||
| .set_dtype(1, dtype) | |||||
| .set_dtype(2, dtype) | |||||
| .set_dtype(3, dtype) | |||||
| .execs({{N, top_channels, top_height, | |||||
| top_width}, | |||||
| {N, C, OH, OW}, | |||||
| {N, C, OH, OW}, | |||||
| {N, C, OH, OW}}); | |||||
| } | |||||
| }; | |||||
| run(dtype::Float32()); | |||||
| run(dtype::Float16()); | |||||
| checker.set_epsilon(5e-2); | |||||
| run(dtype::BFloat16()); | |||||
| } | |||||
| TEST_F(CUDA, CORRELATION_BACKWARDDATA2) { | |||||
| ConstValue const_0{0}; | |||||
| using Param = CorrelationBackwardData2::Param; | |||||
| Param param; | |||||
| param.is_multiply = true; | |||||
| param.format = Param::Format::NCHW; | |||||
| param.stride1 = 2; | |||||
| param.stride2 = 2; | |||||
| param.kernel_size = 3; | |||||
| param.pad_size = 4; | |||||
| Checker<CorrelationBackwardData2> checker(handle_cuda()); | |||||
| checker.set_epsilon(1e-2); | |||||
| uint32_t pad_size = param.pad_size; | |||||
| uint32_t kernel_size = param.kernel_size; | |||||
| uint32_t stride1 = param.stride1; | |||||
| uint32_t stride2 = param.stride2; | |||||
| uint32_t max_displacement = param.max_displacement; | |||||
| auto run = [&](DType dtype) { | |||||
| for (size_t N : {1, 3}) | |||||
| for (size_t C : {1, 3}) | |||||
| for (size_t OH : {10, 100}) | |||||
| for (size_t OW : {10, 100}) { | |||||
| int paddedbottomheight = OH + 2 * pad_size; | |||||
| int paddedbottomwidth = OW + 2 * pad_size; | |||||
| uint32_t kernel_radius = (kernel_size - 1) / 2; | |||||
| uint32_t border_size = max_displacement + kernel_radius; | |||||
| uint32_t top_width = | |||||
| ceil(static_cast<float>(paddedbottomwidth - | |||||
| border_size * 2) / | |||||
| static_cast<float>(stride1)); | |||||
| uint32_t top_height = | |||||
| ceil(static_cast<float>(paddedbottomheight - | |||||
| border_size * 2) / | |||||
| static_cast<float>(stride1)); | |||||
| uint32_t neighborhood_grid_radius = | |||||
| max_displacement / stride2; | |||||
| uint32_t neighborhood_grid_width = | |||||
| neighborhood_grid_radius * 2 + 1; | |||||
| uint32_t top_channels = neighborhood_grid_width * | |||||
| neighborhood_grid_width; | |||||
| checker.set_param(param) | |||||
| .set_dtype(0, dtype) | |||||
| .set_dtype(1, dtype) | |||||
| .set_dtype(2, dtype) | |||||
| .set_dtype(3, dtype) | |||||
| .execs({{N, top_channels, top_height, | |||||
| top_width}, | |||||
| {N, C, OH, OW}, | |||||
| {N, C, OH, OW}, | |||||
| {N, C, OH, OW}}); | |||||
| } | |||||
| }; | |||||
| run(dtype::Float32()); | |||||
| run(dtype::Float16()); | |||||
| checker.set_epsilon(5e-2); | |||||
| run(dtype::BFloat16()); | |||||
| } | |||||
| } // namespace test | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||