GitOrigin-RevId: a48ea9bff6
tags/v1.3.0
| @@ -16,6 +16,7 @@ | |||
| #include "megbrain/opr/dnn/batch_norm.h" | |||
| #include "megbrain/opr/dnn/local.h" | |||
| #include "megbrain/opr/search_policy/algo_chooser_helper.h" | |||
| #include "megbrain/opr/search_policy/profiler.h" | |||
| #include "megbrain/utils/shared_set.h" | |||
| #include "megbrain/serialization/opr_shallow_copy.h" | |||
| #include "megbrain/opr/basic_arith.h" | |||
| @@ -149,15 +150,6 @@ void inplace_conv_opr_workspace_limit_modifier(OperatorNodeBase& opr, | |||
| } // anonymous namespace | |||
| #define MGB_FOREACH_FASTRUN_OPR(cb) \ | |||
| cb(ConvolutionForward), cb(ConvBiasForward), cb(ConvolutionBackwardData), \ | |||
| cb(ConvolutionBackwardFilter), cb(Convolution3DForward), \ | |||
| cb(Convolution3DBackwardData), cb(Convolution3DBackwardFilter), \ | |||
| cb(LocalShareForward), cb(LocalShareBackwardData), \ | |||
| cb(LocalShareBackwardFilter), cb(DeformableConvForward), \ | |||
| cb(DeformableConvBackwardFilter), cb(DeformableConvBackwardData), \ | |||
| cb(BatchConvBiasForward), | |||
| void gopt::modify_opr_algo_strategy_inplace( | |||
| const VarNodeArrayView& dest_vars, | |||
| opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) { | |||
| @@ -171,7 +163,7 @@ void gopt::modify_opr_algo_strategy_inplace( | |||
| modifiers = { | |||
| #define CONV(t) \ | |||
| {opr::t::typeinfo(), std::bind(inplace_conv_opr_modifier<opr::t>, \ | |||
| std::placeholders::_1, strategy)} | |||
| std::placeholders::_1, strategy)}, | |||
| MGB_FOREACH_FASTRUN_OPR(CONV) | |||
| #undef CONV | |||
| }; | |||
| @@ -209,7 +201,7 @@ void gopt::set_opr_algo_workspace_limit_inplace( | |||
| static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&, size_t)> | |||
| modifiers = { | |||
| #define CONV(t) \ | |||
| {opr::t::typeinfo(), &inplace_conv_opr_workspace_limit_modifier<opr::t>} | |||
| {opr::t::typeinfo(), &inplace_conv_opr_workspace_limit_modifier<opr::t>}, | |||
| MGB_FOREACH_FASTRUN_OPR(CONV) | |||
| #undef CONV | |||
| }; | |||
| @@ -226,7 +218,6 @@ void gopt::set_opr_algo_workspace_limit_inplace( | |||
| dep_iter.add(i); | |||
| } | |||
| } | |||
| #undef MGB_FOREACH_FASTRUN_OPR | |||
| /* ================ ParamRedistributePass ================ */ | |||
| const char* ParamRedistributePass::name() const { | |||
| @@ -790,8 +781,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( | |||
| new_inp[1]->name().c_str(), | |||
| new_inp[1]->owner_opr()->name().c_str()); | |||
| auto new_deconv_opr = opr::ConvolutionBackwardData::make( | |||
| new_inp[0], new_inp[1], new_param, deconv_opr.execution_policy(), | |||
| deconv_opr.config()); | |||
| new_inp[0], new_inp[1], new_param, | |||
| deconv_opr.execution_policy(), deconv_opr.config()); | |||
| return new_deconv_opr.node()->owner_opr(); | |||
| }; | |||
| @@ -813,20 +804,20 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( | |||
| new_inp[1]->owner_opr()->name().c_str()); | |||
| if(opr->input().size() == 2) { | |||
| auto new_conv_opr = opr::ConvBias::make( | |||
| new_inp[0], new_inp[1], new_param, convbias_opr.execution_policy(), | |||
| convbias_opr.config()); | |||
| new_inp[0], new_inp[1], new_param, | |||
| convbias_opr.execution_policy(), convbias_opr.config()); | |||
| return new_conv_opr.node()->owner_opr(); | |||
| } else if(opr->input().size() == 3) { | |||
| auto new_conv_opr = opr::ConvBias::make( | |||
| new_inp[0], new_inp[1], new_inp[2], new_param, convbias_opr.execution_policy(), | |||
| convbias_opr.config()); | |||
| new_inp[0], new_inp[1], new_inp[2], new_param, | |||
| convbias_opr.execution_policy(), convbias_opr.config()); | |||
| return new_conv_opr.node()->owner_opr(); | |||
| } else { | |||
| mgb_assert(opr->input().size() == 4, "invalid input size %zu", | |||
| opr->input().size()); | |||
| auto new_conv_opr = opr::ConvBias::make( | |||
| new_inp[0], new_inp[1], new_inp[2], new_inp[3], new_param, convbias_opr.execution_policy(), | |||
| convbias_opr.config()); | |||
| new_inp[0], new_inp[1], new_inp[2], new_inp[3], new_param, | |||
| convbias_opr.execution_policy(), convbias_opr.config()); | |||
| return new_conv_opr.node()->owner_opr(); | |||
| } | |||
| }; | |||
| @@ -841,7 +832,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( | |||
| megdnn::param::MatrixMul::ComputeMode::FLOAT32; | |||
| } | |||
| auto new_matmul_opr = opr::MatrixMul::make( | |||
| new_inp[0], new_inp[1], new_param, matmul_opr.config()); | |||
| new_inp[0], new_inp[1], new_param, | |||
| matmul_opr.execution_policy(), matmul_opr.config()); | |||
| return new_matmul_opr.node()->owner_opr(); | |||
| }; | |||
| @@ -864,7 +856,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( | |||
| new_inp[1]->name().c_str(), | |||
| new_inp[1]->owner_opr()->name().c_str()); | |||
| auto new_matmul_opr = opr::BatchedMatrixMul::make( | |||
| new_inp[0], new_inp[1], new_param, matmul_opr.config()); | |||
| new_inp[0], new_inp[1], new_param, | |||
| matmul_opr.execution_policy(), matmul_opr.config()); | |||
| return new_matmul_opr.node()->owner_opr(); | |||
| }; | |||
| @@ -915,8 +908,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( | |||
| new_mat->owner_opr()->input(0)->dtype() == dtype::Float32()) | |||
| new_mat = new_mat->owner_opr()->input(0); | |||
| else | |||
| new_mat = | |||
| opr::TypeCvt::make(new_inp[1], dtype::Float32(), {}).node(); | |||
| new_mat = opr::TypeCvt::make(new_inp[1], dtype::Float32(), {}) | |||
| .node(); | |||
| } | |||
| SymbolVar new_warp; | |||
| if (new_inp.size() == 3) { | |||
| @@ -944,8 +937,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( | |||
| new_map->owner_opr()->input(0)->dtype() == dtype::Float32()) | |||
| new_map = new_map->owner_opr()->input(0); | |||
| else | |||
| new_map = | |||
| opr::TypeCvt::make(new_inp[1], dtype::Float32(), {}).node(); | |||
| new_map = opr::TypeCvt::make(new_inp[1], dtype::Float32(), {}) | |||
| .node(); | |||
| } | |||
| SymbolVar new_remap; | |||
| @@ -18,15 +18,41 @@ | |||
| #include "megbrain/opr/tensor_gen.h" | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| #include "megbrain/opr/search_policy/algo_chooser.h" | |||
| #include "megbrain/opr/search_policy/profiler.h" | |||
| #include "./internal/megdnn_opr_wrapper.inl" | |||
| #include "./search_policy/workspace_need_limit_getter.inl" | |||
| #include "megdnn/oprs/linalg.h" | |||
| using namespace mgb; | |||
| using namespace opr; | |||
| namespace { | |||
| int get_mask_from_matmul(const megdnn::param::MatrixMul& param) { | |||
| return static_cast<int>(param.transposeA) + | |||
| (static_cast<int>(param.transposeB) * 2); | |||
| } | |||
| } | |||
| /* ================= MatrixMul ================= */ | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(MatrixMul); | |||
| MEGDNN_OPR_INIT2(MatrixMul, "matrix_mul") | |||
| MatrixMul::MatrixMul(VarNode* a, VarNode* b, const Param& param, | |||
| const ExecutionPolicy& policy, | |||
| const OperatorNodeConfig& config) | |||
| : Super{a->owner_graph(), config, "matrix_mul", {a, b}} { | |||
| init_megdnn_opr(*this, param); | |||
| m_policy = policy; | |||
| add_input({a, b}); | |||
| } | |||
| SymbolVar MatrixMul::make(SymbolVar a, SymbolVar b, const Param& param, | |||
| const ExecutionPolicy& policy, | |||
| const OperatorNodeConfig& config) { | |||
| return a.insert_single_output_opr<MatrixMul>(a.node(), b.node(), param, | |||
| policy, config); | |||
| } | |||
| void MatrixMul::init_output_dtype() { | |||
| DType output_dtype = config().output_dtype(); | |||
| @@ -72,13 +98,32 @@ size_t MatrixMul::get_workspace_size_bytes( | |||
| param ^= 1; | |||
| }; | |||
| MGB_TRY { | |||
| a = mo->get_workspace_in_bytes(i0, i1, out); | |||
| a = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out}, | |||
| megdnn_opr(), this); | |||
| //! Here we just want to save the execution policy got from setup_algo, | |||
| //! while change the delaration of get_workspace_in_bytes may cause | |||
| //! many changes. | |||
| const_cast<MatrixMul*>(this) | |||
| ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = | |||
| megdnn_opr()->execution_policy(); | |||
| transpose(i0, tparam.transposeA); | |||
| b = mo->get_workspace_in_bytes(i0, i1, out); | |||
| b = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out}, | |||
| megdnn_opr(), this); | |||
| const_cast<MatrixMul*>(this) | |||
| ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = | |||
| megdnn_opr()->execution_policy(); | |||
| transpose(i1, tparam.transposeB); | |||
| c = mo->get_workspace_in_bytes(i0, i1, out); | |||
| c = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out}, | |||
| megdnn_opr(), this); | |||
| const_cast<MatrixMul*>(this) | |||
| ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = | |||
| megdnn_opr()->execution_policy(); | |||
| transpose(i0, tparam.transposeA); | |||
| d = mo->get_workspace_in_bytes(i0, i1, out); | |||
| d = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out}, | |||
| megdnn_opr(), this); | |||
| const_cast<MatrixMul*>(this) | |||
| ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = | |||
| megdnn_opr()->execution_policy(); | |||
| } | |||
| MGB_FINALLY({ tparam = this->param(); }); | |||
| return std::max(std::max(a, b), std::max(c, d)); | |||
| @@ -100,6 +145,8 @@ void MatrixMul::scn_do_execute() { | |||
| MGB_TRY { | |||
| transpose(inp0.layout, tparam.transposeA); | |||
| transpose(inp1.layout, tparam.transposeB); | |||
| megdnn_opr()->execution_policy() = | |||
| m_cadidate_execution_policies[get_mask_from_matmul(tparam)]; | |||
| megdnn_opr()->exec(inp0, inp1, out, | |||
| intl::get_megdnn_workspace_from_var(output(1))); | |||
| } | |||
| @@ -134,7 +181,21 @@ MGB_IMPL_OPR_GRAD(MatrixMul) { | |||
| /* ================= BatchedMatrixMul ================= */ | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchedMatrixMul); | |||
| MEGDNN_OPR_INIT2(BatchedMatrixMul, "batched_matrix_mul") | |||
| BatchedMatrixMul::BatchedMatrixMul(VarNode* a, VarNode* b, const Param& param, | |||
| const ExecutionPolicy& policy, | |||
| const OperatorNodeConfig& config) | |||
| : Super{a->owner_graph(), config, "batched_matrix_mul", {a, b}} { | |||
| init_megdnn_opr(*this, param); | |||
| m_policy = policy; | |||
| add_input({a, b}); | |||
| } | |||
| SymbolVar BatchedMatrixMul::make(SymbolVar a, SymbolVar b, const Param& param, | |||
| const ExecutionPolicy& policy, | |||
| const OperatorNodeConfig& config) { | |||
| return a.insert_single_output_opr<BatchedMatrixMul>(a.node(), b.node(), | |||
| param, policy, config); | |||
| } | |||
| void BatchedMatrixMul::add_input_layout_constraint() { | |||
| auto check = [](const TensorLayout& ly) { | |||
| @@ -191,13 +252,29 @@ size_t BatchedMatrixMul::get_workspace_size_bytes( | |||
| param ^= 1; | |||
| }; | |||
| MGB_TRY { | |||
| a = mo->get_workspace_in_bytes(i0, i1, out); | |||
| a = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo( | |||
| {i0, i1, out}, megdnn_opr(), this); | |||
| const_cast<BatchedMatrixMul*>(this) | |||
| ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = | |||
| megdnn_opr()->execution_policy(); | |||
| transpose(i0, tparam.transposeA); | |||
| b = mo->get_workspace_in_bytes(i0, i1, out); | |||
| b = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo( | |||
| {i0, i1, out}, megdnn_opr(), this); | |||
| const_cast<BatchedMatrixMul*>(this) | |||
| ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = | |||
| megdnn_opr()->execution_policy(); | |||
| transpose(i1, tparam.transposeB); | |||
| c = mo->get_workspace_in_bytes(i0, i1, out); | |||
| c = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo( | |||
| {i0, i1, out}, megdnn_opr(), this); | |||
| const_cast<BatchedMatrixMul*>(this) | |||
| ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = | |||
| megdnn_opr()->execution_policy(); | |||
| transpose(i0, tparam.transposeA); | |||
| d = mo->get_workspace_in_bytes(i0, i1, out); | |||
| d = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo( | |||
| {i0, i1, out}, megdnn_opr(), this); | |||
| const_cast<BatchedMatrixMul*>(this) | |||
| ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = | |||
| megdnn_opr()->execution_policy(); | |||
| } | |||
| MGB_FINALLY({ tparam = this->param(); }); | |||
| return std::max(std::max(a, b), std::max(c, d)); | |||
| @@ -220,6 +297,8 @@ void BatchedMatrixMul::scn_do_execute() { | |||
| MGB_TRY { | |||
| transpose(inp0.layout, tparam.transposeA); | |||
| transpose(inp1.layout, tparam.transposeB); | |||
| megdnn_opr()->execution_policy() = | |||
| m_cadidate_execution_policies[get_mask_from_matmul(tparam)]; | |||
| megdnn_opr()->exec(inp0, inp1, out, | |||
| intl::get_megdnn_workspace_from_var(output(1))); | |||
| } | |||
| @@ -14,12 +14,14 @@ decl_opr('BatchedMatrixMul', | |||
| 'performed and output shape is (n, a, c)') | |||
| decl_opr('MatrixMul', | |||
| pyname='matrix_mul_v2', | |||
| inputs=['opr0', 'opr1'], | |||
| params='MatrixMul', | |||
| desc='matrix multiplication', | |||
| version=2, has_out_dtype=True) | |||
| decl_opr('BatchedMatrixMul', | |||
| pyname='batched_matrix_mul_v2', | |||
| inputs=['opr0', 'opr1'], | |||
| params='MatrixMul', | |||
| desc='batched matrix multiplication: input shapes should be ' | |||
| @@ -28,6 +30,23 @@ decl_opr('BatchedMatrixMul', | |||
| 'performed and output shape is (n, a, c)', | |||
| version=2, has_out_dtype=True) | |||
| decl_opr('MatrixMul', | |||
| inputs=['opr0', 'opr1'], | |||
| params=[('param', 'MatrixMul'), | |||
| ('execution_polity', 'ExecutionPolicy')], | |||
| desc='matrix multiplication', | |||
| version=3, has_out_dtype=True) | |||
| decl_opr('BatchedMatrixMul', | |||
| inputs=['opr0', 'opr1'], | |||
| params=[('param', 'MatrixMul'), | |||
| ('execution_polity', 'ExecutionPolicy')], | |||
| desc='batched matrix multiplication: input shapes should be ' | |||
| '(n, a, b) and (n, b, c) (assuming transposeA and transeposeB are ' | |||
| 'False); then :math:`n` independent matrix multiplications would be ' | |||
| 'performed and output shape is (n, a, c)', | |||
| version=3, has_out_dtype=True) | |||
| decl_opr('Dot', | |||
| inputs=['opr0', 'opr1'], | |||
| params='Empty', | |||
| @@ -10,7 +10,10 @@ | |||
| */ | |||
| #include "megbrain/opr/blas.h" | |||
| #include "megbrain/opr/param_defs.h" | |||
| #include "megbrain/serialization/sereg.h" | |||
| #include "megdnn/opr_param_defs.h" | |||
| #include "megdnn/oprs/linalg.h" | |||
| namespace mgb { | |||
| namespace serialization { | |||
| @@ -27,14 +30,70 @@ struct OprMaker<opr::SVD, 1> { | |||
| } | |||
| }; | |||
| template <class MegDNNConv = megdnn::MatrixMul> | |||
| struct MakeMatrixMulCaller { | |||
| template <typename Opr> | |||
| static VarNode* make(const cg::VarNodeArray& inputs, | |||
| const typename MegDNNConv::Param& param, | |||
| const megdnn::param::ExecutionPolicy& execution_policy, | |||
| const OperatorNodeConfig& config) { | |||
| if (inputs.size() == 2) { | |||
| return Opr::make(inputs[0], inputs[1], param, execution_policy, | |||
| config) | |||
| .node(); | |||
| } | |||
| return nullptr; | |||
| } | |||
| }; | |||
| template <class Opr, class Maker, class MegDNNMatrixMul> | |||
| struct MatrixMulLoadDumpImpl { | |||
| static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | |||
| auto&& opr = opr_.cast_final_safe<Opr>(); | |||
| ctx.write_param<megdnn::param::MatrixMul>(opr.param()); | |||
| ctx.write_param<megdnn::param::ExecutionPolicy>(opr.execution_policy()); | |||
| } | |||
| static VarNode* make(const cg::VarNodeArray& inputs, | |||
| const megdnn::param::MatrixMul& param, | |||
| const megdnn::param::ExecutionPolicy& execution_policy, | |||
| const OperatorNodeConfig& config) { | |||
| VarNode* ret = Maker::template make<Opr>(inputs, param, | |||
| execution_policy, config); | |||
| mgb_assert(ret); | |||
| return ret; | |||
| } | |||
| static cg::OperatorNodeBase* load(OprLoadContext& ctx, | |||
| const cg::VarNodeArray& inputs, | |||
| const OperatorNodeConfig& config) { | |||
| auto param = ctx.read_param<megdnn::param::MatrixMul>(); | |||
| auto execution_policy = | |||
| ctx.read_param<megdnn::param::ExecutionPolicy>(); | |||
| return make(inputs, param, execution_policy, config)->owner_opr(); | |||
| } | |||
| }; | |||
| template <> | |||
| struct OprLoadDumpImpl<opr::MatrixMul, 2> | |||
| : public MatrixMulLoadDumpImpl<opr::MatrixMul, | |||
| MakeMatrixMulCaller<megdnn::MatrixMul>, | |||
| megdnn::MatrixMul> {}; | |||
| template <> | |||
| struct OprLoadDumpImpl<opr::BatchedMatrixMul, 2> | |||
| : public MatrixMulLoadDumpImpl< | |||
| opr::BatchedMatrixMul, | |||
| MakeMatrixMulCaller<megdnn::BatchedMatrixMul>, | |||
| megdnn::BatchedMatrixMul> {}; | |||
| } // namespace serialization | |||
| namespace opr { | |||
| using MatrixMulV2 = MatrixMul; | |||
| using BatchedMatrixMulV2 = BatchedMatrixMul; | |||
| MGB_SEREG_OPR(MatrixMulV2, 2); | |||
| MGB_SEREG_OPR(BatchedMatrixMulV2, 2); | |||
| using MatrixMulV3 = MatrixMul; | |||
| using BatchedMatrixMulV3 = BatchedMatrixMul; | |||
| MGB_SEREG_OPR(MatrixMulV3, 2); | |||
| MGB_SEREG_OPR(BatchedMatrixMulV3, 2); | |||
| MGB_SEREG_OPR(Dot, 2); | |||
| MGB_SEREG_OPR(MatrixInverse, 1); | |||
| MGB_SEREG_OPR(SVD, 1); | |||
| @@ -1636,6 +1636,5 @@ void BatchConvBiasForward::init_output_format() { | |||
| } | |||
| #undef IMPL_CONV | |||
| #undef MGB_FOREACH_FASTRUN_OPR | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -98,7 +98,7 @@ namespace serialization { | |||
| return nullptr; | |||
| } | |||
| }; | |||
| template<class Opr, class Maker0, class MegDNNConv, | |||
| class Maker1=MakeConvCallerEmpty<MegDNNConv>, | |||
| @@ -292,7 +292,7 @@ namespace serialization { | |||
| return nullptr; | |||
| } | |||
| }; | |||
| template<class Opr, class Maker0, class MegDNNConv, | |||
| class Maker1=MakeLocalShareCallerEmpty<MegDNNConv>, | |||
| class Maker2=MakeLocalShareCallerEmpty<MegDNNConv>, | |||
| @@ -251,6 +251,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(bool reproducible) const { | |||
| auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | |||
| opr->owner_graph(), opr->comp_node(), | |||
| opr->execution_policy().workspace_limit); | |||
| m_megdnn_opr->execution_policy() = {}; | |||
| return APPLY(m_megdnn_opr->get_algorithm_info_heuristic( | |||
| args..., workspace_limit, reproducible), | |||
| m_layouts); | |||
| @@ -12,6 +12,7 @@ | |||
| #pragma once | |||
| #include "megbrain/exception.h" | |||
| #include "megbrain/opr/search_policy/algo_chooser_helper.h" | |||
| #include "megbrain/tensor.h" | |||
| #include "megbrain/graph.h" | |||
| #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
| @@ -24,51 +25,58 @@ namespace opr { | |||
| /*! | |||
| * \brief matrix_mul(trans0(opr0), trans1(opr1)) | |||
| */ | |||
| MGB_DEFINE_OPR_CLASS(MatrixMul, | |||
| intl::MegDNNOprWrapperFwd<megdnn::MatrixMul>) // { | |||
| public: | |||
| MatrixMul(VarNode *opr0, VarNode *opr1, | |||
| const Param ¶m, const OperatorNodeConfig &config); | |||
| static SymbolVar make(SymbolVar opr0, SymbolVar opr1, | |||
| const Param ¶m = {}, | |||
| const OperatorNodeConfig &config = {}); | |||
| private: | |||
| void add_input_layout_constraint() override; | |||
| void scn_do_execute() override; | |||
| void init_output_dtype() override; | |||
| size_t get_workspace_size_bytes( | |||
| const TensorShapeArray &input_shapes, | |||
| const TensorShapeArray &output_shapes) const override; | |||
| static bool check_layout(const TensorLayout &layout, int transpose); | |||
| MGB_DEFINE_OPR_CLASS(MatrixMul, intl::MegDNNOprWrapperFwd<megdnn::MatrixMul>, | |||
| public mixin::AlgoChooserHelper) // { | |||
| public: | |||
| using AlgorithmInfo = megdnn::detail::Algorithm::Info; | |||
| MatrixMul(VarNode* opr0, VarNode* opr1, const Param& param, | |||
| const ExecutionPolicy& policy, const OperatorNodeConfig& config); | |||
| static SymbolVar make(SymbolVar opr0, SymbolVar opr1, | |||
| const Param& param = {}, | |||
| const ExecutionPolicy& policy = {}, | |||
| const OperatorNodeConfig& config = {}); | |||
| private: | |||
| void add_input_layout_constraint() override; | |||
| void scn_do_execute() override; | |||
| void init_output_dtype() override; | |||
| size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes, | |||
| const TensorShapeArray& output_shapes) | |||
| const override; | |||
| static bool check_layout(const TensorLayout& layout, int transpose); | |||
| //! store the policy of all transpose situations | |||
| megdnn::MatrixMul::ExecutionPolicy m_cadidate_execution_policies[4]; | |||
| }; | |||
| /*! | |||
| * \brief batched matrix multiplication on 3D inputs | |||
| */ | |||
| MGB_DEFINE_OPR_CLASS(BatchedMatrixMul, | |||
| intl::MegDNNOprWrapperFwd<megdnn::BatchedMatrixMul>) // { | |||
| public: | |||
| BatchedMatrixMul(VarNode *opr0, VarNode *opr1, | |||
| const Param ¶m, const OperatorNodeConfig &config); | |||
| static SymbolVar make(SymbolVar opr0, SymbolVar opr1, | |||
| const Param ¶m = {}, | |||
| const OperatorNodeConfig &config = {}); | |||
| private: | |||
| void add_input_layout_constraint() override; | |||
| void init_output_dtype() override; | |||
| void scn_do_execute() override; | |||
| size_t get_workspace_size_bytes( | |||
| const TensorShapeArray &input_shapes, | |||
| const TensorShapeArray &output_shapes) const override; | |||
| static bool check_layout(const TensorLayout &layout, bool transpose); | |||
| intl::MegDNNOprWrapperFwd<megdnn::BatchedMatrixMul>, | |||
| public mixin::AlgoChooserHelper) // { | |||
| public: | |||
| using AlgorithmInfo = megdnn::detail::Algorithm::Info; | |||
| BatchedMatrixMul(VarNode* opr0, VarNode* opr1, const Param& param, | |||
| const ExecutionPolicy& policy, | |||
| const OperatorNodeConfig& config); | |||
| static SymbolVar make(SymbolVar opr0, SymbolVar opr1, | |||
| const Param& param = {}, | |||
| const ExecutionPolicy& policy = {}, | |||
| const OperatorNodeConfig& config = {}); | |||
| private: | |||
| void add_input_layout_constraint() override; | |||
| void init_output_dtype() override; | |||
| void scn_do_execute() override; | |||
| size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes, | |||
| const TensorShapeArray& output_shapes) | |||
| const override; | |||
| static bool check_layout(const TensorLayout& layout, bool transpose); | |||
| //! store the policy of all transpose situations | |||
| megdnn::BatchedMatrixMul::ExecutionPolicy m_cadidate_execution_policies[4]; | |||
| }; | |||
| /*! | |||
| @@ -109,4 +117,3 @@ MGB_DEFINE_OPR_CLASS(SVD, intl::MegDNNOprWrapperFwd<megdnn::SVD>) // { | |||
| } // mgb | |||
| // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -14,6 +14,7 @@ | |||
| #include "megbrain/opr/search_policy/profiler.h" | |||
| #include "megbrain/opr/dnn/convolution.h" | |||
| #include "megbrain/opr/blas.h" | |||
| template <class MegDNNOpr> | |||
| struct MegDNNOpr2MGBOpr; | |||
| @@ -18,26 +18,31 @@ | |||
| #include "megbrain/comp_node.h" | |||
| #include "megdnn/basic_types.h" | |||
| #include "megdnn/oprs/linalg.h" | |||
| #include "megdnn/oprs/nn.h" | |||
| namespace mgb { | |||
| namespace opr { | |||
| #define MGB_FOREACH_FASTRUN_OPR(cb) \ | |||
| cb(ConvolutionForward); \ | |||
| cb(ConvBiasForward); \ | |||
| cb(ConvolutionBackwardData); \ | |||
| cb(ConvolutionBackwardFilter); \ | |||
| cb(Convolution3DForward); \ | |||
| cb(Convolution3DBackwardData); \ | |||
| cb(Convolution3DBackwardFilter); \ | |||
| cb(LocalShareForward); \ | |||
| cb(LocalShareBackwardData); \ | |||
| cb(LocalShareBackwardFilter); \ | |||
| cb(DeformableConvForward); \ | |||
| cb(DeformableConvBackwardFilter); \ | |||
| cb(DeformableConvBackwardData); \ | |||
| cb(BatchConvBiasForward); | |||
| // clang-format off | |||
| #define MGB_FOREACH_FASTRUN_OPR(cb) \ | |||
| cb(ConvolutionForward) \ | |||
| cb(ConvBiasForward) \ | |||
| cb(ConvolutionBackwardData) \ | |||
| cb(ConvolutionBackwardFilter) \ | |||
| cb(Convolution3DForward) \ | |||
| cb(Convolution3DBackwardData) \ | |||
| cb(Convolution3DBackwardFilter) \ | |||
| cb(LocalShareForward) \ | |||
| cb(LocalShareBackwardData) \ | |||
| cb(LocalShareBackwardFilter) \ | |||
| cb(DeformableConvForward) \ | |||
| cb(DeformableConvBackwardFilter) \ | |||
| cb(DeformableConvBackwardData) \ | |||
| cb(BatchConvBiasForward) \ | |||
| cb(MatrixMul) \ | |||
| cb(BatchedMatrixMul) | |||
| // clang-format on | |||
| template <typename Opr> | |||
| struct OprArityTrait; | |||
| @@ -67,6 +72,8 @@ INST_ARITY(megdnn::DeformableConvBackwardFilter, 4, 1); | |||
| INST_ARITY(megdnn::BatchConvBiasForward, 4, 1); | |||
| INST_ARITY(megdnn::ConvBias, 4, 1); | |||
| INST_ARITY(megdnn::DeformableConvBackwardData, 5, 3); | |||
| INST_ARITY(megdnn::MatrixMul, 2, 1); | |||
| INST_ARITY(megdnn::BatchedMatrixMul, 2, 1); | |||
| #undef INST_ARITY | |||
| @@ -269,7 +269,7 @@ void run_trans_inp_test_case(bool trans_a, bool trans_b) { | |||
| if (DTypeTrait<dt_dst>::enumv == DTypeEnum::Int16) { | |||
| config.output_dtype(dtype::Int16()); | |||
| } | |||
| auto z = opr::MatrixMul::make(x, y, {}, config); | |||
| auto z = opr::MatrixMul::make(x, y, {}, {}, config); | |||
| HostTensorND host_z; | |||
| auto func = graph->compile({make_callback_copy(z, host_z)}); | |||
| @@ -359,7 +359,7 @@ void run_bgemm_trans_inp_test_case(bool trans_a, bool trans_b) { | |||
| trans_a ? (x = opr::Dimshuffle::make(x, {0, 2, 1})) : 0; | |||
| trans_b ? (y = opr::Dimshuffle::make(y, {0, 2, 1})) : 0; | |||
| auto z = opr::BatchedMatrixMul::make(x, y, {}, OperatorNodeConfig{}); | |||
| auto z = opr::BatchedMatrixMul::make(x, y, {}, {}, OperatorNodeConfig{}); | |||
| HostTensorND host_z; | |||
| auto func = graph->compile({make_callback_copy(z, host_z)}); | |||
| auto run = [&](size_t B, size_t M, size_t K, size_t N) { | |||
| @@ -420,6 +420,43 @@ TEST(TestOprBlas, MatrixMul_TT) { | |||
| run_sgemm_test(true, true); | |||
| } | |||
| TEST(TestOprDNN, MatrixMulExePolicy) { | |||
| using Param = opr::MatrixMul::Param; | |||
| Param param; | |||
| using Policy = opr::MatrixMul::ExecutionPolicy; | |||
| using S = Policy::Strategy; | |||
| auto cn = CompNode::load("cpux"); | |||
| #if MGB_ENABLE_FASTRUN | |||
| for (auto strategy : {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE, | |||
| S::PROFILE_HEURISTIC}) { | |||
| #else | |||
| for (auto strategy: {S:HEURISTIC, S::PROFILE_HEURISTIC}) { | |||
| #endif | |||
| auto graph = ComputingGraph::make(); | |||
| HostTensorGenerator<> gen; | |||
| auto mkvar = [&](const char* name, const TensorShape& shp) { | |||
| return opr::Host2DeviceCopy::make(*graph, gen(shp), cn) | |||
| .rename(name); | |||
| }; | |||
| auto A = mkvar("A", {32, 64}); | |||
| auto B = mkvar("B", {64, 32}); | |||
| Policy policy; | |||
| policy.strategy = strategy; | |||
| auto C = opr::MatrixMul::make(A, B, param, policy); | |||
| HostTensorND host_c; | |||
| auto func = graph->compile({make_callback_copy(C, host_c)}); | |||
| func->execute(); | |||
| } | |||
| } | |||
| TEST(TestOprBlas, BatchedMatrixMulFp32_NN) { | |||
| run_batched_sgemm_test(false, false); | |||
| } | |||