| @@ -259,6 +259,9 @@ public: | |||||
| BATCH_CONV_FORWARD, | BATCH_CONV_FORWARD, | ||||
| POOLING_FORWARD, | POOLING_FORWARD, | ||||
| POOLING_BACKWARD, | POOLING_BACKWARD, | ||||
| REGIONRESTRICTEDCONVOLUTION_FORWARD, | |||||
| REGIONRESTRICTEDCONVOLUTION_BACKWARD_DATA, | |||||
| REGIONRESTRICTEDCONVOLUTION_BACKWARD_FILTER, | |||||
| }; | }; | ||||
| struct SearchItem { | struct SearchItem { | ||||
| @@ -535,6 +535,131 @@ protected: | |||||
| }; | }; | ||||
| using ConvBias = ConvBiasForward; | using ConvBias = ConvBiasForward; | ||||
| /** | |||||
| * \brief RegionRestrictedConvolutionForward operator. | |||||
| */ | |||||
| class RegionRestrictedConvolutionForward : public ConvolutionBase<param::Convolution> { | |||||
| DEF_OPR_IMPL(RegionRestrictedConvolutionForward, ConvolutionBase, 4, 1); | |||||
| public: | |||||
| /** | |||||
| * \param[in] src (n, ic, ih, iw) | |||||
| * \param[in] filter (oc, ic, fh, fw) | |||||
| * \param[in] rin (n, ih, iw) | |||||
| * \param[in] rout (n, oh, ow) | |||||
| * \param[out] dst (n, oc, oh, ow) | |||||
| */ | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in rin, | |||||
| _megdnn_tensor_in rout, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| void deduce_dtype(DType src, DType filter, DType rin, DType rout, DType& dst); | |||||
| MGE_WIN_DECLSPEC_FUC void deduce_layout( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& rin, const TensorLayout& rout, TensorLayout& dst); | |||||
| /** | |||||
| * \brief query the workspace needed when executing the opr | |||||
| * \return the size of workspace needed when executing | |||||
| */ | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& rin, const TensorLayout& rout, | |||||
| const TensorLayout& dst) = 0; | |||||
| static Algorithm::OprType get_opr_type() { | |||||
| return Algorithm::OprType::REGIONRESTRICTEDCONVOLUTION_FORWARD; | |||||
| } | |||||
| protected: | |||||
| CanonizedFilterMeta check_exec( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& rin, const TensorLayout& rout, const TensorLayout& dst, | |||||
| size_t workspace_in_bytes); | |||||
| }; | |||||
| using RegionRestrictedConvolution = RegionRestrictedConvolutionForward; | |||||
| /** | |||||
| * \brief RegionRestrictedConvolutionBackwardData operator. | |||||
| * | |||||
| * Calculating the gradient wrt. convolution input data. | |||||
| */ | |||||
| class RegionRestrictedConvolutionBackwardData | |||||
| : public ConvolutionBase<param::Convolution> { | |||||
| DEF_OPR_IMPL(RegionRestrictedConvolutionBackwardData, ConvolutionBase, 4, 1); | |||||
| public: | |||||
| /** | |||||
| * \param[in] filter (oc, ic, fh, fw) | |||||
| * \param[in] diff (n, oc, oh, ow) | |||||
| * \param[in] rin (n, ih, iw) | |||||
| * \param[in] rout (n, oh, ow) | |||||
| * \param[out] grad (n, ic, ih, iw) | |||||
| */ | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_in rin, | |||||
| _megdnn_tensor_in rout, _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& filter, const TensorLayout& diff, | |||||
| const TensorLayout& rin, const TensorLayout& rout, | |||||
| const TensorLayout& grad) = 0; | |||||
| MGE_WIN_DECLSPEC_FUC void deduce_dtype( | |||||
| DType filter, DType diff, DType rin, DType rout, DType& grad); | |||||
| MGE_WIN_DECLSPEC_FUC void deduce_layout( | |||||
| const TensorLayout& filter, const TensorLayout& diff, | |||||
| const TensorLayout& rin, const TensorLayout& rout, TensorLayout& grad); | |||||
| static Algorithm::OprType get_opr_type() { | |||||
| return Algorithm::OprType::REGIONRESTRICTEDCONVOLUTION_BACKWARD_DATA; | |||||
| } | |||||
| protected: | |||||
| CanonizedFilterMeta check_exec( | |||||
| const TensorLayout& filter, const TensorLayout& diff, | |||||
| const TensorLayout& rin, const TensorLayout& rout, const TensorLayout& grad, | |||||
| size_t workspace_in_bytes); | |||||
| }; | |||||
| /** | |||||
| * \brief RegionRestrictedConvolutionBackwardFilter operator. | |||||
| * | |||||
| * Calculating the gradient wrt. convolution filter. | |||||
| */ | |||||
| class RegionRestrictedConvolutionBackwardFilter | |||||
| : public ConvolutionBase<param::Convolution> { | |||||
| DEF_OPR_IMPL(RegionRestrictedConvolutionBackwardFilter, ConvolutionBase, 4, 1); | |||||
| public: | |||||
| /** | |||||
| * \param[in] src (n, ic, ih, iw) | |||||
| * \param[in] diff (n, oc, oh, ow) | |||||
| * \param[in] rin (n, ih, iw) | |||||
| * \param[in] rout (n, oh, ow) | |||||
| * \param[out] grad (oc, ic, fh, fw) | |||||
| */ | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_in rin, | |||||
| _megdnn_tensor_in rout, _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& src, const TensorLayout& diff, const TensorLayout& rin, | |||||
| const TensorLayout& rout, const TensorLayout& grad) = 0; | |||||
| static Algorithm::OprType get_opr_type() { | |||||
| return Algorithm::OprType::REGIONRESTRICTEDCONVOLUTION_BACKWARD_FILTER; | |||||
| } | |||||
| protected: | |||||
| CanonizedFilterMeta check_exec( | |||||
| const TensorLayout& src, const TensorLayout& diff, const TensorLayout& rin, | |||||
| const TensorLayout& rout, const TensorLayout& grad, | |||||
| size_t workspace_in_bytes); | |||||
| }; | |||||
| /** | /** | ||||
| * \brief base class for Conv - Nonline - Pooling | * \brief base class for Conv - Nonline - Pooling | ||||
| */ | */ | ||||
| @@ -213,7 +213,10 @@ private: | |||||
| cb(LSTMBackward) \ | cb(LSTMBackward) \ | ||||
| cb(SoftmaxForward) \ | cb(SoftmaxForward) \ | ||||
| cb(SoftmaxBackward) \ | cb(SoftmaxBackward) \ | ||||
| cb(NormForward) | |||||
| cb(NormForward) \ | |||||
| cb(RegionRestrictedConvolutionForward) \ | |||||
| cb(RegionRestrictedConvolutionBackwardData) \ | |||||
| cb(RegionRestrictedConvolutionBackwardFilter) | |||||
| // clang-format on | // clang-format on | ||||
| /*! | /*! | ||||
| @@ -139,6 +139,9 @@ DEF(LSTMForward, 8, true, true); | |||||
| DEF(LSTMBackward, 13, true, true); | DEF(LSTMBackward, 13, true, true); | ||||
| DEF(SoftmaxForward, 2, true, true); | DEF(SoftmaxForward, 2, true, true); | ||||
| DEF(SoftmaxBackward, 3, true, false); | DEF(SoftmaxBackward, 3, true, false); | ||||
| DEF(RegionRestrictedConvolutionForward, 5, true, true); | |||||
| DEF(RegionRestrictedConvolutionBackwardData, 5, true, false); | |||||
| DEF(RegionRestrictedConvolutionBackwardFilter, 5, true, false); | |||||
| } // namespace megdnn | } // namespace megdnn | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -0,0 +1,216 @@ | |||||
| #include "megdnn/oprs/nn.h" | |||||
| #include "src/common/utils.cuh" | |||||
| #include "src/common/utils.h" | |||||
| using namespace megdnn; | |||||
| namespace { | |||||
| template <typename Param> | |||||
| std::string get_errmsg( | |||||
| const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, | |||||
| const Param& param) { | |||||
| MEGDNN_MARK_USED_VAR(src); | |||||
| MEGDNN_MARK_USED_VAR(filter); | |||||
| MEGDNN_MARK_USED_VAR(dst); | |||||
| return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter) + ", " + | |||||
| megdnn_layout_msg(dst) + ", " + "is_nchw=" + | |||||
| std::to_string(param.format == param::Convolution::Format::NCHW) + ", " + | |||||
| "is_xcorr=" + | |||||
| std::to_string((param.mode == Convolution::Mode::CROSS_CORRELATION)) + ", " + | |||||
| "pad_h=" + std::to_string(param.pad_h) + ", " + | |||||
| "pad_w=" + std::to_string(param.pad_w) + ", " + | |||||
| "stride_h=" + std::to_string(param.stride_h) + ", " + | |||||
| "stride_w=" + std::to_string(param.stride_w) + ", " + | |||||
| "dilate_h=" + std::to_string(param.dilate_h) + ", " + | |||||
| "dilate_w=" + std::to_string(param.dilate_w); | |||||
| } | |||||
| } // namespace | |||||
| namespace megdnn { | |||||
| void RegionRestrictedConvolutionForward::deduce_dtype( | |||||
| DType src, DType filter, DType rin, DType rout, DType& dst) { | |||||
| check_or_deduce_dtype_fwd(src, filter, dst); | |||||
| megdnn_assert( | |||||
| rin == rout && rin == dtype::Int32(), | |||||
| "the dtype of rin/rout should be Int32, got %s.", rin.name()); | |||||
| } | |||||
| void RegionRestrictedConvolutionForward::deduce_layout( | |||||
| const TensorLayout& src, const TensorLayout& filter, const TensorLayout& rin, | |||||
| const TensorLayout& rout, TensorLayout& dst) { | |||||
| MEGDNN_MARK_USED_VAR(rin); | |||||
| MEGDNN_MARK_USED_VAR(rout); | |||||
| deduce_layout_fwd(src, filter, dst); | |||||
| } | |||||
| RegionRestrictedConvolutionForward::CanonizedFilterMeta | |||||
| RegionRestrictedConvolutionForward::check_exec( | |||||
| const TensorLayout& src, const TensorLayout& filter, const TensorLayout& rin, | |||||
| const TensorLayout& rout, const TensorLayout& dst, size_t workspace_in_bytes) { | |||||
| auto ret = check_layout_fwd(src, filter, dst); | |||||
| megdnn_assert( | |||||
| param().format == Param::Format::NCHW, | |||||
| "RegionRestrictedConv only support NCHW format mow."); | |||||
| #define err_msg(lhs, rhs) \ | |||||
| megdnn_assert(lhs == rhs, "shape mismatch, #lhs:%zu, #rhs:%zu", lhs, rhs); | |||||
| err_msg(rin.shape[0], src.shape[0]); | |||||
| err_msg(rin.shape[1], src.shape[2]); | |||||
| err_msg(rin.shape[2], src.shape[3]); | |||||
| err_msg(rout.shape[0], dst.shape[0]); | |||||
| err_msg(rout.shape[1], dst.shape[2]); | |||||
| err_msg(rout.shape[2], dst.shape[3]); | |||||
| #undef err_msg | |||||
| auto required_workspace_in_bytes = | |||||
| get_workspace_in_bytes(src, filter, rin, rout, dst); | |||||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||||
| return ret; | |||||
| } | |||||
| RegionRestrictedConvolutionBackwardData::CanonizedFilterMeta | |||||
| RegionRestrictedConvolutionBackwardData::check_exec( | |||||
| const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& rin, | |||||
| const TensorLayout& rout, const TensorLayout& grad, size_t workspace_in_bytes) { | |||||
| auto grad_fwd = grad; | |||||
| auto filter_fwd = filter; | |||||
| auto diff_fwd = diff; | |||||
| std::swap(grad_fwd.dtype, diff_fwd.dtype); | |||||
| grad_fwd.init_contiguous_stride(); | |||||
| diff_fwd.init_contiguous_stride(); | |||||
| auto ret = check_layout_fwd(grad_fwd, filter_fwd, diff_fwd); | |||||
| #define err_msg(lhs, rhs) \ | |||||
| megdnn_assert(lhs == rhs, "shape mismatch, #lhs:%zu, #rhs:%zu", lhs, rhs); | |||||
| err_msg(rin.shape[0], grad_fwd.shape[0]); | |||||
| err_msg(rin.shape[1], grad_fwd.shape[2]); | |||||
| err_msg(rin.shape[2], grad_fwd.shape[3]); | |||||
| err_msg(rout.shape[0], diff_fwd.shape[0]); | |||||
| err_msg(rout.shape[1], diff_fwd.shape[2]); | |||||
| err_msg(rout.shape[2], diff_fwd.shape[3]); | |||||
| #undef err_msg | |||||
| auto required_workspace_in_bytes = | |||||
| get_workspace_in_bytes(filter, diff, rin, rout, grad); | |||||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||||
| return ret; | |||||
| } | |||||
| void RegionRestrictedConvolutionBackwardData::deduce_dtype( | |||||
| DType filter, DType diff, DType rin, DType rout, DType& grad) { | |||||
| SmallVector<DType> supported_dst_dtype; | |||||
| if (filter.category() == diff.category() && | |||||
| filter.category() == DTypeCategory::FLOAT) { | |||||
| supported_dst_dtype.push_back(filter); | |||||
| } else if (filter.enumv() == DTypeEnum::Int8 && diff == filter) { | |||||
| supported_dst_dtype.push_back(dtype::Int32()); | |||||
| } else if ( | |||||
| (filter.enumv() == DTypeEnum::QuantizedS8 && | |||||
| diff.enumv() == DTypeEnum::QuantizedS8) || | |||||
| (filter.enumv() == DTypeEnum::Quantized8Asymm && | |||||
| diff.enumv() == DTypeEnum::Quantized8Asymm)) { | |||||
| supported_dst_dtype.push_back(dtype::QuantizedS32(mul_scale(filter, diff))); | |||||
| if (grad.valid() && grad.enumv() == diff.enumv()) { | |||||
| supported_dst_dtype.push_back(grad); | |||||
| } | |||||
| } else { | |||||
| megdnn_throw(ssprintf( | |||||
| "unsupported input / diff DType: %s x %s", filter.name(), diff.name())); | |||||
| } | |||||
| if (!grad.valid()) { | |||||
| grad = supported_dst_dtype.at(0); | |||||
| } else { | |||||
| megdnn_assert( | |||||
| vec_contains(supported_dst_dtype, grad), | |||||
| "unsupported ConvBwd(%s, %s) -> %s", filter.name(), diff.name(), | |||||
| grad.name()); | |||||
| } | |||||
| megdnn_assert( | |||||
| param().compute_mode != Param::ComputeMode::FLOAT32 | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| || filter.enumv() == DTypeEnum::Float16 || | |||||
| filter.enumv() == DTypeEnum::BFloat16 | |||||
| #endif | |||||
| , | |||||
| "ComputeMode::FLOAT32 is only available for Float16/BFloat16 " | |||||
| "input / output."); | |||||
| megdnn_assert( | |||||
| rin == rout && rin == dtype::Int32(), | |||||
| "the dtype of rin/rout should be Int32, got %s.", rin.name()); | |||||
| } | |||||
| void RegionRestrictedConvolutionBackwardData::deduce_layout( | |||||
| const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& rin, | |||||
| const TensorLayout& rout, TensorLayout& grad) { | |||||
| auto errmsg = [&]() { return get_errmsg(filter, diff, grad, param()); }; | |||||
| MEGDNN_MARK_USED_VAR(errmsg); | |||||
| megdnn_assert_contiguous(filter); | |||||
| megdnn_assert_contiguous(diff); | |||||
| megdnn_assert(filter.ndim == 4_z || filter.ndim == 5_z, "%s", errmsg().c_str()); | |||||
| megdnn_assert(diff.ndim == 4_z || diff.ndim == 5_z, "%s", errmsg().c_str()); | |||||
| deduce_dtype(filter.dtype, diff.dtype, rin.dtype, rout.dtype, grad.dtype); | |||||
| auto cflt = make_canonized_filter_meta(diff.ndim, filter); | |||||
| auto deduce = [&errmsg](size_t out, size_t filter, size_t stride, size_t pad) { | |||||
| MEGDNN_MARK_USED_VAR(errmsg); | |||||
| auto i = (out - 1) * stride + filter; | |||||
| megdnn_assert(i > pad * 2, "%s", errmsg().c_str()); | |||||
| return i - pad * 2; | |||||
| }; | |||||
| megdnn_assert( | |||||
| param().format == Param::Format::NCHW, | |||||
| "RegionRestrictedConvolutionBackwardData only support NCHW format mow."); | |||||
| size_t src_or_dst_c_pos = 1; | |||||
| size_t src_or_dst_spatial_start = 2; | |||||
| megdnn_assert( | |||||
| cflt.ocpg * cflt.group == diff[src_or_dst_c_pos], "%s", errmsg().c_str()); | |||||
| grad.ndim = diff.ndim; | |||||
| grad[0] = diff[0]; | |||||
| grad[src_or_dst_c_pos] = cflt.icpg * cflt.group; | |||||
| for (size_t i = 0; i < cflt.spatial_ndim; ++i) { | |||||
| grad[i + src_or_dst_spatial_start] = | |||||
| deduce(diff[i + src_or_dst_spatial_start], cflt.dilated_spatial[i], | |||||
| cflt.stride[i], cflt.padding[i]); | |||||
| } | |||||
| grad.format = diff.format; | |||||
| grad.init_contiguous_stride(); | |||||
| } | |||||
| RegionRestrictedConvolutionBackwardFilter::CanonizedFilterMeta | |||||
| RegionRestrictedConvolutionBackwardFilter::check_exec( | |||||
| const TensorLayout& src, const TensorLayout& diff, const TensorLayout& rin, | |||||
| const TensorLayout& rout, const TensorLayout& grad, size_t workspace_in_bytes) { | |||||
| megdnn_assert( | |||||
| src.dtype.category() == DTypeCategory::FLOAT && | |||||
| diff.dtype.category() == DTypeCategory::FLOAT && | |||||
| grad.dtype.category() == DTypeCategory::FLOAT, | |||||
| "only float type is supported for conv backward filter"); | |||||
| auto src_fwd = src; | |||||
| auto diff_fwd = diff; | |||||
| src_fwd.init_contiguous_stride(); | |||||
| diff_fwd.init_contiguous_stride(); | |||||
| auto ret = check_layout_fwd(src_fwd, grad, diff_fwd); | |||||
| megdnn_assert( | |||||
| param().format == Param::Format::NCHW, | |||||
| "RegionRestrictedConvolutionBackwardFilter only support NCHW format mow."); | |||||
| #define err_msg(lhs, rhs) \ | |||||
| megdnn_assert(lhs == rhs, "shape mismatch, #lhs:%zu, #rhs:%zu", lhs, rhs); | |||||
| err_msg(rin.shape[0], src_fwd.shape[0]); | |||||
| err_msg(rin.shape[1], src_fwd.shape[2]); | |||||
| err_msg(rin.shape[2], src_fwd.shape[3]); | |||||
| err_msg(rout.shape[0], diff_fwd.shape[0]); | |||||
| err_msg(rout.shape[1], diff_fwd.shape[2]); | |||||
| err_msg(rout.shape[2], diff_fwd.shape[3]); | |||||
| #undef err_msg | |||||
| auto required_workspace_in_bytes = | |||||
| get_workspace_in_bytes(src, diff, rin, rout, grad); | |||||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||||
| return ret; | |||||
| } | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -877,6 +877,141 @@ void forward_bias( | |||||
| } | } | ||||
| } | } | ||||
| template < | |||||
| typename stype, typename ftype, typename dtype, typename comp_type, | |||||
| class Strategy, typename FilterMeta, typename FilterVisitor = ConvFilterVisitor> | |||||
| void region_restricted_compute( | |||||
| _megdnn_tensor_in src, ftype* __restrict fptr, _megdnn_tensor_in rin, | |||||
| _megdnn_tensor_in rout, _megdnn_tensor_out dst, const FilterMeta& filter_meta) { | |||||
| size_t spatial_start = 2, channel_pos = 1, batch_pos = 0; | |||||
| auto N = src.layout.shape[batch_pos], IH = src.layout.shape[spatial_start], | |||||
| IW = src.layout.shape[spatial_start + 1]; | |||||
| auto FH = filter_meta.spatial[0], FW = filter_meta.spatial[1]; | |||||
| size_t OC = dst.layout.shape[channel_pos], OH = dst.layout.shape[spatial_start], | |||||
| OW = dst.layout.shape[spatial_start + 1]; | |||||
| size_t FS_SPATIAL = 1, FS_IC = FH * FW, FS_OC = FS_IC * filter_meta.icpg, | |||||
| FS_G = FS_OC * filter_meta.ocpg; | |||||
| int ph = filter_meta.padding[0], pw = filter_meta.padding[1]; | |||||
| size_t sh = filter_meta.stride[0], sw = filter_meta.stride[1]; | |||||
| int dh = filter_meta.dilation[0], dw = filter_meta.dilation[1]; | |||||
| stype* __restrict sptr = src.compatible_ptr<stype>(); | |||||
| dtype* __restrict dptr = dst.compatible_ptr<dtype>(); | |||||
| int32_t* __restrict rinptr = rin.ptr<int32_t>(); | |||||
| int32_t* __restrict routptr = rout.ptr<int32_t>(); | |||||
| int h_offset = -ph, w_offset = -pw; | |||||
| if (filter_meta.should_flip) { | |||||
| h_offset += filter_meta.dilated_spatial[0] - 1; | |||||
| w_offset += filter_meta.dilated_spatial[1] - 1; | |||||
| dh = -dh; | |||||
| dw = -dw; | |||||
| } | |||||
| auto get_linear_addr = [](ptrdiff_t n, ptrdiff_t c, ptrdiff_t h, ptrdiff_t w, | |||||
| const TensorLayout& layout) -> ptrdiff_t { | |||||
| return n * layout.stride[0] + c * layout.stride[1] + h * layout.stride[2] + | |||||
| w * layout.stride[3]; | |||||
| }; | |||||
| auto get_region_addr = [](ptrdiff_t n, ptrdiff_t h, ptrdiff_t w, | |||||
| const TensorLayout& layout) -> ptrdiff_t { | |||||
| return n * layout.stride[0] + h * layout.stride[1] + w * layout.stride[2]; | |||||
| }; | |||||
| auto get_filter_addr = [&](GroupCounter& gc_out, size_t ic, size_t ic0, size_t fh, | |||||
| size_t fw) { | |||||
| return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC + (ic - ic0) * FS_IC + | |||||
| (fh * FW + fw) * FS_SPATIAL; | |||||
| }; | |||||
| size_t filter_sizes = filter_meta.ocpg * filter_meta.icpg * FH * FW; | |||||
| for (size_t n = 0; n < N; ++n) { | |||||
| GroupCounter gc_out{filter_meta.ocpg}; | |||||
| for (size_t oc = 0; oc < OC; ++oc, gc_out.next()) | |||||
| for (size_t oh = 0; oh < OH; ++oh) | |||||
| for (size_t ow = 0; ow < OW; ++ow) { | |||||
| comp_type dval = dptr[get_linear_addr(n, oc, oh, ow, dst.layout)]; | |||||
| ftype* fptr_cur = FilterVisitor::template get_current_ptr( | |||||
| fptr, n, oc, oh, ow, filter_sizes); | |||||
| Strategy::init_dval(dval); | |||||
| int32_t routval = routptr[get_region_addr(n, oh, ow, rout.layout)]; | |||||
| for (size_t fh = 0; fh < FH; ++fh) | |||||
| for (size_t fw = 0; fw < FW; ++fw) { | |||||
| size_t ih = sh * oh + fh * dh + h_offset, | |||||
| iw = sw * ow + fw * dw + w_offset; | |||||
| // here ih and iw are represented in unsigned int | |||||
| // they will become very large if underflow occurs | |||||
| if (ih < IH && iw < IW) { | |||||
| size_t ic0 = gc_out.cur_grp * filter_meta.icpg, | |||||
| ic1 = ic0 + filter_meta.icpg; | |||||
| for (size_t ic = ic0; ic < ic1; ++ic) { | |||||
| stype& sval = sptr[get_linear_addr( | |||||
| n, ic, ih, iw, src.layout)]; | |||||
| ftype& fval = fptr_cur[get_filter_addr( | |||||
| gc_out, ic, ic0, fh, fw)]; | |||||
| int32_t rinval = rinptr[get_region_addr( | |||||
| n, ih, iw, rin.layout)]; | |||||
| if (routval == rinval) { | |||||
| Strategy::on( | |||||
| sval, fval, dval, src.layout.dtype, | |||||
| filter_meta.dtype, dst.layout.dtype); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| Strategy::write( | |||||
| dval, dptr[get_linear_addr(n, oc, oh, ow, dst.layout)]); | |||||
| } | |||||
| } | |||||
| } | |||||
| //! forward with only filter ptr | |||||
| template <typename stype, typename ftype, typename dtype, typename comp_type> | |||||
| void region_restricted_forward( | |||||
| _megdnn_tensor_in src, const ftype* fptr, _megdnn_tensor_in rin, | |||||
| _megdnn_tensor_in rout, _megdnn_tensor_out dst, | |||||
| const RegionRestrictedConvolution::CanonizedFilterMeta& filter_meta) { | |||||
| megdnn_assert(filter_meta.spatial_ndim == 2); | |||||
| megdnn_assert(filter_meta.format == param::Convolution::Format::NCHW); | |||||
| region_restricted_compute<stype, ftype, dtype, comp_type, StrategyFwd>( | |||||
| src, const_cast<ftype*>(fptr), rin, rout, dst, filter_meta); | |||||
| } | |||||
| //! forward with full filter (for API compatibility) | |||||
| template <typename stype, typename ftype, typename dtype, typename comp_type> | |||||
| void region_restricted_forward( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in rin, | |||||
| _megdnn_tensor_in rout, _megdnn_tensor_out dst, | |||||
| const RegionRestrictedConvolution::CanonizedFilterMeta& filter_meta) { | |||||
| return region_restricted_forward<stype, ftype, dtype, comp_type>( | |||||
| src, filter.compatible_ptr<ftype>(), rin, rout, dst, filter_meta); | |||||
| } | |||||
| template <typename ftype, typename dtype, typename gtype> | |||||
| void region_restricted_backward_data( | |||||
| _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_in rin, | |||||
| _megdnn_tensor_in rout, _megdnn_tensor_out grad, | |||||
| const Convolution::CanonizedFilterMeta& filter_meta) { | |||||
| megdnn_assert(filter_meta.format == param::Convolution::Format::NCHW); | |||||
| memset(grad.raw_ptr(), 0, grad.layout.span().dist_byte()); | |||||
| megdnn_assert(filter_meta.spatial_ndim == 2); | |||||
| region_restricted_compute<gtype, ftype, dtype, dtype, StrategyBwdData>( | |||||
| grad, filter.compatible_ptr<ftype>(), rin, rout, diff, filter_meta); | |||||
| } | |||||
| template <typename stype, typename dtype, typename gtype> | |||||
| void region_restricted_backward_filter( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_in rin, | |||||
| _megdnn_tensor_in rout, _megdnn_tensor_out grad, | |||||
| const Convolution::CanonizedFilterMeta& filter_meta) { | |||||
| megdnn_assert(filter_meta.format == param::Convolution::Format::NCHW); | |||||
| memset(grad.raw_ptr(), 0, grad.layout.span().dist_byte()); | |||||
| megdnn_assert(filter_meta.spatial_ndim == 2); | |||||
| region_restricted_compute<stype, gtype, dtype, dtype, StrategyBwdFlt>( | |||||
| src, grad.compatible_ptr<gtype>(), rin, rout, diff, filter_meta); | |||||
| } | |||||
| } // namespace convolution | } // namespace convolution | ||||
| } // namespace naive | } // namespace naive | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -57,6 +57,7 @@ | |||||
| #include "src/naive/pooling/opr_impl.h" | #include "src/naive/pooling/opr_impl.h" | ||||
| #include "src/naive/powc/opr_impl.h" | #include "src/naive/powc/opr_impl.h" | ||||
| #include "src/naive/reduce/opr_impl.h" | #include "src/naive/reduce/opr_impl.h" | ||||
| #include "src/naive/region_restricted_convolution/opr_impl.h" | |||||
| #include "src/naive/relayout/opr_impl.h" | #include "src/naive/relayout/opr_impl.h" | ||||
| #include "src/naive/relayout_format/opr_impl.h" | #include "src/naive/relayout_format/opr_impl.h" | ||||
| #include "src/naive/remap/opr_impl.h" | #include "src/naive/remap/opr_impl.h" | ||||
| @@ -0,0 +1,180 @@ | |||||
| #include "./opr_impl.h" | |||||
| #include "../convolution/helper.h" | |||||
| #include "megdnn/dtype.h" | |||||
| #include "src/common/utils.cuh" | |||||
| #include "src/common/utils.h" | |||||
| #include "src/naive/handle.h" | |||||
| #include <cstring> | |||||
| #include "midout.h" | |||||
| MIDOUT_DECL(megdnn_naive_region_restricted_conv_fwd) | |||||
| using namespace megdnn; | |||||
| using namespace naive; | |||||
| void RegionRestrictedConvolutionForwardImpl::exec( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in rin, | |||||
| _megdnn_tensor_in rout, _megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||||
| MIDOUT_BEGIN(megdnn_naive_region_restricted_conv_fwd) { | |||||
| auto filter_meta = check_exec( | |||||
| src.layout, filter.layout, rin.layout, rout.layout, dst.layout, | |||||
| workspace.size); | |||||
| using ComputeMode = Param::ComputeMode; | |||||
| #define DISPATCH_CMODE(in_dt, out_dt, in_ct, out_ct, comp_ct, cmode) \ | |||||
| do { \ | |||||
| using namespace dtype; \ | |||||
| if (src.layout.dtype.enumv() == DTypeTrait<in_dt>::enumv && \ | |||||
| dst.layout.dtype.enumv() == DTypeTrait<out_dt>::enumv && \ | |||||
| param().compute_mode == cmode) { \ | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR((convolution::region_restricted_forward< \ | |||||
| in_ct, in_ct, out_ct, comp_ct>( \ | |||||
| src, filter, rin, rout, dst, filter_meta));); \ | |||||
| return; \ | |||||
| } \ | |||||
| } while (0); | |||||
| #define DISPATCH(in_dt, out_dt, in_ct, out_ct, comp_ct) \ | |||||
| DISPATCH_CMODE(in_dt, out_dt, in_ct, out_ct, comp_ct, ComputeMode::DEFAULT) | |||||
| #define cb(dt) \ | |||||
| DISPATCH( \ | |||||
| dt, dt, DTypeTrait<dt>::ctype, DTypeTrait<dt>::ctype, \ | |||||
| DTypeTrait<dt>::ctype) | |||||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); | |||||
| #undef cb | |||||
| DNN_INC_FLOAT16(DISPATCH_CMODE( | |||||
| Float16, Float16, dt_float16, dt_float16, dt_float32, | |||||
| ComputeMode::FLOAT32)); | |||||
| #undef DISPATCH | |||||
| megdnn_throw(ssprintf( | |||||
| "unsupported RegionRestrictedConv(%s, %s, %s, %s) -> %s with cmode = " | |||||
| "%d", | |||||
| src.layout.dtype.name(), filter.layout.dtype.name(), | |||||
| rin.layout.dtype.name(), rout.layout.dtype.name(), | |||||
| dst.layout.dtype.name(), static_cast<int>(param().compute_mode))); | |||||
| } | |||||
| MIDOUT_END(); | |||||
| } | |||||
| size_t RegionRestrictedConvolutionBackwardDataImpl::get_workspace_in_bytes( | |||||
| const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& rin, | |||||
| const TensorLayout& rout, const TensorLayout& grad) { | |||||
| size_t workspace_size = 0; | |||||
| auto flt_dt = filter.dtype.enumv(); | |||||
| auto grad_dt = grad.dtype.enumv(); | |||||
| auto diff_dt = diff.dtype.enumv(); | |||||
| MEGDNN_MARK_USED_VAR(rin); | |||||
| MEGDNN_MARK_USED_VAR(rout); | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| if (flt_dt == DTypeEnum::Float16 || flt_dt == DTypeEnum::BFloat16) { | |||||
| megdnn_assert(flt_dt == grad_dt && flt_dt == diff_dt); | |||||
| workspace_size = grad.span().dist_elem() * dtype::Float32().size(); | |||||
| } | |||||
| #else | |||||
| MEGDNN_MARK_USED_VAR(flt_dt); | |||||
| MEGDNN_MARK_USED_VAR(grad_dt); | |||||
| MEGDNN_MARK_USED_VAR(diff_dt); | |||||
| #endif | |||||
| return workspace_size; | |||||
| } | |||||
| void RegionRestrictedConvolutionBackwardDataImpl::exec( | |||||
| _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_in rin, | |||||
| _megdnn_tensor_in rout, _megdnn_tensor_out grad, _megdnn_workspace workspace) { | |||||
| auto filter_meta = check_exec( | |||||
| filter.layout, diff.layout, rin.layout, rout.layout, grad.layout, | |||||
| workspace.size); | |||||
| using ComputeMode = Param::ComputeMode; | |||||
| auto cmode = param().compute_mode; | |||||
| #define cb(dt) \ | |||||
| do { \ | |||||
| if (filter.layout.dtype == dt() && cmode == ComputeMode::DEFAULT) { \ | |||||
| using ctype = DTypeTrait<dt>::ctype; \ | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
| (convolution::region_restricted_backward_data< \ | |||||
| ctype, ctype, ctype>( \ | |||||
| filter, diff, rin, rout, grad, filter_meta));); \ | |||||
| return; \ | |||||
| } \ | |||||
| } while (0); | |||||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); | |||||
| #undef cb | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| if (filter.layout.dtype == dtype::Float16() && cmode == ComputeMode::FLOAT32) { | |||||
| TensorND grad_fp32{ | |||||
| workspace.raw_ptr, TensorLayout{grad.layout, dtype::Float32()}}; | |||||
| auto&& type_cvt = handle()->create_operator<TypeCvt>(); | |||||
| type_cvt->exec(grad, grad_fp32); | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR((convolution::region_restricted_backward_data< | |||||
| dt_float16, dt_float16, dt_float32>( | |||||
| filter, diff, rin, rout, grad_fp32, filter_meta));); | |||||
| type_cvt->exec(grad_fp32, grad); | |||||
| return; | |||||
| } | |||||
| #endif | |||||
| megdnn_throw(ssprintf( | |||||
| "unsupported RegionRestrictedConvolutionBackwardData(%s, %s, %s, %s) -> %s " | |||||
| "with cmode = %d", | |||||
| filter.layout.dtype.name(), diff.layout.dtype.name(), | |||||
| rin.layout.dtype.name(), rout.layout.dtype.name(), grad.layout.dtype.name(), | |||||
| static_cast<int>(cmode))); | |||||
| } | |||||
| size_t RegionRestrictedConvolutionBackwardFilterImpl::get_workspace_in_bytes( | |||||
| const TensorLayout& src, const TensorLayout& diff, const TensorLayout&, | |||||
| const TensorLayout&, const TensorLayout& grad) { | |||||
| size_t workspace_size = 0; | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| auto src_dt = src.dtype.enumv(); | |||||
| auto grad_dt = grad.dtype.enumv(); | |||||
| auto diff_dt = diff.dtype.enumv(); | |||||
| if (src_dt == DTypeEnum::Float16 || src_dt == DTypeEnum::BFloat16) { | |||||
| megdnn_assert(src_dt == grad_dt && src_dt == diff_dt); | |||||
| workspace_size = grad.span().dist_elem() * dtype::Float32().size(); | |||||
| } | |||||
| #endif | |||||
| return workspace_size; | |||||
| } | |||||
| void RegionRestrictedConvolutionBackwardFilterImpl::exec( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_in rin, | |||||
| _megdnn_tensor_in rout, _megdnn_tensor_out grad, _megdnn_workspace workspace) { | |||||
| auto filter_meta = check_exec( | |||||
| src.layout, diff.layout, rin.layout, rout.layout, grad.layout, | |||||
| workspace.size); | |||||
| using ComputeMode = Param::ComputeMode; | |||||
| auto cmode = param().compute_mode; | |||||
| #define cb(dt) \ | |||||
| do { \ | |||||
| if (src.layout.dtype == dt() && cmode == ComputeMode::DEFAULT) { \ | |||||
| using ctype = DTypeTrait<dt>::ctype; \ | |||||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||||
| static_cast<HandleImpl*>(handle()), \ | |||||
| convolution::region_restricted_backward_filter< \ | |||||
| ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \ | |||||
| src, diff, rin, rout, grad, filter_meta);); \ | |||||
| return; \ | |||||
| } \ | |||||
| } while (0); | |||||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); | |||||
| #undef cb | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| if (src.layout.dtype == dtype::Float16() && cmode == ComputeMode::FLOAT32) { | |||||
| TensorND grad_fp32{ | |||||
| workspace.raw_ptr, TensorLayout{grad.layout, dtype::Float32()}}; | |||||
| auto&& type_cvt = handle()->create_operator<TypeCvt>(); | |||||
| type_cvt->exec(grad, grad_fp32); | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR((convolution::region_restricted_backward_filter< | |||||
| dt_float16, dt_float16, dt_float32>( | |||||
| src, diff, rin, rout, grad_fp32, filter_meta));); | |||||
| type_cvt->exec(grad_fp32, grad); | |||||
| return; | |||||
| } | |||||
| #endif | |||||
| megdnn_assert_internal(0); | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,53 @@ | |||||
| #pragma once | |||||
| #include "megdnn/oprs.h" | |||||
| #include "src/common/utils.h" | |||||
| namespace megdnn { | |||||
| namespace naive { | |||||
| class RegionRestrictedConvolutionForwardImpl | |||||
| : public RegionRestrictedConvolutionForward { | |||||
| public: | |||||
| using RegionRestrictedConvolutionForward::RegionRestrictedConvolutionForward; | |||||
| void exec( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in rin, | |||||
| _megdnn_tensor_in rout, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) override; | |||||
| size_t get_workspace_in_bytes( | |||||
| const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||||
| const TensorLayout&, const TensorLayout&) override { | |||||
| return 0; | |||||
| } | |||||
| }; | |||||
| class RegionRestrictedConvolutionBackwardDataImpl | |||||
| : public RegionRestrictedConvolutionBackwardData { | |||||
| public: | |||||
| using RegionRestrictedConvolutionBackwardData:: | |||||
| RegionRestrictedConvolutionBackwardData; | |||||
| void exec( | |||||
| _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_in rin, | |||||
| _megdnn_tensor_in rout, _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) override; | |||||
| size_t get_workspace_in_bytes( | |||||
| const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||||
| const TensorLayout&, const TensorLayout&) override; | |||||
| }; | |||||
| class RegionRestrictedConvolutionBackwardFilterImpl | |||||
| : public RegionRestrictedConvolutionBackwardFilter { | |||||
| public: | |||||
| using RegionRestrictedConvolutionBackwardFilter:: | |||||
| RegionRestrictedConvolutionBackwardFilter; | |||||
| void exec( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_in rin, | |||||
| _megdnn_tensor_in rout, _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) override; | |||||
| size_t get_workspace_in_bytes( | |||||
| const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||||
| const TensorLayout&, const TensorLayout&) override; | |||||
| }; | |||||
| } // namespace naive | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,198 @@ | |||||
| #include "test/naive/fixture.h" | |||||
| #include "megdnn/oprs/nn.h" | |||||
| #include "test/common/checker.h" | |||||
| #include "test/common/convolution.h" | |||||
| // #include "test/common/regin_restricted_convolution.h" | |||||
| #include "test/common/extra_impl_helper.h" | |||||
| #include "test/common/random_state.h" | |||||
| using namespace megdnn; | |||||
| using namespace test; | |||||
| namespace { | |||||
| void mask_tensor( | |||||
| const TensorND& in, TensorND& out, const TensorND& mask, | |||||
| const int32_t mask_val) { | |||||
| megdnn_assert( | |||||
| in.layout.ndim == out.layout.ndim && in.layout.ndim == 4 && | |||||
| mask.layout.ndim == 3); | |||||
| megdnn_assert_eq_layout(in.layout, out.layout); | |||||
| megdnn_assert( | |||||
| mask.layout[0] == in.layout[0] && mask.layout[1] == in.layout[2] && | |||||
| mask.layout[2] == in.layout[3]); | |||||
| int32_t* mask_ptr = mask.ptr<int32_t>(); | |||||
| float* src_ptr = in.compatible_ptr<float>(); | |||||
| float* dst_ptr = out.compatible_ptr<float>(); | |||||
| for (size_t n = 0; n < in.layout[0]; ++n) { | |||||
| for (size_t c = 0; c < in.layout[1]; ++c) { | |||||
| for (size_t h = 0; h < in.layout[2]; ++h) { | |||||
| for (size_t w = 0; w < in.layout[3]; ++w) { | |||||
| size_t mask_off = n * mask.layout.stride[0] + | |||||
| h * mask.layout.stride[1] + | |||||
| w * mask.layout.stride[2]; | |||||
| size_t src_dst_off = | |||||
| n * in.layout.stride[0] + c * in.layout.stride[1] + | |||||
| h * in.layout.stride[2] + w * in.layout.stride[3]; | |||||
| if (mask_ptr[mask_off] == mask_val) { | |||||
| dst_ptr[src_dst_off] = src_ptr[src_dst_off]; | |||||
| } else { | |||||
| dst_ptr[src_dst_off] = 0.; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace | |||||
| TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD) { | |||||
| Checker<RegionRestrictedConvolution> checker(handle()); | |||||
| RegionRestrictedConvolution::Param param; | |||||
| constexpr int N = 3; | |||||
| UniformIntRNG rng{0, N-1}; | |||||
| auto extra_impl = [&, this](const TensorNDArray& tensors) { | |||||
| auto conv = handle()->create_operator<Convolution>(); | |||||
| conv->param() = param; | |||||
| auto workspace_size = conv->get_workspace_in_bytes( | |||||
| tensors[0].layout, tensors[1].layout, tensors[4].layout, nullptr); | |||||
| dt_byte* workspace_ptr = static_cast<dt_byte*>(malloc(workspace_size)); | |||||
| Workspace workspace{workspace_ptr, workspace_size}; | |||||
| TensorND masked_src(malloc(tensors[0].layout.span().dist_byte()), tensors[0].layout); | |||||
| TensorNDArray dst_tensors; | |||||
| for(int i=0; i<N; ++i) { | |||||
| dst_tensors.emplace_back(malloc(tensors[4].layout.span().dist_byte()), tensors[4].layout); | |||||
| } | |||||
| for(int i=0; i<N; ++i) { | |||||
| mask_tensor(tensors[0], masked_src, tensors[2], i); | |||||
| conv->exec(masked_src, tensors[1], dst_tensors[i], nullptr, workspace); | |||||
| mask_tensor(dst_tensors[i], dst_tensors[i], tensors[3], i); | |||||
| } | |||||
| free(workspace_ptr); | |||||
| using Mode = ElemwiseForward::Param::Mode; | |||||
| auto add = handle()->create_operator<ElemwiseForward>(); | |||||
| add->param().mode = Mode::ADD; | |||||
| add->exec({dst_tensors[0], dst_tensors[1]}, tensors[4]); | |||||
| for (int i=2; i<N; ++i) { | |||||
| add->exec({dst_tensors[i], tensors[4]}, tensors[4]); | |||||
| } | |||||
| }; | |||||
| checker.set_extra_opr_impl(extra_impl) | |||||
| .set_rng(2, &rng) | |||||
| .set_rng(3, &rng) | |||||
| .set_dtype(2, dtype::Int32()) | |||||
| .set_dtype(3, dtype::Int32()); | |||||
| checker.execs({{1, 8, 2, 2}, {4, 8, 1, 1}, {1, 2, 2}, {1, 2, 2}, {}}) | |||||
| .execs({{20, 12, 30, 30}, {4, 12, 1, 1}, {20, 30, 30}, {20, 30, 30}, {}}) | |||||
| .execs({{20, 8, 30, 30}, {4, 8, 3, 3}, {20, 30, 30}, {20, 28, 28}, {}}); | |||||
| param.sparse = Convolution::Param::Sparse::GROUP; | |||||
| checker.set_param(param) | |||||
| .execs({{20, 15, 30, 30}, {5, 4, 3, 3, 3}, {20, 30, 30}, {20, 28, 28}, {}}) | |||||
| .execs({{20, 25, 30, 30}, {25, 1, 1, 3, 3}, {20, 30, 30}, {20, 28, 28}, {}}); | |||||
| } | |||||
| #if 0 | |||||
| TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_BACKWARD_DATA) { | |||||
| Checker<RegionRestrictedConvolutionBackwardData> checker(handle()); | |||||
| using Param = RegionRestrictedConvolutionBackwardData::Param; | |||||
| Param param; | |||||
| auto run = [&](size_t n, size_t ic, size_t oh, size_t ow, size_t oc, size_t fh, | |||||
| size_t fw, size_t stride, size_t padding, size_t dilate = 1, | |||||
| size_t group = 1) { | |||||
| param.pad_h = param.pad_w = padding; | |||||
| param.stride_h = param.stride_w = stride; | |||||
| param.dilate_h = param.dilate_w = dilate; | |||||
| TensorLayout diff = TensorLayout{{n, oc * group, oh, ow}, dtype::Float32()}; | |||||
| TensorLayout grad; | |||||
| TensorLayout filter; | |||||
| if (group == 1) { | |||||
| param.sparse = Param::Sparse::DENSE; | |||||
| filter = {{oc, ic, fh, fw}, dtype::Float32()}; | |||||
| } else { | |||||
| param.sparse = Param::Sparse::GROUP; | |||||
| filter = {{group, oc, ic, fh, fw}, dtype::Float32()}; | |||||
| } | |||||
| // TensorLayout grad; | |||||
| { | |||||
| auto opr = handle()->create_operator<ConvolutionBackwardData>(); | |||||
| opr->param() = param; | |||||
| opr->deduce_layout(filter, diff, grad); | |||||
| } | |||||
| checker.set_param(param); | |||||
| checker.exec(TensorLayoutArray{filter, diff, grad}); | |||||
| }; | |||||
| for (auto mode : {Param::Mode::CONVOLUTION, Param::Mode::CROSS_CORRELATION}) { | |||||
| param.mode = mode; | |||||
| run(4, 3, 10, 13, 5, 1, 1, 1, 0, 1, 1); | |||||
| run(5, 5, 24, 43, 11, 9, 3, 3, 12, 1, 2); | |||||
| run(4, 3, 10, 45, 2, 1, 1, 1, 0, 4, 3); | |||||
| run(2, 3, 9, 12, 2, 4, 6, 1, 0, 1, 2); | |||||
| run(3, 4, 17, 32, 2, 3, 2, 5, 4, 4, 3); | |||||
| run(5, 5, 24, 43, 11, 9, 3, 3, 12, 2, 2); | |||||
| run(2, 3, 20, 33, 3, 5, 7, 4, 15, 2, 3); | |||||
| run(4, 4, 6, 7, 9, 3, 2, 2, 1, 3, 2); | |||||
| } | |||||
| } | |||||
| TEST_F(NAIVE, CONVOLUTION_BACKWARD_DATA) { | |||||
| Checker<RegionRestrictedConvolutionBackwardData> checker(handle()); | |||||
| using Param = RegionRestrictedConvolutionBackwardData::Param; | |||||
| Param param; | |||||
| auto run = [&](size_t n, size_t ic, size_t oh, size_t ow, size_t oc, size_t fh, | |||||
| size_t fw, size_t stride, size_t padding, size_t dilate = 1, | |||||
| size_t group = 1) { | |||||
| param.pad_h = param.pad_w = padding; | |||||
| param.stride_h = param.stride_w = stride; | |||||
| param.dilate_h = param.dilate_w = dilate; | |||||
| TensorLayout diff = TensorLayout{{n, oc * group, oh, ow}, dtype::Float32()}; | |||||
| TensorLayout grad; | |||||
| TensorLayout filter; | |||||
| if (group == 1) { | |||||
| param.sparse = Param::Sparse::DENSE; | |||||
| filter = {{oc, ic, fh, fw}, dtype::Float32()}; | |||||
| } else { | |||||
| param.sparse = Param::Sparse::GROUP; | |||||
| filter = {{group, oc, ic, fh, fw}, dtype::Float32()}; | |||||
| } | |||||
| // TensorLayout grad; | |||||
| { | |||||
| auto opr = handle()->create_operator<ConvolutionBackwardData>(); | |||||
| opr->param() = param; | |||||
| opr->deduce_layout(filter, diff, grad); | |||||
| } | |||||
| checker.set_param(param); | |||||
| checker.exec(TensorLayoutArray{filter, diff, grad}); | |||||
| }; | |||||
| for (auto mode : {Param::Mode::CONVOLUTION, Param::Mode::CROSS_CORRELATION}) { | |||||
| param.mode = mode; | |||||
| run(4, 3, 10, 13, 5, 1, 1, 1, 0, 1, 1); | |||||
| run(5, 5, 24, 43, 11, 9, 3, 3, 12, 1, 2); | |||||
| run(4, 3, 10, 45, 2, 1, 1, 1, 0, 4, 3); | |||||
| run(2, 3, 9, 12, 2, 4, 6, 1, 0, 1, 2); | |||||
| run(3, 4, 17, 32, 2, 3, 2, 5, 4, 4, 3); | |||||
| run(5, 5, 24, 43, 11, 9, 3, 3, 12, 2, 2); | |||||
| run(2, 3, 20, 33, 3, 5, 7, 4, 15, 2, 3); | |||||
| run(4, 4, 6, 7, 9, 3, 2, 2, 1, 3, 2); | |||||
| } | |||||
| } | |||||
| #endif | |||||
| // vim: syntax=cpp.doxygen | |||||