#include "megbrain/imperative/ops/autogen.h" #include "megbrain/opr/dnn/roi_align.h" #include "megbrain/opr/dnn/roi_pooling.h" #include "megbrain/opr/imgproc.h" #include "../blob_manager_impl.h" #include "../dnn_op_helper.h" #include "../op_trait.h" namespace mgb { namespace imperative { namespace { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 1); OperatorNodeConfig config{op.make_name()}; return opr::CvtColor::make(inputs[0], op.param(), config); } OP_TRAIT_REG(CvtColor, CvtColor).apply_on_var_node(apply_on_var_node).fallback(); } // namespace namespace { namespace roi_align { VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 2); OperatorNodeConfig config{op.make_name()}; auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param(), config) .node() ->owner_opr(); return {opr->output(0), opr->output(1)}; } std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { auto&& op = def.cast_final_safe(); DnnOprHelper dnn_opr(op.param()); auto cn = inputs[0].comp_node; auto&& [out_layout, ind_layout] = dnn_opr.deduce_layouts<2>(inputs[0].layout, inputs[1].layout); bool validated = out_layout.ndim == 0 && ind_layout.ndim == 0; return {{{out_layout, cn}, {ind_layout, cn}}, validated}; } SmallVector apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs, SmallVector& output_descs, const bool& validated) { auto&& op = def.cast_final_safe(); auto cn = inputs[0]->comp_node(); DnnOprCaller dnn_opr(cn, op.param()); auto&& [out_layout, ind_layout] = [&]() -> std::array { if (validated) { return {output_descs[0].layout, output_descs[1].layout}; } else { return dnn_opr.deduce_layouts<2>(inputs[0]->layout(), inputs[1]->layout()); } }(); auto out = Tensor::make(out_layout, cn); auto ind = Tensor::make(ind_layout, cn); if (out_layout.is_empty() || ind_layout.is_empty()) { return {out, ind}; } dnn_opr.exec_with_ws(inputs[0], inputs[1], out, ind); return {out, ind}; } SmallVector get_input_layout_constraint( const OpDef& def, const SmallVector& inputs) { SmallVector layout_checker(inputs.size()); layout_checker[0] = layout_checker[1] = [](const TensorLayout& layout) { return layout.is_contiguous(); }; return layout_checker; } OP_TRAIT_REG(ROIAlign, ROIAlign) .apply_on_var_node(apply_on_var_node) .apply_on_physical_tensor(apply_on_physical_tensor) .infer_output_attrs_fallible(infer_output_attrs_fallible) .get_input_layout_constraint(get_input_layout_constraint) .fallback(); } // namespace roi_align } // namespace namespace { namespace roi_pooling { VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 3); OperatorNodeConfig config{op.make_name()}; auto* opr = opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param(), config) .node() ->owner_opr(); return {opr->output(0), opr->output(1)}; } OP_TRAIT_REG(ROIPooling, ROIPooling).apply_on_var_node(apply_on_var_node).fallback(); } // namespace roi_pooling } // namespace } // namespace imperative } // namespace mgb