#include "megdnn/oprs/nn.h" #include "src/common/utils.h" using namespace megdnn; void MaskConvForward::deduce_dtype(DType src, DType filter, DType, DType& dst) { check_or_deduce_dtype_fwd(src, filter, dst); } void MaskConvForward::deduce_layout( const TensorLayout& src, const TensorLayout& filter, const TensorLayout& mask, TensorLayout& dst) { deduce_layout_fwd(src, filter, dst); megdnn_assert(dst[2] == mask[0]); megdnn_assert(dst[3] == mask[1]); } MaskConvForward::CanonizedFilterMeta MaskConvForward::check_exec( const TensorLayout& src, const TensorLayout& filter, const TensorLayout& mask, const TensorLayout& dst, size_t workspace_in_bytes) { auto ret = check_layout_fwd(src, filter, dst); megdnn_assert(dst[2] == mask[0]); megdnn_assert(dst[3] == mask[1]); auto required_workspace_in_bytes = get_workspace_in_bytes(src, filter, mask, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); return ret; } void MaskPropagate::deduce_layout(const TensorLayout& src, TensorLayout& dst) { size_t oh, ow; auto p = param(); infer_conv_shape2d( src[0], src[1], (p.kernel_h - 1) * p.dilate_h + 1, (p.kernel_w - 1) * p.dilate_w + 1, p.stride_h, p.stride_w, p.pad_h, p.pad_w, oh, ow); dst = TensorLayout{{oh, ow}, src.dtype}; } // vim: syntax=cpp.doxygen