GitOrigin-RevId: 05223deca7
tags/v1.11.0
| @@ -1406,6 +1406,7 @@ public: | |||||
| protected: | protected: | ||||
| SmallVector<size_t> get_offsets(); | SmallVector<size_t> get_offsets(); | ||||
| MGE_WIN_DECLSPEC_FUC static SmallVector<size_t> get_offsets_impl(const Param& p); | |||||
| void check_exec(const TensorLayout& src, const TensorLayout& dst); | void check_exec(const TensorLayout& src, const TensorLayout& dst); | ||||
| }; | }; | ||||
| @@ -1421,6 +1422,9 @@ public: | |||||
| const TensorLayout& src, const TensorLayout& dst) = 0; | const TensorLayout& src, const TensorLayout& dst) = 0; | ||||
| void deduce_layout(const TensorLayout& src, TensorLayout& dst); | void deduce_layout(const TensorLayout& src, TensorLayout& dst); | ||||
| MGE_WIN_DECLSPEC_FUC static void deduce_layout_impl( | |||||
| const TensorLayout& src, TensorLayout& dst, const Param& p); | |||||
| protected: | protected: | ||||
| void forward_check_exec(const TensorLayout& src, const TensorLayout& dst); | void forward_check_exec(const TensorLayout& src, const TensorLayout& dst); | ||||
| }; | }; | ||||
| @@ -7,6 +7,7 @@ | |||||
| namespace megdnn { | namespace megdnn { | ||||
| using padding_param = megdnn::param_enumv::Padding; | using padding_param = megdnn::param_enumv::Padding; | ||||
| using Param = PaddingBase::Param; | |||||
| void PaddingForward::forward_check_exec( | void PaddingForward::forward_check_exec( | ||||
| const TensorLayout& src, const TensorLayout& dst) { | const TensorLayout& src, const TensorLayout& dst) { | ||||
| @@ -19,8 +20,9 @@ void PaddingForward::forward_check_exec( | |||||
| "unsupported %s dtype for forward padding opr", src.dtype.name()); | "unsupported %s dtype for forward padding opr", src.dtype.name()); | ||||
| } | } | ||||
| void PaddingForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) { | |||||
| SmallVector<size_t> offsets(get_offsets()); | |||||
| void PaddingForward::deduce_layout_impl( | |||||
| const TensorLayout& src, TensorLayout& dst, const Param& p) { | |||||
| SmallVector<size_t> offsets(get_offsets_impl(p)); | |||||
| TensorShape dst_shape; | TensorShape dst_shape; | ||||
| switch (src.ndim) { | switch (src.ndim) { | ||||
| case 1: | case 1: | ||||
| @@ -76,6 +78,10 @@ void PaddingForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) { | |||||
| dst = TensorLayout(dst_shape, src.dtype); | dst = TensorLayout(dst_shape, src.dtype); | ||||
| } | } | ||||
| void PaddingForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) { | |||||
| return deduce_layout_impl(src, dst, param()); | |||||
| } | |||||
| void PaddingBackward::backward_check_exec( | void PaddingBackward::backward_check_exec( | ||||
| const TensorLayout& src, const TensorLayout& dst) { | const TensorLayout& src, const TensorLayout& dst) { | ||||
| check_exec(dst, src); | check_exec(dst, src); | ||||
| @@ -86,17 +92,20 @@ void PaddingBackward::backward_check_exec( | |||||
| "unsupported %s dtype for forward padding opr", src.dtype.name()); | "unsupported %s dtype for forward padding opr", src.dtype.name()); | ||||
| } | } | ||||
| SmallVector<size_t> PaddingBase::get_offsets() { | |||||
| SmallVector<size_t> offsets = {param().front_offset_dim0, param().back_offset_dim0, | |||||
| param().front_offset_dim1, param().back_offset_dim1, | |||||
| param().front_offset_dim2, param().back_offset_dim2, | |||||
| param().front_offset_dim3, param().back_offset_dim3, | |||||
| param().front_offset_dim4, param().back_offset_dim4, | |||||
| param().front_offset_dim5, param().back_offset_dim5, | |||||
| param().front_offset_dim6, param().back_offset_dim6}; | |||||
| SmallVector<size_t> PaddingBase::get_offsets_impl(const Param& p) { | |||||
| SmallVector<size_t> offsets = { | |||||
| p.front_offset_dim0, p.back_offset_dim0, p.front_offset_dim1, | |||||
| p.back_offset_dim1, p.front_offset_dim2, p.back_offset_dim2, | |||||
| p.front_offset_dim3, p.back_offset_dim3, p.front_offset_dim4, | |||||
| p.back_offset_dim4, p.front_offset_dim5, p.back_offset_dim5, | |||||
| p.front_offset_dim6, p.back_offset_dim6}; | |||||
| return offsets; | return offsets; | ||||
| } | } | ||||
| SmallVector<size_t> PaddingBase::get_offsets() { | |||||
| return get_offsets_impl(param()); | |||||
| } | |||||
| void PaddingBase::check_exec(const TensorLayout& src, const TensorLayout& dst) { | void PaddingBase::check_exec(const TensorLayout& src, const TensorLayout& dst) { | ||||
| SmallVector<size_t> offsets(get_offsets()); | SmallVector<size_t> offsets(get_offsets()); | ||||
| // make sure the src and dst tensor not empty | // make sure the src and dst tensor not empty | ||||
| @@ -0,0 +1,77 @@ | |||||
| #include "megbrain/graph/symbol_var.h" | |||||
| #include "megbrain/imperative/ops/autogen.h" | |||||
| #include "megbrain/imperative/physical_tensor.h" | |||||
| #include "megbrain/imperative/proxy_graph_detail.h" | |||||
| #include "megbrain/opr/basic_arith.h" | |||||
| #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||||
| #include "megbrain/opr/io.h" | |||||
| #include "megbrain/opr/tensor_manip.h" | |||||
| #include "megdnn/dtype.h" | |||||
| #include "../blob_manager_impl.h" | |||||
| #include "../dnn_op_helper.h" | |||||
| #include "../op_trait.h" | |||||
| namespace mgb { | |||||
| namespace imperative { | |||||
| namespace { | |||||
| namespace padding { | |||||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
| auto&& op = static_cast<const Padding&>(def); | |||||
| mgb_assert(inputs.size() == 1); | |||||
| return opr::Padding::make(inputs[0], op.param()); | |||||
| } | |||||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||||
| auto comp_node = inputs[0]->comp_node(); | |||||
| auto&& op_def = def.cast_final_safe<Padding>(); | |||||
| DnnOprCaller<megdnn::Padding> dnn_op(comp_node); | |||||
| dnn_op.op->param() = op_def.param(); | |||||
| TensorLayout dst = output_descs[0].layout; | |||||
| if (!validated) { | |||||
| megdnn::Padding::deduce_layout_impl( | |||||
| inputs[0]->dnn_tensor().layout, dst, op_def.param()); | |||||
| } | |||||
| DeviceTensorND out = | |||||
| BlobManager::inst()->alloc_workspace_with_defrag(comp_node, dst); | |||||
| dnn_op.op->exec(inputs[0]->dnn_tensor(), out.as_megdnn()); | |||||
| return {Tensor::make(out)}; | |||||
| } | |||||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||||
| auto&& op_def = def.cast_final_safe<Padding>(); | |||||
| size_t nr_inp = inputs.size(); | |||||
| auto p = op_def.param(); | |||||
| auto&& inp = inputs[0]; | |||||
| auto& inp_cn = inp.comp_node; | |||||
| if (inp.layout.ndim == 0) { | |||||
| return {{{TensorLayout{inp.layout.dtype}, inp_cn, {}}}, false}; | |||||
| } | |||||
| TensorLayout oup_layout; | |||||
| megdnn::Padding::deduce_layout_impl(inp.layout, oup_layout, p); | |||||
| return {{{oup_layout, inp_cn, {}}}, true}; | |||||
| } | |||||
| OP_TRAIT_REG(Padding, Padding, opr::Padding) | |||||
| .apply_on_var_node(apply_on_var_node) | |||||
| .apply_on_physical_tensor(apply_on_physical_tensor) | |||||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||||
| .fallback(); | |||||
| } // namespace padding | |||||
| } // namespace | |||||
| } // namespace imperative | |||||
| } // namespace mgb | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| @@ -664,15 +664,6 @@ OP_TRAIT_REG(Cumsum, Cumsum).apply_on_var_node(apply_on_var_node).fallback(); | |||||
| } // namespace cumsum | } // namespace cumsum | ||||
| } // namespace | } // namespace | ||||
| namespace padding { | |||||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
| auto&& op = static_cast<const Padding&>(def); | |||||
| mgb_assert(inputs.size() == 1); | |||||
| return opr::Padding::make(inputs[0], op.param()); | |||||
| } | |||||
| OP_TRAIT_REG(Padding, Padding).apply_on_var_node(apply_on_var_node).fallback(); | |||||
| } // namespace padding | |||||
| namespace lrn { | namespace lrn { | ||||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | ||||
| auto&& op = static_cast<const LRN&>(def); | auto&& op = static_cast<const LRN&>(def); | ||||