diff --git a/dnn/src/common/pooling.cpp b/dnn/src/common/pooling.cpp index b2fed318..877eb1f0 100644 --- a/dnn/src/common/pooling.cpp +++ b/dnn/src/common/pooling.cpp @@ -16,50 +16,55 @@ namespace megdnn { void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) { - auto errmsg = - megdnn_layout_msg(src) + ", " + megdnn_layout_msg(dst) + ", " + - "pad_h=" + std::to_string(param().pad_h) + ", " + - "pad_w=" + std::to_string(param().pad_w) + ", " + - "stride_h=" + std::to_string(param().stride_h) + ", " + - "stride_w=" + std::to_string(param().stride_w) + ", " + - "window_h=" + std::to_string(param().window_h) + ", " + - "window_w=" + std::to_string(param().window_w) + ", " + - "is_max=" + std::to_string(param().mode == Mode::MAX) + ", " + - "is_nhwc=" + std::to_string(param().format == Param::Format::NHWC) + ", " + - "is_nhwcd4=" + std::to_string(param().format == Param::Format::NHWCD4); - auto errmsg_c = errmsg.c_str(); - - MEGDNN_MARK_USED_VAR(errmsg_c); + auto& p = param(); + auto pformat = p.format; + + // the overhead of generating error message is about 18x of the other part of this + // function so we use a function to wrap the error message and get it only when need. + auto get_errmsg = [&](void) -> std::string { + std::string errmsg = + megdnn_layout_msg(src) + ", " + megdnn_layout_msg(dst) + ", " + + "pad_h=" + std::to_string(param().pad_h) + ", " + + "pad_w=" + std::to_string(param().pad_w) + ", " + + "stride_h=" + std::to_string(param().stride_h) + ", " + + "stride_w=" + std::to_string(param().stride_w) + ", " + + "window_h=" + std::to_string(param().window_h) + ", " + + "window_w=" + std::to_string(param().window_w) + ", " + + "is_max=" + std::to_string(param().mode == Mode::MAX) + ", " + + "is_nhwc=" + std::to_string(pformat == Param::Format::NHWC) + ", " + + "is_nhwcd4=" + std::to_string(pformat == Param::Format::NHWCD4); + return errmsg; + }; + + MEGDNN_MARK_USED_VAR(get_errmsg); megdnn_assert_contiguous(src); size_t spatial_pos, c_pos, batch_pos = 0; - if (param().format == Param::Format::NCHW) { - megdnn_assert(src.ndim == 4_z, "%s", errmsg_c); + if (pformat == Param::Format::NCHW) { + megdnn_assert(src.ndim == 4_z, "%s", get_errmsg().c_str()); spatial_pos = 2; c_pos = 1; - } else if (param().format == Param::Format::NHWC) { - megdnn_assert(src.ndim == 4_z, "%s", errmsg_c); + } else if (pformat == Param::Format::NHWC) { + megdnn_assert(src.ndim == 4_z, "%s", get_errmsg().c_str()); spatial_pos = 1; c_pos = 3; } else if ( - param().format == Param::Format::NCHW4 || - param().format == Param::Format::NCHW44 || - param().format == Param::Format::NCHW88 || - param().format == Param::Format::NCHW32 || - param().format == Param::Format::NCHW64) { - megdnn_assert(src.ndim == 5_z, "%s", errmsg_c); + pformat == Param::Format::NCHW4 || pformat == Param::Format::NCHW44 || + pformat == Param::Format::NCHW88 || pformat == Param::Format::NCHW32 || + pformat == Param::Format::NCHW64) { + megdnn_assert(src.ndim == 5_z, "%s", get_errmsg().c_str()); spatial_pos = 2; c_pos = 1; - } else if (param().format == Param::Format::CHWN4) { + } else if (pformat == Param::Format::CHWN4) { spatial_pos = 1; c_pos = 0; batch_pos = 3; } else { megdnn_assert( - param().format == Param::Format::NHWCD4 && src.ndim == 5_z, "%s", - errmsg_c); + pformat == Param::Format::NHWCD4 && src.ndim == 5_z, "%s", + get_errmsg().c_str()); spatial_pos = 1; c_pos = 2; } @@ -67,31 +72,34 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) size_t c = src[c_pos]; size_t ih = src[spatial_pos]; size_t iw = src[spatial_pos + 1]; - if (param().format == Param::Format::NHWCD4) { + if (pformat == Param::Format::NHWCD4) { c *= 4; iw = src[spatial_pos + 2]; } - if (param().format == Param::Format::NCHW4 || - param().format == Param::Format::NCHW44 || - param().format == Param::Format::CHWN4) { + if (pformat == Param::Format::NCHW4 || pformat == Param::Format::NCHW44 || + pformat == Param::Format::CHWN4) { c *= 4; } - if (param().format == Param::Format::NCHW88) { + if (pformat == Param::Format::NCHW88) { c *= 8; } - if (param().format == Param::Format::NCHW32) { + if (pformat == Param::Format::NCHW32) { c *= 32; } - if (param().format == Param::Format::NCHW64) { + if (pformat == Param::Format::NCHW64) { c *= 64; } size_t oh, ow; - size_t fh = this->param().window_h; - size_t fw = this->param().window_w; - size_t sh = this->param().stride_h; - size_t sw = this->param().stride_w; - size_t ph = this->param().pad_h; - size_t pw = this->param().pad_w; + size_t fh = p.window_h; + size_t fw = p.window_w; + size_t sh = p.stride_h; + size_t sw = p.stride_w; + size_t ph = p.pad_h; + size_t pw = p.pad_w; + + // moving some python assert to here + // megdnn_assert() + if (ph >= fh || pw >= fw) { megdnn_log_warn( "pooling padding size (%zu %zu) should not be bigger than " @@ -99,26 +107,23 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) pw, ph, fw, fh); } infer_conv_shape2d(ih, iw, fh, fw, sh, sw, ph, pw, oh, ow); - if (param().format == Param::Format::NCHW) { + if (pformat == Param::Format::NCHW) { dst = TensorLayout(TensorShape({n, c, oh, ow}), src.dtype); - } else if (param().format == Param::Format::NHWC) { - megdnn_assert(param().format == Param::Format::NHWC, "invalid pooling format"); + } else if (pformat == Param::Format::NHWC) { + megdnn_assert(pformat == Param::Format::NHWC, "invalid pooling format"); dst = TensorLayout({n, oh, ow, c}, src.dtype, src.format); - } else if ( - param().format == Param::Format::NCHW4 || - param().format == Param::Format::NCHW44) { + } else if (pformat == Param::Format::NCHW4 || pformat == Param::Format::NCHW44) { dst = TensorLayout{{n, c / 4, oh, ow, 4}, src.dtype, src.format}; - } else if (param().format == Param::Format::NCHW88) { + } else if (pformat == Param::Format::NCHW88) { dst = TensorLayout{{n, c / 8, oh, ow, 8}, src.dtype, src.format}; - } else if (param().format == Param::Format::NCHW32) { + } else if (pformat == Param::Format::NCHW32) { dst = TensorLayout{{n, c / 32, oh, ow, 32}, src.dtype, src.format}; - } else if (param().format == Param::Format::NCHW64) { + } else if (pformat == Param::Format::NCHW64) { dst = TensorLayout{{n, c / 64, oh, ow, 64}, src.dtype, src.format}; - } else if (param().format == Param::Format::CHWN4) { + } else if (pformat == Param::Format::CHWN4) { dst = TensorLayout{{c / 4, oh, ow, n, 4}, src.dtype, src.format}; } else { - megdnn_assert( - param().format == Param::Format::NHWCD4, "invalid pooling format"); + megdnn_assert(pformat == Param::Format::NHWCD4, "invalid pooling format"); dst = TensorLayout{{n, oh, c / 4, ow, 4}, src.dtype, src.format}; } } diff --git a/imperative/src/impl/ops/pooling.cpp b/imperative/src/impl/ops/pooling.cpp new file mode 100644 index 00000000..6465ae14 --- /dev/null +++ b/imperative/src/impl/ops/pooling.cpp @@ -0,0 +1,105 @@ +/** + * \file imperative/src/impl/ops/pooling.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/opr/dnn/pooling.h" +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/opr/utility.h" + +#include "megbrain/opr/internal/megdnn_opr_wrapper.h" + +#include "../algo_chooser.h" +#include "../blob_manager_impl.h" +#include "../dnn_op_helper.h" +#include "../op_trait.h" + +namespace mgb::imperative { + +namespace { +namespace pooling { + +// using OprHandle = opr::intl::UniqPtrWithCN; +// static ThinHashMap dnn_oprs; + +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& pool = static_cast(def); + OperatorNodeConfig config{pool.make_name()}; + return opr::Pooling::make(inputs[0], pool.param(), pool.policy(), config); +} + +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, const SmallVector& inputs) { + mgb_assert( + inputs.size() == 1, "num of inputs of pooling should be 1 but you give %zu", + inputs.size()); + + auto&& op_def = def.cast_final_safe(); + auto&& inp = inputs[0]; + auto& inp_cn = inp.comp_node; + + if (inp.layout.ndim == 0) { + return {{{TensorLayout{inp.layout.dtype}, inp_cn, {}}}, false}; + } + + DnnOprCaller caller(inp_cn); + auto&& dnn_opr = caller.op; + dnn_opr->param() = op_def.param(); + TensorLayout oup_layout; + dnn_opr->deduce_layout(inp.layout, oup_layout); + return {{{oup_layout, inp_cn, {}}}, true}; +} + +SmallVector apply_on_physical_tensor( + const OpDef& def, const SmallVector& inputs, + SmallVector& output_descs, const bool& validated) { + mgb_assert( + inputs.size() == 1, "num of inputs of pooling should be 1 but you give %zu", + inputs.size()); + + auto&& op_def = def.cast_final_safe(); + auto cn = inputs[0]->comp_node(); + megdnn::TensorND inp_tensornd = inputs[0]->dnn_tensor(); + + DnnOprCaller caller(cn); + auto&& dnn_opr = caller.op; + dnn_opr->param() = op_def.param(); + + TensorLayout& oup_layout = output_descs[0].layout; + if (!validated) { + dnn_opr->deduce_layout(inp_tensornd.layout, oup_layout); + } + DeviceTensorND out_devtensor = + BlobManager::inst()->alloc_workspace_with_defrag(cn, oup_layout); + + size_t wk_size = setup_algo( + {inp_tensornd.layout, oup_layout}, dnn_opr.get(), 0, false, false, cn, + op_def.policy(), false); + + megdnn::Workspace dnn_wk; + if (wk_size != 0) { + auto wk = Blob::make(cn, wk_size); + dnn_wk.raw_ptr = wk->storage().get(); + dnn_wk.size = wk_size; + } + + dnn_opr->exec(inp_tensornd, out_devtensor.as_megdnn(), {}); + return {Tensor::make(out_devtensor)}; +} + +OP_TRAIT_REG(Pooling, Pooling) + .apply_on_var_node(apply_on_var_node) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .apply_on_physical_tensor(apply_on_physical_tensor) + .fallback(); + +} // namespace pooling +} // namespace + +} // namespace mgb::imperative diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index f1f66eb6..cb36e43a 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -333,17 +333,6 @@ OP_TRAIT_REG(BatchConvBias, BatchConvBias) } // namespace batch_conv_bias } // namespace -namespace { -namespace pooling { -auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { - auto&& pool = static_cast(def); - OperatorNodeConfig config{pool.make_name()}; - return opr::Pooling::make(inputs[0], pool.param(), pool.policy(), config); -} -OP_TRAIT_REG(Pooling, Pooling).apply_on_var_node(apply_on_var_node).fallback(); -} // namespace pooling -} // namespace - namespace { namespace matrix_mul { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {