| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * \file dnn/src/cuda/roi_align/roi_align.cu | |||
| * \file dnn/src/cuda/correlation/correlation_cuda.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -28,9 +28,9 @@ | |||
| #include "src/naive/convolution/opr_impl.h" | |||
| #include "src/naive/convolution3d/opr_impl.h" | |||
| #include "src/naive/convpooling/opr_impl.h" | |||
| #include "src/naive/correlation/opr_impl.h" | |||
| #include "src/naive/cumsum/opr_impl.h" | |||
| #include "src/naive/cvt_color/opr_impl.h" | |||
| #include "src/naive/correlation/opr_impl.h" | |||
| #include "src/naive/dct/opr_impl.h" | |||
| #include "src/naive/deformable_conv/opr_impl.h" | |||
| #include "src/naive/deformable_ps_roi_pooling/opr_impl.h" | |||
| @@ -38,6 +38,7 @@ | |||
| #include "src/naive/elemwise/opr_impl.h" | |||
| #include "src/naive/elemwise_multi_type/opr_impl.h" | |||
| #include "src/naive/eye/opr_impl.h" | |||
| #include "src/naive/fake_quant/opr_impl.h" | |||
| #include "src/naive/flip/opr_impl.h" | |||
| #include "src/naive/gaussian_blur/opr_impl.h" | |||
| #include "src/naive/group_local/opr_impl.h" | |||
| @@ -75,13 +76,11 @@ | |||
| #include "src/naive/tensor_remap/opr_impl.h" | |||
| #include "src/naive/tile/opr_impl.h" | |||
| #include "src/naive/topk/opr_impl.h" | |||
| #include "src/naive/tqt/opr_impl.h" | |||
| #include "src/naive/transpose/opr_impl.h" | |||
| #include "src/naive/type_cvt/opr_impl.h" | |||
| #include "src/naive/warp_affine/opr_impl.h" | |||
| #include "src/naive/warp_perspective/opr_impl.h" | |||
| #include "src/naive/remap/opr_impl.h" | |||
| #include "src/naive/fake_quant/opr_impl.h" | |||
| #include "src/naive/tqt/opr_impl.h" | |||
| static size_t g_image2d_pitch_alignment = 1; | |||
| @@ -45,19 +45,6 @@ inline static std::vector<TestArg> get_args() { | |||
| TensorShape{batch_size, channel, height, width}, | |||
| TensorShape{batch_size, channel, height, width}); | |||
| // cur_param.is_multiply = false; | |||
| // cur_param.kernel_size = 1; | |||
| // cur_param.max_displacement = 2; | |||
| // cur_param.pad_size = 1; | |||
| // cur_param.stride1 = 1; | |||
| // cur_param.stride2 = 1; | |||
| // cur_param.format = | |||
| // megdnn::param::Correlation::Format::NCHW; | |||
| // args.emplace_back( | |||
| // cur_param, | |||
| // TensorShape{batch_size, channel, height, width}, | |||
| // TensorShape{batch_size, channel, height, width}); | |||
| } | |||
| } | |||
| } | |||
| @@ -106,6 +106,43 @@ def roi_pooling( | |||
| return result | |||
| def correlation( | |||
| data1: Tensor, | |||
| data2: Tensor, | |||
| kernel_size: int = 1, | |||
| max_displacement: int = 1, | |||
| stride1: int = 1, | |||
| stride2: int = 1, | |||
| pad_size: int = 0, | |||
| is_multiply: bool = True, | |||
| ) -> Tensor: | |||
| """ Applies correlation to inputs. | |||
| :param data1: Input data1 to the correlation. format must be nchw | |||
| :param data2: Input data2 to the correlation. format must be nchw | |||
| :param kernel_size: (int (non-negative), optional, default=1) – kernel size for Correlation must be an odd number | |||
| :param max_displacement: (int (non-negative), optional, default=1) – Max displacement of Correlation | |||
| :param stride1: (int (non-negative), optional, default=1) – stride1 quantize data1 globally | |||
| :param stride2: (int (non-negative), optional, default=1) – stride2 quantize data2 within the neighborhood centered around data1 | |||
| :param pad_size: (int (non-negative), optional, default=0) – pad for Correlation | |||
| :param is_multiply: (boolean, optional, default=True) – operation type is either multiplication or absolute difference | |||
| """ | |||
| op = builtin.Correlation( | |||
| format="NCHW", | |||
| kernel_size=kernel_size, | |||
| max_displacement=max_displacement, | |||
| stride1=stride1, | |||
| stride2=stride2, | |||
| pad_size=pad_size, | |||
| is_multiply=is_multiply, | |||
| ) | |||
| result, *_ = apply(op, data1, data2) | |||
| return result | |||
| def roi_align( | |||
| inp: Tensor, | |||
| rois: Tensor, | |||
| @@ -228,6 +228,106 @@ def test_roi_align(): | |||
| assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) | |||
| def _gen_correlation(random=True, constant=1, image_shape=(2, 1, 160, 160)): | |||
| if random: | |||
| inp_feat1 = np.random.randn( | |||
| image_shape[0], image_shape[1], image_shape[2], image_shape[3] | |||
| ) | |||
| inp_feat2 = np.random.randn( | |||
| image_shape[0], image_shape[1], image_shape[2], image_shape[3] | |||
| ) | |||
| else: | |||
| inp_feat1 = np.ones(image_shape) * constant | |||
| inp_feat2 = np.ones(image_shape) * constant | |||
| return tensor(inp_feat1), tensor(inp_feat2) | |||
| def test_correlation(): | |||
| ##test case 0 check the grad shape | |||
| data1, data2 = _gen_correlation() | |||
| grad = Grad().wrt(data1, callback=_save_to(data1)) | |||
| out_feat = F.vision.correlation( | |||
| data1, | |||
| data2, | |||
| kernel_size=5, | |||
| max_displacement=4, | |||
| stride1=2, | |||
| stride2=2, | |||
| pad_size=2, | |||
| is_multiply=True, | |||
| ) | |||
| grad(out_feat, tensor(F.ones_like(out_feat))) | |||
| assert make_shape_tuple(data1.grad.shape) == make_shape_tuple(data1.shape) | |||
| ##test case 1 from https://github.com/NVIDIA/flownet2-pytorch/issues/194 | |||
| data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3)) | |||
| out_feat = F.vision.correlation( | |||
| data1, | |||
| data2, | |||
| kernel_size=3, | |||
| max_displacement=0, | |||
| stride1=1, | |||
| stride2=1, | |||
| pad_size=0, | |||
| is_multiply=True, | |||
| ) | |||
| assert abs(out_feat.sum() - 1) < 1e-9 | |||
| ##test case 2 check same image subduction | |||
| data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3)) | |||
| out_feat = F.vision.correlation( | |||
| data1, | |||
| data2, | |||
| kernel_size=3, | |||
| max_displacement=0, | |||
| stride1=1, | |||
| stride2=1, | |||
| pad_size=0, | |||
| is_multiply=False, | |||
| ) | |||
| assert out_feat.sum() < 1e-9 | |||
| ##test case 3 check same image subduction | |||
| data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3)) | |||
| out_feat = F.vision.correlation( | |||
| data1, | |||
| data2, | |||
| kernel_size=3, | |||
| max_displacement=0, | |||
| stride1=1, | |||
| stride2=1, | |||
| pad_size=0, | |||
| is_multiply=False, | |||
| ) | |||
| assert out_feat.sum() < 1e-9 | |||
| ##test case 4 check correlation | |||
| data1, _ = _gen_correlation( | |||
| random=False, image_shape=(1, 1, 220, 220), constant=2.0 | |||
| ) | |||
| _, data2 = _gen_correlation( | |||
| random=False, image_shape=(1, 1, 220, 220), constant=1.0 | |||
| ) | |||
| out_feat = F.vision.correlation( | |||
| data1, | |||
| data2, | |||
| kernel_size=3, | |||
| max_displacement=2, | |||
| stride1=1, | |||
| stride2=2, | |||
| pad_size=0, | |||
| is_multiply=False, | |||
| ) | |||
| assert abs(out_feat.mean() - 1) < 1e-9 | |||
| def test_roi_pooling(): | |||
| inp_feat, rois = _gen_roi_inp() | |||
| grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat)) | |||
| @@ -19,6 +19,7 @@ | |||
| #include "megbrain/opr/dnn/pooling.h" | |||
| #include "megbrain/opr/dnn/local.h" | |||
| #include "megbrain/opr/dnn/roi_align.h" | |||
| #include "megbrain/opr/dnn/correlation.h" | |||
| #include "megbrain/opr/dnn/roi_pooling.h" | |||
| #include "megbrain/opr/basic_arith.h" | |||
| #include "megbrain/opr/blas.h" | |||
| @@ -445,6 +446,21 @@ OP_TRAIT_REG(ROIAlign, ROIAlign) | |||
| .fallback(); | |||
| }} // roi_align | |||
| namespace { namespace correlation { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const Correlation&>(def); | |||
| mgb_assert(inputs.size() == 2); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::Correlation::make( | |||
| inputs[0], inputs[1], op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(Correlation, Correlation) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // correlation | |||
| #if MGB_CUDA | |||
| namespace { namespace nvof { | |||
| auto apply_on_var_node( | |||
| @@ -82,6 +82,7 @@ def BatchConvBias : MgbHashableOp<"BatchConvBias", [BatchConvBiasParam, Executio | |||
| def BatchNorm : MgbHashableOp<"BatchNorm", [BNParam]>; | |||
| def ROIAlign: MgbHashableOp<"ROIAlign", [ROIAlignParam]>; | |||
| def Correlation: MgbHashableOp<"Correlation", [CorrelationParam]>; | |||
| def WarpPerspective: MgbHashableOp<"WarpPerspective", [WarpPerspectiveParam]>; | |||
| @@ -0,0 +1,109 @@ | |||
| /** | |||
| * \file src/opr/impl/dnn/correlation.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "megbrain/opr/dnn/correlation.h" | |||
| #include "megbrain/graph/grad_impl.h" | |||
| #include "megbrain/opr/internal/out_shape_by_sym_var.h" | |||
| #include "megbrain/opr/utility.h" | |||
| #include "../internal/megdnn_opr_wrapper.inl" | |||
| using namespace mgb; | |||
| using namespace opr; | |||
| /* ==================== CorrelationForward ==================== */ | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(CorrelationForward); | |||
| CorrelationForward::CorrelationForward(VarNode* data1, VarNode* data2, | |||
| const Param& param, | |||
| const OperatorNodeConfig& config) | |||
| : Super{data1->owner_graph(), config, "correlation", {data1, data2}} { | |||
| init_megdnn_opr(*this, param); | |||
| mgb_assert(data1->dtype() == data2->dtype()); | |||
| mgb_assert(data1->dtype().category() == DTypeCategory::FLOAT); | |||
| add_input({data1, data2}); | |||
| output(0)->dtype(data1->dtype()); | |||
| } | |||
| SymbolVar CorrelationForward::make(SymbolVar data1, SymbolVar data2, | |||
| const Param& param, | |||
| const OperatorNodeConfig& config) { | |||
| return data1.insert_single_output_opr<CorrelationForward>( | |||
| data1.node(), data2.node(), param, config); | |||
| } | |||
| #if MGB_ENABLE_GRAD | |||
| MGB_IMPL_OPR_GRAD(CorrelationForward) { | |||
| if (wrt_idx == 0) { | |||
| // wrt src | |||
| SymbolVar grad = CorrelationBackwardData1::make( | |||
| out_grad[0], opr.input(0), opr.input(1), opr.param(), | |||
| opr.config()); | |||
| return grad.node(); | |||
| } else { | |||
| mgb_assert(wrt_idx == 1); | |||
| SymbolVar grad = CorrelationBackwardData2::make( | |||
| out_grad[0], opr.input(0), opr.input(1), opr.param(), | |||
| opr.config()); | |||
| return grad.node(); | |||
| } | |||
| } | |||
| #endif | |||
| /* ==================== CorrelationBackwardData1 ==================== */ | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(CorrelationBackwardData1); | |||
| MEGDNN_OPR_INIT3(CorrelationBackwardData1, "correlation_backward_data1", 1, | |||
| true); | |||
| void CorrelationBackwardData1::scn_do_execute() { | |||
| megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(), | |||
| input(1)->dev_tensor().as_megdnn(), | |||
| input(2)->dev_tensor().as_megdnn(), | |||
| output(0)->dev_tensor().as_megdnn(), | |||
| intl::get_megdnn_workspace_from_var(output(1))); | |||
| } | |||
| size_t CorrelationBackwardData1::get_workspace_size_bytes( | |||
| const TensorShapeArray& inp_shapes, | |||
| const TensorShapeArray& out_shapes) const { | |||
| TensorLayout diff{inp_shapes[0], input(0)->dtype(), input(0)->format()}, | |||
| data1{inp_shapes[1], input(1)->dtype(), input(1)->format()}, | |||
| data2{inp_shapes[2], input(2)->dtype(), input(2)->format()}, | |||
| grad1{out_shapes[0], output(0)->dtype(), output(0)->format()}; | |||
| return megdnn_opr()->get_workspace_in_bytes(diff, data1, data2, grad1); | |||
| } | |||
| /* ==================== CorrelationBackwardData2 ==================== */ | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(CorrelationBackwardData2); | |||
| MEGDNN_OPR_INIT3(CorrelationBackwardData2, "correlation_backward_data2", 1, | |||
| true); | |||
| void CorrelationBackwardData2::scn_do_execute() { | |||
| megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(), | |||
| input(1)->dev_tensor().as_megdnn(), | |||
| input(2)->dev_tensor().as_megdnn(), | |||
| output(0)->dev_tensor().as_megdnn(), | |||
| intl::get_megdnn_workspace_from_var(output(1))); | |||
| } | |||
| size_t CorrelationBackwardData2::get_workspace_size_bytes( | |||
| const TensorShapeArray& inp_shapes, | |||
| const TensorShapeArray& out_shapes) const { | |||
| TensorLayout diff{inp_shapes[0], input(0)->dtype(), input(0)->format()}, | |||
| data1{inp_shapes[1], input(1)->dtype(), input(1)->format()}, | |||
| data2{inp_shapes[2], input(2)->dtype(), input(2)->format()}, | |||
| grad2{out_shapes[0], output(0)->dtype(), output(0)->format()}; | |||
| return megdnn_opr()->get_workspace_in_bytes(diff, data1, data2, grad2); | |||
| } | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -11,6 +11,7 @@ | |||
| #include "megbrain/opr/dnn/batch_norm.h" | |||
| #include "megbrain/opr/dnn/convolution.h" | |||
| #include "megbrain/opr/dnn/correlation.h" | |||
| #include "megbrain/opr/dnn/images2neibs.h" | |||
| #include "megbrain/opr/dnn/pooling.h" | |||
| #include "megbrain/opr/dnn/adaptive_pooling.h" | |||
| @@ -573,6 +574,10 @@ MGB_SEREG_OPR(DeformableConvForwardV1, 0); | |||
| MGB_SEREG_OPR(DeformableConvBackwardDataV1, 0); | |||
| MGB_SEREG_OPR(DeformableConvBackwardFilterV1, 0); | |||
| MGB_SEREG_OPR(CorrelationForward, 2); | |||
| MGB_SEREG_OPR(CorrelationBackwardData1, 3); | |||
| MGB_SEREG_OPR(CorrelationBackwardData2, 3); | |||
| MGB_SEREG_OPR(DeformablePSROIPoolingForward, 3); | |||
| MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5); | |||
| @@ -0,0 +1,69 @@ | |||
| /** | |||
| * \file src/opr/include/megbrain/opr/dnn/correlation.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
| #include "megdnn/oprs.h" | |||
| namespace mgb { | |||
| namespace opr { | |||
| MGB_DEFINE_OPR_CLASS(CorrelationForward, | |||
| intl::MegDNNOprWrapperFwd<megdnn::CorrelationForward>) // { | |||
| public: | |||
| CorrelationForward(VarNode* data1, VarNode* data2, const Param& param, | |||
| const OperatorNodeConfig& config); | |||
| static SymbolVar make(SymbolVar data1, SymbolVar data2, | |||
| const Param& param = {}, | |||
| const OperatorNodeConfig& config = {}); | |||
| }; | |||
| using Correlation = CorrelationForward; | |||
| MGB_DEFINE_OPR_CLASS( | |||
| CorrelationBackwardData1, intl::MegDNNOprWrapperBwd<megdnn::CorrelationBackwardData1>) // { | |||
| public: | |||
| CorrelationBackwardData1(VarNode* diff, VarNode* data1, VarNode* data2, | |||
| const Param& param, const OperatorNodeConfig& config); | |||
| static SymbolVar make(SymbolVar diff, SymbolVar data1, SymbolVar data2, | |||
| const Param& param = {}, | |||
| const OperatorNodeConfig& config = {}); | |||
| private: | |||
| void scn_do_execute() override; | |||
| size_t get_workspace_size_bytes( | |||
| const TensorShapeArray& input_shapes, | |||
| const TensorShapeArray& output_shapes) const override; | |||
| }; | |||
| MGB_DEFINE_OPR_CLASS( | |||
| CorrelationBackwardData2, intl::MegDNNOprWrapperBwd<megdnn::CorrelationBackwardData2>) // { | |||
| public: | |||
| CorrelationBackwardData2(VarNode* diff, VarNode* data1, VarNode* data2, | |||
| const Param& param, const OperatorNodeConfig& config); | |||
| static SymbolVar make(SymbolVar diff, SymbolVar data1, SymbolVar data2, | |||
| const Param& param = {}, | |||
| const OperatorNodeConfig& config = {}); | |||
| private: | |||
| void scn_do_execute() override; | |||
| size_t get_workspace_size_bytes( | |||
| const TensorShapeArray& input_shapes, | |||
| const TensorShapeArray& output_shapes) const override; | |||
| }; | |||
| } // namespace opr | |||
| } // namespace mgb | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -0,0 +1,134 @@ | |||
| /** | |||
| * \file src/opr/test/dnn/correlation.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "megbrain/opr/dnn/correlation.h" | |||
| #include "megbrain/test/autocheck.h" | |||
| #include "megbrain/test/helper.h" | |||
| #include "megbrain/test/megdnn_helper.h" | |||
| #include "megdnn/oprs.h" | |||
| #include <cmath> | |||
| #include <iomanip> | |||
| #include <random> | |||
| #include <sstream> | |||
| using namespace mgb; | |||
| namespace { | |||
| using Param = opr::CorrelationForward::Param; | |||
| void run_forward(bool is_multiply) { | |||
| RNGxorshf rng{next_rand_seed()}; | |||
| using Checker = AutoOprChecker<2, 1>; | |||
| Param param; | |||
| param.format = Param::Format::NCHW; | |||
| param.is_multiply = is_multiply; | |||
| param.kernel_size = 3; | |||
| param.max_displacement = 2; | |||
| param.pad_size = 1; | |||
| param.stride1 = 2; | |||
| param.stride2 = 2; | |||
| auto make_graph = | |||
| [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { | |||
| auto o0 = opr::CorrelationForward::make(inputs[0], inputs[1], param); | |||
| return {o0}; | |||
| }; | |||
| auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { | |||
| auto opr = megdnn_naive_handle() | |||
| ->create_operator<megdnn::CorrelationForward>(); | |||
| opr->param() = param; | |||
| auto inp_shape = inp[0]->shape(); | |||
| auto num = inp_shape[0]; | |||
| auto height = inp_shape[2]; | |||
| auto width = inp_shape[3]; | |||
| uint32_t pad_size = param.pad_size; | |||
| uint32_t kernel_size = param.kernel_size; | |||
| uint32_t stride1 = param.stride1; | |||
| uint32_t stride2 = param.stride2; | |||
| uint32_t max_displacement = param.max_displacement; | |||
| int paddedbottomheight = height + 2 * pad_size; | |||
| int paddedbottomwidth = width + 2 * pad_size; | |||
| uint32_t kernel_radius = (kernel_size - 1) / 2; | |||
| uint32_t border_size = max_displacement + kernel_radius; | |||
| uint32_t top_width = | |||
| ceil(static_cast<float>(paddedbottomwidth - border_size * 2) / | |||
| static_cast<float>(stride1)); | |||
| uint32_t top_height = | |||
| ceil(static_cast<float>(paddedbottomheight - border_size * 2) / | |||
| static_cast<float>(stride1)); | |||
| uint32_t neighborhood_grid_radius = max_displacement / stride2; | |||
| uint32_t neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; | |||
| uint32_t top_channels = | |||
| neighborhood_grid_width * neighborhood_grid_width; | |||
| megdnn::TensorShape target_shape{num, top_channels, top_height, | |||
| top_width}; | |||
| dest[0].dtype(dtype::Float32()) | |||
| .comp_node(inp[0]->comp_node()) | |||
| .resize(target_shape); | |||
| opr->exec(inp[0]->as_megdnn(), inp[1]->as_megdnn(), dest[0].as_megdnn(), | |||
| {}); | |||
| }; | |||
| auto rand_real = [&](float lo, float hi) { | |||
| std::uniform_real_distribution<float> dist(lo, hi); | |||
| return dist(rng); | |||
| }; | |||
| auto gen_inp1 = [&](HostTensorND &inp) { | |||
| auto ptr = inp.ptr<float>(); | |||
| for (size_t i = 0; i < inp.shape().total_nr_elems(); ++i) { | |||
| ptr[i] = rand_real(0.06f, 0.1f); | |||
| }; | |||
| }; | |||
| auto gen_inp2 = [&](HostTensorND &inp) { | |||
| auto ptr = inp.ptr<float>(); | |||
| for (size_t i = 0; i < inp.shape().total_nr_elems(); ++i) { | |||
| ptr[i] = rand_real(0.01f, 0.04f); | |||
| }; | |||
| }; | |||
| Checker::RunOptions option; | |||
| option.numdiff_eps = 1e-3; | |||
| option.numdiff_max_err = 1e-2; | |||
| Checker checker{make_graph, fwd}; | |||
| checker.set_input_generator(0, gen_inp1); | |||
| checker.set_input_generator(1, gen_inp2); | |||
| checker.run({TensorShape{1, 1, 10, 10}, TensorShape{1, 1, 10, 10}}, option) | |||
| .run({TensorShape{1, 3, 50, 50}, TensorShape{1, 3, 50, 50}}, option) | |||
| .run({TensorShape{1, 1, 100, 100}, TensorShape{1, 1, 100, 100}}, | |||
| option); | |||
| } | |||
| TEST(TestOprDNN, CorrelationForwardMultiply) { | |||
| // TODO: fix me, add correct backward of cpu | |||
| REQUIRE_GPU(1); | |||
| run_forward(true); | |||
| } | |||
| TEST(TestOprDNN, CorrelationForwardSubstract) { | |||
| // TODO: fix me, add correct backward of cpu | |||
| REQUIRE_GPU(1); | |||
| run_forward(false); | |||
| } | |||
| } // anonymous namespace | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -106,6 +106,7 @@ union OperatorParam { | |||
| param.DctChannelSelect = 72, | |||
| param.FakeQuant = 73, | |||
| param.TQT = 74, | |||
| param.Correlation = 75, | |||
| } | |||
| table Operator { | |||