|
- /**
- * \file dnn/src/common/roi_align.cpp
- * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- *
- * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- */
- #include "megdnn/oprs.h"
-
- #include "src/common/utils.h"
-
- namespace megdnn {
-
- void ROIAlignBase::deduce_layout_fwd(const TensorLayout& src,
- const TensorLayout& rois,
- TensorLayout& dst, TensorLayout& index) {
- megdnn_assert_contiguous(src);
- megdnn_assert_contiguous(rois);
- megdnn_assert_contiguous(dst);
- megdnn_assert_contiguous(index);
- auto errmsg = [&]() {
- return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(rois) + ", " +
- megdnn_layout_msg(dst) + ", " + megdnn_layout_msg(index);
- };
- MEGDNN_MARK_USED_VAR(errmsg);
- using Format = ROIAlignBase::Param::Format;
- megdnn_assert(param().format == Format::NCHW);
- auto src_dtype = src.dtype, rois_dtype = rois.dtype;
- megdnn_assert(src_dtype == rois_dtype &&
- src_dtype.category() == DTypeCategory::FLOAT);
- megdnn_assert(src.ndim == 4_z, "%s", errmsg().c_str());
- size_t channels = src.shape[1];
- megdnn_assert(rois.ndim == 2_z, "%s", errmsg().c_str());
- // rois shape: bid, x0, y0, x1, y1
- megdnn_assert(rois[1] == 5_z, "%s", errmsg().c_str());
- size_t M = rois[0];
- size_t pooled_height = param().pooled_height;
- size_t pooled_width = param().pooled_width;
- dst = TensorLayout{{M, channels, pooled_height, pooled_width}, src.dtype};
- index = dst;
- index.dtype = dtype::Int32();
- }
-
- void ROIAlignBase::check_layout_fwd(const TensorLayout& src,
- const TensorLayout& rois,
- const TensorLayout& dst,
- const TensorLayout& index) {
- TensorLayout dst_expected, index_expected;
- megdnn_assert_eq_dtype(src, dst);
- deduce_layout_fwd(src, rois, dst_expected, index_expected);
- megdnn_assert_eq_shape(dst_expected, dst);
- megdnn_assert_eq_shape(index_expected, index);
- megdnn_assert(index.dtype == dtype::Int32());
- }
-
- void ROIAlignForward::deduce_layout(const TensorLayout& src,
- const TensorLayout& rois, TensorLayout& dst,
- TensorLayout& index) {
- deduce_layout_fwd(src, rois, dst, index);
- }
-
- void ROIAlignForward::check_exec(const TensorLayout& src,
- const TensorLayout& rois,
- const TensorLayout& dst,
- const TensorLayout& index,
- size_t workspace_in_bytes) {
- check_layout_fwd(src, rois, dst, index);
- auto required_workspace_in_bytes =
- get_workspace_in_bytes(src, rois, dst, index);
- megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
- }
-
- void ROIAlignBackward::check_exec(const TensorLayout& diff,
- const TensorLayout& rois,
- const TensorLayout& index,
- const TensorLayout& grad,
- size_t workspace_in_bytes) {
- check_layout_fwd(grad, rois, diff, index);
- auto required_workspace_in_bytes =
- get_workspace_in_bytes(diff, rois, index, grad);
- megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
- }
-
- } // namespace megdnn
-
- // vim: syntax=cpp.doxygen
|