GitOrigin-RevId: a48ea9bff6
tags/v1.3.0
| @@ -16,6 +16,7 @@ | |||||
| #include "megbrain/opr/dnn/batch_norm.h" | #include "megbrain/opr/dnn/batch_norm.h" | ||||
| #include "megbrain/opr/dnn/local.h" | #include "megbrain/opr/dnn/local.h" | ||||
| #include "megbrain/opr/search_policy/algo_chooser_helper.h" | #include "megbrain/opr/search_policy/algo_chooser_helper.h" | ||||
| #include "megbrain/opr/search_policy/profiler.h" | |||||
| #include "megbrain/utils/shared_set.h" | #include "megbrain/utils/shared_set.h" | ||||
| #include "megbrain/serialization/opr_shallow_copy.h" | #include "megbrain/serialization/opr_shallow_copy.h" | ||||
| #include "megbrain/opr/basic_arith.h" | #include "megbrain/opr/basic_arith.h" | ||||
| @@ -149,15 +150,6 @@ void inplace_conv_opr_workspace_limit_modifier(OperatorNodeBase& opr, | |||||
| } // anonymous namespace | } // anonymous namespace | ||||
| #define MGB_FOREACH_FASTRUN_OPR(cb) \ | |||||
| cb(ConvolutionForward), cb(ConvBiasForward), cb(ConvolutionBackwardData), \ | |||||
| cb(ConvolutionBackwardFilter), cb(Convolution3DForward), \ | |||||
| cb(Convolution3DBackwardData), cb(Convolution3DBackwardFilter), \ | |||||
| cb(LocalShareForward), cb(LocalShareBackwardData), \ | |||||
| cb(LocalShareBackwardFilter), cb(DeformableConvForward), \ | |||||
| cb(DeformableConvBackwardFilter), cb(DeformableConvBackwardData), \ | |||||
| cb(BatchConvBiasForward), | |||||
| void gopt::modify_opr_algo_strategy_inplace( | void gopt::modify_opr_algo_strategy_inplace( | ||||
| const VarNodeArrayView& dest_vars, | const VarNodeArrayView& dest_vars, | ||||
| opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) { | opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) { | ||||
| @@ -171,7 +163,7 @@ void gopt::modify_opr_algo_strategy_inplace( | |||||
| modifiers = { | modifiers = { | ||||
| #define CONV(t) \ | #define CONV(t) \ | ||||
| {opr::t::typeinfo(), std::bind(inplace_conv_opr_modifier<opr::t>, \ | {opr::t::typeinfo(), std::bind(inplace_conv_opr_modifier<opr::t>, \ | ||||
| std::placeholders::_1, strategy)} | |||||
| std::placeholders::_1, strategy)}, | |||||
| MGB_FOREACH_FASTRUN_OPR(CONV) | MGB_FOREACH_FASTRUN_OPR(CONV) | ||||
| #undef CONV | #undef CONV | ||||
| }; | }; | ||||
| @@ -209,7 +201,7 @@ void gopt::set_opr_algo_workspace_limit_inplace( | |||||
| static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&, size_t)> | static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&, size_t)> | ||||
| modifiers = { | modifiers = { | ||||
| #define CONV(t) \ | #define CONV(t) \ | ||||
| {opr::t::typeinfo(), &inplace_conv_opr_workspace_limit_modifier<opr::t>} | |||||
| {opr::t::typeinfo(), &inplace_conv_opr_workspace_limit_modifier<opr::t>}, | |||||
| MGB_FOREACH_FASTRUN_OPR(CONV) | MGB_FOREACH_FASTRUN_OPR(CONV) | ||||
| #undef CONV | #undef CONV | ||||
| }; | }; | ||||
| @@ -226,7 +218,6 @@ void gopt::set_opr_algo_workspace_limit_inplace( | |||||
| dep_iter.add(i); | dep_iter.add(i); | ||||
| } | } | ||||
| } | } | ||||
| #undef MGB_FOREACH_FASTRUN_OPR | |||||
| /* ================ ParamRedistributePass ================ */ | /* ================ ParamRedistributePass ================ */ | ||||
| const char* ParamRedistributePass::name() const { | const char* ParamRedistributePass::name() const { | ||||
| @@ -790,8 +781,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( | |||||
| new_inp[1]->name().c_str(), | new_inp[1]->name().c_str(), | ||||
| new_inp[1]->owner_opr()->name().c_str()); | new_inp[1]->owner_opr()->name().c_str()); | ||||
| auto new_deconv_opr = opr::ConvolutionBackwardData::make( | auto new_deconv_opr = opr::ConvolutionBackwardData::make( | ||||
| new_inp[0], new_inp[1], new_param, deconv_opr.execution_policy(), | |||||
| deconv_opr.config()); | |||||
| new_inp[0], new_inp[1], new_param, | |||||
| deconv_opr.execution_policy(), deconv_opr.config()); | |||||
| return new_deconv_opr.node()->owner_opr(); | return new_deconv_opr.node()->owner_opr(); | ||||
| }; | }; | ||||
| @@ -813,20 +804,20 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( | |||||
| new_inp[1]->owner_opr()->name().c_str()); | new_inp[1]->owner_opr()->name().c_str()); | ||||
| if(opr->input().size() == 2) { | if(opr->input().size() == 2) { | ||||
| auto new_conv_opr = opr::ConvBias::make( | auto new_conv_opr = opr::ConvBias::make( | ||||
| new_inp[0], new_inp[1], new_param, convbias_opr.execution_policy(), | |||||
| convbias_opr.config()); | |||||
| new_inp[0], new_inp[1], new_param, | |||||
| convbias_opr.execution_policy(), convbias_opr.config()); | |||||
| return new_conv_opr.node()->owner_opr(); | return new_conv_opr.node()->owner_opr(); | ||||
| } else if(opr->input().size() == 3) { | } else if(opr->input().size() == 3) { | ||||
| auto new_conv_opr = opr::ConvBias::make( | auto new_conv_opr = opr::ConvBias::make( | ||||
| new_inp[0], new_inp[1], new_inp[2], new_param, convbias_opr.execution_policy(), | |||||
| convbias_opr.config()); | |||||
| new_inp[0], new_inp[1], new_inp[2], new_param, | |||||
| convbias_opr.execution_policy(), convbias_opr.config()); | |||||
| return new_conv_opr.node()->owner_opr(); | return new_conv_opr.node()->owner_opr(); | ||||
| } else { | } else { | ||||
| mgb_assert(opr->input().size() == 4, "invalid input size %zu", | mgb_assert(opr->input().size() == 4, "invalid input size %zu", | ||||
| opr->input().size()); | opr->input().size()); | ||||
| auto new_conv_opr = opr::ConvBias::make( | auto new_conv_opr = opr::ConvBias::make( | ||||
| new_inp[0], new_inp[1], new_inp[2], new_inp[3], new_param, convbias_opr.execution_policy(), | |||||
| convbias_opr.config()); | |||||
| new_inp[0], new_inp[1], new_inp[2], new_inp[3], new_param, | |||||
| convbias_opr.execution_policy(), convbias_opr.config()); | |||||
| return new_conv_opr.node()->owner_opr(); | return new_conv_opr.node()->owner_opr(); | ||||
| } | } | ||||
| }; | }; | ||||
| @@ -841,7 +832,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( | |||||
| megdnn::param::MatrixMul::ComputeMode::FLOAT32; | megdnn::param::MatrixMul::ComputeMode::FLOAT32; | ||||
| } | } | ||||
| auto new_matmul_opr = opr::MatrixMul::make( | auto new_matmul_opr = opr::MatrixMul::make( | ||||
| new_inp[0], new_inp[1], new_param, matmul_opr.config()); | |||||
| new_inp[0], new_inp[1], new_param, | |||||
| matmul_opr.execution_policy(), matmul_opr.config()); | |||||
| return new_matmul_opr.node()->owner_opr(); | return new_matmul_opr.node()->owner_opr(); | ||||
| }; | }; | ||||
| @@ -864,7 +856,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( | |||||
| new_inp[1]->name().c_str(), | new_inp[1]->name().c_str(), | ||||
| new_inp[1]->owner_opr()->name().c_str()); | new_inp[1]->owner_opr()->name().c_str()); | ||||
| auto new_matmul_opr = opr::BatchedMatrixMul::make( | auto new_matmul_opr = opr::BatchedMatrixMul::make( | ||||
| new_inp[0], new_inp[1], new_param, matmul_opr.config()); | |||||
| new_inp[0], new_inp[1], new_param, | |||||
| matmul_opr.execution_policy(), matmul_opr.config()); | |||||
| return new_matmul_opr.node()->owner_opr(); | return new_matmul_opr.node()->owner_opr(); | ||||
| }; | }; | ||||
| @@ -915,8 +908,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( | |||||
| new_mat->owner_opr()->input(0)->dtype() == dtype::Float32()) | new_mat->owner_opr()->input(0)->dtype() == dtype::Float32()) | ||||
| new_mat = new_mat->owner_opr()->input(0); | new_mat = new_mat->owner_opr()->input(0); | ||||
| else | else | ||||
| new_mat = | |||||
| opr::TypeCvt::make(new_inp[1], dtype::Float32(), {}).node(); | |||||
| new_mat = opr::TypeCvt::make(new_inp[1], dtype::Float32(), {}) | |||||
| .node(); | |||||
| } | } | ||||
| SymbolVar new_warp; | SymbolVar new_warp; | ||||
| if (new_inp.size() == 3) { | if (new_inp.size() == 3) { | ||||
| @@ -944,8 +937,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( | |||||
| new_map->owner_opr()->input(0)->dtype() == dtype::Float32()) | new_map->owner_opr()->input(0)->dtype() == dtype::Float32()) | ||||
| new_map = new_map->owner_opr()->input(0); | new_map = new_map->owner_opr()->input(0); | ||||
| else | else | ||||
| new_map = | |||||
| opr::TypeCvt::make(new_inp[1], dtype::Float32(), {}).node(); | |||||
| new_map = opr::TypeCvt::make(new_inp[1], dtype::Float32(), {}) | |||||
| .node(); | |||||
| } | } | ||||
| SymbolVar new_remap; | SymbolVar new_remap; | ||||
| @@ -18,15 +18,41 @@ | |||||
| #include "megbrain/opr/tensor_gen.h" | #include "megbrain/opr/tensor_gen.h" | ||||
| #include "megbrain/opr/tensor_manip.h" | #include "megbrain/opr/tensor_manip.h" | ||||
| #include "megbrain/opr/search_policy/algo_chooser.h" | |||||
| #include "megbrain/opr/search_policy/profiler.h" | |||||
| #include "./internal/megdnn_opr_wrapper.inl" | #include "./internal/megdnn_opr_wrapper.inl" | ||||
| #include "./search_policy/workspace_need_limit_getter.inl" | |||||
| #include "megdnn/oprs/linalg.h" | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace opr; | using namespace opr; | ||||
| namespace { | |||||
| int get_mask_from_matmul(const megdnn::param::MatrixMul& param) { | |||||
| return static_cast<int>(param.transposeA) + | |||||
| (static_cast<int>(param.transposeB) * 2); | |||||
| } | |||||
| } | |||||
| /* ================= MatrixMul ================= */ | /* ================= MatrixMul ================= */ | ||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(MatrixMul); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(MatrixMul); | ||||
| MEGDNN_OPR_INIT2(MatrixMul, "matrix_mul") | |||||
| MatrixMul::MatrixMul(VarNode* a, VarNode* b, const Param& param, | |||||
| const ExecutionPolicy& policy, | |||||
| const OperatorNodeConfig& config) | |||||
| : Super{a->owner_graph(), config, "matrix_mul", {a, b}} { | |||||
| init_megdnn_opr(*this, param); | |||||
| m_policy = policy; | |||||
| add_input({a, b}); | |||||
| } | |||||
| SymbolVar MatrixMul::make(SymbolVar a, SymbolVar b, const Param& param, | |||||
| const ExecutionPolicy& policy, | |||||
| const OperatorNodeConfig& config) { | |||||
| return a.insert_single_output_opr<MatrixMul>(a.node(), b.node(), param, | |||||
| policy, config); | |||||
| } | |||||
| void MatrixMul::init_output_dtype() { | void MatrixMul::init_output_dtype() { | ||||
| DType output_dtype = config().output_dtype(); | DType output_dtype = config().output_dtype(); | ||||
| @@ -72,13 +98,32 @@ size_t MatrixMul::get_workspace_size_bytes( | |||||
| param ^= 1; | param ^= 1; | ||||
| }; | }; | ||||
| MGB_TRY { | MGB_TRY { | ||||
| a = mo->get_workspace_in_bytes(i0, i1, out); | |||||
| a = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out}, | |||||
| megdnn_opr(), this); | |||||
| //! Here we just want to save the execution policy got from setup_algo, | |||||
| //! while change the delaration of get_workspace_in_bytes may cause | |||||
| //! many changes. | |||||
| const_cast<MatrixMul*>(this) | |||||
| ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = | |||||
| megdnn_opr()->execution_policy(); | |||||
| transpose(i0, tparam.transposeA); | transpose(i0, tparam.transposeA); | ||||
| b = mo->get_workspace_in_bytes(i0, i1, out); | |||||
| b = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out}, | |||||
| megdnn_opr(), this); | |||||
| const_cast<MatrixMul*>(this) | |||||
| ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = | |||||
| megdnn_opr()->execution_policy(); | |||||
| transpose(i1, tparam.transposeB); | transpose(i1, tparam.transposeB); | ||||
| c = mo->get_workspace_in_bytes(i0, i1, out); | |||||
| c = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out}, | |||||
| megdnn_opr(), this); | |||||
| const_cast<MatrixMul*>(this) | |||||
| ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = | |||||
| megdnn_opr()->execution_policy(); | |||||
| transpose(i0, tparam.transposeA); | transpose(i0, tparam.transposeA); | ||||
| d = mo->get_workspace_in_bytes(i0, i1, out); | |||||
| d = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out}, | |||||
| megdnn_opr(), this); | |||||
| const_cast<MatrixMul*>(this) | |||||
| ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = | |||||
| megdnn_opr()->execution_policy(); | |||||
| } | } | ||||
| MGB_FINALLY({ tparam = this->param(); }); | MGB_FINALLY({ tparam = this->param(); }); | ||||
| return std::max(std::max(a, b), std::max(c, d)); | return std::max(std::max(a, b), std::max(c, d)); | ||||
| @@ -100,6 +145,8 @@ void MatrixMul::scn_do_execute() { | |||||
| MGB_TRY { | MGB_TRY { | ||||
| transpose(inp0.layout, tparam.transposeA); | transpose(inp0.layout, tparam.transposeA); | ||||
| transpose(inp1.layout, tparam.transposeB); | transpose(inp1.layout, tparam.transposeB); | ||||
| megdnn_opr()->execution_policy() = | |||||
| m_cadidate_execution_policies[get_mask_from_matmul(tparam)]; | |||||
| megdnn_opr()->exec(inp0, inp1, out, | megdnn_opr()->exec(inp0, inp1, out, | ||||
| intl::get_megdnn_workspace_from_var(output(1))); | intl::get_megdnn_workspace_from_var(output(1))); | ||||
| } | } | ||||
| @@ -134,7 +181,21 @@ MGB_IMPL_OPR_GRAD(MatrixMul) { | |||||
| /* ================= BatchedMatrixMul ================= */ | /* ================= BatchedMatrixMul ================= */ | ||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchedMatrixMul); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchedMatrixMul); | ||||
| MEGDNN_OPR_INIT2(BatchedMatrixMul, "batched_matrix_mul") | |||||
| BatchedMatrixMul::BatchedMatrixMul(VarNode* a, VarNode* b, const Param& param, | |||||
| const ExecutionPolicy& policy, | |||||
| const OperatorNodeConfig& config) | |||||
| : Super{a->owner_graph(), config, "batched_matrix_mul", {a, b}} { | |||||
| init_megdnn_opr(*this, param); | |||||
| m_policy = policy; | |||||
| add_input({a, b}); | |||||
| } | |||||
| SymbolVar BatchedMatrixMul::make(SymbolVar a, SymbolVar b, const Param& param, | |||||
| const ExecutionPolicy& policy, | |||||
| const OperatorNodeConfig& config) { | |||||
| return a.insert_single_output_opr<BatchedMatrixMul>(a.node(), b.node(), | |||||
| param, policy, config); | |||||
| } | |||||
| void BatchedMatrixMul::add_input_layout_constraint() { | void BatchedMatrixMul::add_input_layout_constraint() { | ||||
| auto check = [](const TensorLayout& ly) { | auto check = [](const TensorLayout& ly) { | ||||
| @@ -191,13 +252,29 @@ size_t BatchedMatrixMul::get_workspace_size_bytes( | |||||
| param ^= 1; | param ^= 1; | ||||
| }; | }; | ||||
| MGB_TRY { | MGB_TRY { | ||||
| a = mo->get_workspace_in_bytes(i0, i1, out); | |||||
| a = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo( | |||||
| {i0, i1, out}, megdnn_opr(), this); | |||||
| const_cast<BatchedMatrixMul*>(this) | |||||
| ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = | |||||
| megdnn_opr()->execution_policy(); | |||||
| transpose(i0, tparam.transposeA); | transpose(i0, tparam.transposeA); | ||||
| b = mo->get_workspace_in_bytes(i0, i1, out); | |||||
| b = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo( | |||||
| {i0, i1, out}, megdnn_opr(), this); | |||||
| const_cast<BatchedMatrixMul*>(this) | |||||
| ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = | |||||
| megdnn_opr()->execution_policy(); | |||||
| transpose(i1, tparam.transposeB); | transpose(i1, tparam.transposeB); | ||||
| c = mo->get_workspace_in_bytes(i0, i1, out); | |||||
| c = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo( | |||||
| {i0, i1, out}, megdnn_opr(), this); | |||||
| const_cast<BatchedMatrixMul*>(this) | |||||
| ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = | |||||
| megdnn_opr()->execution_policy(); | |||||
| transpose(i0, tparam.transposeA); | transpose(i0, tparam.transposeA); | ||||
| d = mo->get_workspace_in_bytes(i0, i1, out); | |||||
| d = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo( | |||||
| {i0, i1, out}, megdnn_opr(), this); | |||||
| const_cast<BatchedMatrixMul*>(this) | |||||
| ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = | |||||
| megdnn_opr()->execution_policy(); | |||||
| } | } | ||||
| MGB_FINALLY({ tparam = this->param(); }); | MGB_FINALLY({ tparam = this->param(); }); | ||||
| return std::max(std::max(a, b), std::max(c, d)); | return std::max(std::max(a, b), std::max(c, d)); | ||||
| @@ -220,6 +297,8 @@ void BatchedMatrixMul::scn_do_execute() { | |||||
| MGB_TRY { | MGB_TRY { | ||||
| transpose(inp0.layout, tparam.transposeA); | transpose(inp0.layout, tparam.transposeA); | ||||
| transpose(inp1.layout, tparam.transposeB); | transpose(inp1.layout, tparam.transposeB); | ||||
| megdnn_opr()->execution_policy() = | |||||
| m_cadidate_execution_policies[get_mask_from_matmul(tparam)]; | |||||
| megdnn_opr()->exec(inp0, inp1, out, | megdnn_opr()->exec(inp0, inp1, out, | ||||
| intl::get_megdnn_workspace_from_var(output(1))); | intl::get_megdnn_workspace_from_var(output(1))); | ||||
| } | } | ||||
| @@ -14,12 +14,14 @@ decl_opr('BatchedMatrixMul', | |||||
| 'performed and output shape is (n, a, c)') | 'performed and output shape is (n, a, c)') | ||||
| decl_opr('MatrixMul', | decl_opr('MatrixMul', | ||||
| pyname='matrix_mul_v2', | |||||
| inputs=['opr0', 'opr1'], | inputs=['opr0', 'opr1'], | ||||
| params='MatrixMul', | params='MatrixMul', | ||||
| desc='matrix multiplication', | desc='matrix multiplication', | ||||
| version=2, has_out_dtype=True) | version=2, has_out_dtype=True) | ||||
| decl_opr('BatchedMatrixMul', | decl_opr('BatchedMatrixMul', | ||||
| pyname='batched_matrix_mul_v2', | |||||
| inputs=['opr0', 'opr1'], | inputs=['opr0', 'opr1'], | ||||
| params='MatrixMul', | params='MatrixMul', | ||||
| desc='batched matrix multiplication: input shapes should be ' | desc='batched matrix multiplication: input shapes should be ' | ||||
| @@ -28,6 +30,23 @@ decl_opr('BatchedMatrixMul', | |||||
| 'performed and output shape is (n, a, c)', | 'performed and output shape is (n, a, c)', | ||||
| version=2, has_out_dtype=True) | version=2, has_out_dtype=True) | ||||
| decl_opr('MatrixMul', | |||||
| inputs=['opr0', 'opr1'], | |||||
| params=[('param', 'MatrixMul'), | |||||
| ('execution_polity', 'ExecutionPolicy')], | |||||
| desc='matrix multiplication', | |||||
| version=3, has_out_dtype=True) | |||||
| decl_opr('BatchedMatrixMul', | |||||
| inputs=['opr0', 'opr1'], | |||||
| params=[('param', 'MatrixMul'), | |||||
| ('execution_polity', 'ExecutionPolicy')], | |||||
| desc='batched matrix multiplication: input shapes should be ' | |||||
| '(n, a, b) and (n, b, c) (assuming transposeA and transeposeB are ' | |||||
| 'False); then :math:`n` independent matrix multiplications would be ' | |||||
| 'performed and output shape is (n, a, c)', | |||||
| version=3, has_out_dtype=True) | |||||
| decl_opr('Dot', | decl_opr('Dot', | ||||
| inputs=['opr0', 'opr1'], | inputs=['opr0', 'opr1'], | ||||
| params='Empty', | params='Empty', | ||||
| @@ -10,7 +10,10 @@ | |||||
| */ | */ | ||||
| #include "megbrain/opr/blas.h" | #include "megbrain/opr/blas.h" | ||||
| #include "megbrain/opr/param_defs.h" | |||||
| #include "megbrain/serialization/sereg.h" | #include "megbrain/serialization/sereg.h" | ||||
| #include "megdnn/opr_param_defs.h" | |||||
| #include "megdnn/oprs/linalg.h" | |||||
| namespace mgb { | namespace mgb { | ||||
| namespace serialization { | namespace serialization { | ||||
| @@ -27,14 +30,70 @@ struct OprMaker<opr::SVD, 1> { | |||||
| } | } | ||||
| }; | }; | ||||
| template <class MegDNNConv = megdnn::MatrixMul> | |||||
| struct MakeMatrixMulCaller { | |||||
| template <typename Opr> | |||||
| static VarNode* make(const cg::VarNodeArray& inputs, | |||||
| const typename MegDNNConv::Param& param, | |||||
| const megdnn::param::ExecutionPolicy& execution_policy, | |||||
| const OperatorNodeConfig& config) { | |||||
| if (inputs.size() == 2) { | |||||
| return Opr::make(inputs[0], inputs[1], param, execution_policy, | |||||
| config) | |||||
| .node(); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| }; | |||||
| template <class Opr, class Maker, class MegDNNMatrixMul> | |||||
| struct MatrixMulLoadDumpImpl { | |||||
| static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | |||||
| auto&& opr = opr_.cast_final_safe<Opr>(); | |||||
| ctx.write_param<megdnn::param::MatrixMul>(opr.param()); | |||||
| ctx.write_param<megdnn::param::ExecutionPolicy>(opr.execution_policy()); | |||||
| } | |||||
| static VarNode* make(const cg::VarNodeArray& inputs, | |||||
| const megdnn::param::MatrixMul& param, | |||||
| const megdnn::param::ExecutionPolicy& execution_policy, | |||||
| const OperatorNodeConfig& config) { | |||||
| VarNode* ret = Maker::template make<Opr>(inputs, param, | |||||
| execution_policy, config); | |||||
| mgb_assert(ret); | |||||
| return ret; | |||||
| } | |||||
| static cg::OperatorNodeBase* load(OprLoadContext& ctx, | |||||
| const cg::VarNodeArray& inputs, | |||||
| const OperatorNodeConfig& config) { | |||||
| auto param = ctx.read_param<megdnn::param::MatrixMul>(); | |||||
| auto execution_policy = | |||||
| ctx.read_param<megdnn::param::ExecutionPolicy>(); | |||||
| return make(inputs, param, execution_policy, config)->owner_opr(); | |||||
| } | |||||
| }; | |||||
| template <> | |||||
| struct OprLoadDumpImpl<opr::MatrixMul, 2> | |||||
| : public MatrixMulLoadDumpImpl<opr::MatrixMul, | |||||
| MakeMatrixMulCaller<megdnn::MatrixMul>, | |||||
| megdnn::MatrixMul> {}; | |||||
| template <> | |||||
| struct OprLoadDumpImpl<opr::BatchedMatrixMul, 2> | |||||
| : public MatrixMulLoadDumpImpl< | |||||
| opr::BatchedMatrixMul, | |||||
| MakeMatrixMulCaller<megdnn::BatchedMatrixMul>, | |||||
| megdnn::BatchedMatrixMul> {}; | |||||
| } // namespace serialization | } // namespace serialization | ||||
| namespace opr { | namespace opr { | ||||
| using MatrixMulV2 = MatrixMul; | |||||
| using BatchedMatrixMulV2 = BatchedMatrixMul; | |||||
| MGB_SEREG_OPR(MatrixMulV2, 2); | |||||
| MGB_SEREG_OPR(BatchedMatrixMulV2, 2); | |||||
| using MatrixMulV3 = MatrixMul; | |||||
| using BatchedMatrixMulV3 = BatchedMatrixMul; | |||||
| MGB_SEREG_OPR(MatrixMulV3, 2); | |||||
| MGB_SEREG_OPR(BatchedMatrixMulV3, 2); | |||||
| MGB_SEREG_OPR(Dot, 2); | MGB_SEREG_OPR(Dot, 2); | ||||
| MGB_SEREG_OPR(MatrixInverse, 1); | MGB_SEREG_OPR(MatrixInverse, 1); | ||||
| MGB_SEREG_OPR(SVD, 1); | MGB_SEREG_OPR(SVD, 1); | ||||
| @@ -1636,6 +1636,5 @@ void BatchConvBiasForward::init_output_format() { | |||||
| } | } | ||||
| #undef IMPL_CONV | #undef IMPL_CONV | ||||
| #undef MGB_FOREACH_FASTRUN_OPR | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -98,7 +98,7 @@ namespace serialization { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| }; | }; | ||||
| template<class Opr, class Maker0, class MegDNNConv, | template<class Opr, class Maker0, class MegDNNConv, | ||||
| class Maker1=MakeConvCallerEmpty<MegDNNConv>, | class Maker1=MakeConvCallerEmpty<MegDNNConv>, | ||||
| @@ -292,7 +292,7 @@ namespace serialization { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| }; | }; | ||||
| template<class Opr, class Maker0, class MegDNNConv, | template<class Opr, class Maker0, class MegDNNConv, | ||||
| class Maker1=MakeLocalShareCallerEmpty<MegDNNConv>, | class Maker1=MakeLocalShareCallerEmpty<MegDNNConv>, | ||||
| class Maker2=MakeLocalShareCallerEmpty<MegDNNConv>, | class Maker2=MakeLocalShareCallerEmpty<MegDNNConv>, | ||||
| @@ -251,6 +251,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(bool reproducible) const { | |||||
| auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | ||||
| opr->owner_graph(), opr->comp_node(), | opr->owner_graph(), opr->comp_node(), | ||||
| opr->execution_policy().workspace_limit); | opr->execution_policy().workspace_limit); | ||||
| m_megdnn_opr->execution_policy() = {}; | |||||
| return APPLY(m_megdnn_opr->get_algorithm_info_heuristic( | return APPLY(m_megdnn_opr->get_algorithm_info_heuristic( | ||||
| args..., workspace_limit, reproducible), | args..., workspace_limit, reproducible), | ||||
| m_layouts); | m_layouts); | ||||
| @@ -12,6 +12,7 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "megbrain/exception.h" | #include "megbrain/exception.h" | ||||
| #include "megbrain/opr/search_policy/algo_chooser_helper.h" | |||||
| #include "megbrain/tensor.h" | #include "megbrain/tensor.h" | ||||
| #include "megbrain/graph.h" | #include "megbrain/graph.h" | ||||
| #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | ||||
| @@ -24,51 +25,58 @@ namespace opr { | |||||
| /*! | /*! | ||||
| * \brief matrix_mul(trans0(opr0), trans1(opr1)) | * \brief matrix_mul(trans0(opr0), trans1(opr1)) | ||||
| */ | */ | ||||
| MGB_DEFINE_OPR_CLASS(MatrixMul, | |||||
| intl::MegDNNOprWrapperFwd<megdnn::MatrixMul>) // { | |||||
| public: | |||||
| MatrixMul(VarNode *opr0, VarNode *opr1, | |||||
| const Param ¶m, const OperatorNodeConfig &config); | |||||
| static SymbolVar make(SymbolVar opr0, SymbolVar opr1, | |||||
| const Param ¶m = {}, | |||||
| const OperatorNodeConfig &config = {}); | |||||
| private: | |||||
| void add_input_layout_constraint() override; | |||||
| void scn_do_execute() override; | |||||
| void init_output_dtype() override; | |||||
| size_t get_workspace_size_bytes( | |||||
| const TensorShapeArray &input_shapes, | |||||
| const TensorShapeArray &output_shapes) const override; | |||||
| static bool check_layout(const TensorLayout &layout, int transpose); | |||||
| MGB_DEFINE_OPR_CLASS(MatrixMul, intl::MegDNNOprWrapperFwd<megdnn::MatrixMul>, | |||||
| public mixin::AlgoChooserHelper) // { | |||||
| public: | |||||
| using AlgorithmInfo = megdnn::detail::Algorithm::Info; | |||||
| MatrixMul(VarNode* opr0, VarNode* opr1, const Param& param, | |||||
| const ExecutionPolicy& policy, const OperatorNodeConfig& config); | |||||
| static SymbolVar make(SymbolVar opr0, SymbolVar opr1, | |||||
| const Param& param = {}, | |||||
| const ExecutionPolicy& policy = {}, | |||||
| const OperatorNodeConfig& config = {}); | |||||
| private: | |||||
| void add_input_layout_constraint() override; | |||||
| void scn_do_execute() override; | |||||
| void init_output_dtype() override; | |||||
| size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes, | |||||
| const TensorShapeArray& output_shapes) | |||||
| const override; | |||||
| static bool check_layout(const TensorLayout& layout, int transpose); | |||||
| //! store the policy of all transpose situations | |||||
| megdnn::MatrixMul::ExecutionPolicy m_cadidate_execution_policies[4]; | |||||
| }; | }; | ||||
| /*! | /*! | ||||
| * \brief batched matrix multiplication on 3D inputs | * \brief batched matrix multiplication on 3D inputs | ||||
| */ | */ | ||||
| MGB_DEFINE_OPR_CLASS(BatchedMatrixMul, | MGB_DEFINE_OPR_CLASS(BatchedMatrixMul, | ||||
| intl::MegDNNOprWrapperFwd<megdnn::BatchedMatrixMul>) // { | |||||
| public: | |||||
| BatchedMatrixMul(VarNode *opr0, VarNode *opr1, | |||||
| const Param ¶m, const OperatorNodeConfig &config); | |||||
| static SymbolVar make(SymbolVar opr0, SymbolVar opr1, | |||||
| const Param ¶m = {}, | |||||
| const OperatorNodeConfig &config = {}); | |||||
| private: | |||||
| void add_input_layout_constraint() override; | |||||
| void init_output_dtype() override; | |||||
| void scn_do_execute() override; | |||||
| size_t get_workspace_size_bytes( | |||||
| const TensorShapeArray &input_shapes, | |||||
| const TensorShapeArray &output_shapes) const override; | |||||
| static bool check_layout(const TensorLayout &layout, bool transpose); | |||||
| intl::MegDNNOprWrapperFwd<megdnn::BatchedMatrixMul>, | |||||
| public mixin::AlgoChooserHelper) // { | |||||
| public: | |||||
| using AlgorithmInfo = megdnn::detail::Algorithm::Info; | |||||
| BatchedMatrixMul(VarNode* opr0, VarNode* opr1, const Param& param, | |||||
| const ExecutionPolicy& policy, | |||||
| const OperatorNodeConfig& config); | |||||
| static SymbolVar make(SymbolVar opr0, SymbolVar opr1, | |||||
| const Param& param = {}, | |||||
| const ExecutionPolicy& policy = {}, | |||||
| const OperatorNodeConfig& config = {}); | |||||
| private: | |||||
| void add_input_layout_constraint() override; | |||||
| void init_output_dtype() override; | |||||
| void scn_do_execute() override; | |||||
| size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes, | |||||
| const TensorShapeArray& output_shapes) | |||||
| const override; | |||||
| static bool check_layout(const TensorLayout& layout, bool transpose); | |||||
| //! store the policy of all transpose situations | |||||
| megdnn::BatchedMatrixMul::ExecutionPolicy m_cadidate_execution_policies[4]; | |||||
| }; | }; | ||||
| /*! | /*! | ||||
| @@ -109,4 +117,3 @@ MGB_DEFINE_OPR_CLASS(SVD, intl::MegDNNOprWrapperFwd<megdnn::SVD>) // { | |||||
| } // mgb | } // mgb | ||||
| // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -14,6 +14,7 @@ | |||||
| #include "megbrain/opr/search_policy/profiler.h" | #include "megbrain/opr/search_policy/profiler.h" | ||||
| #include "megbrain/opr/dnn/convolution.h" | #include "megbrain/opr/dnn/convolution.h" | ||||
| #include "megbrain/opr/blas.h" | |||||
| template <class MegDNNOpr> | template <class MegDNNOpr> | ||||
| struct MegDNNOpr2MGBOpr; | struct MegDNNOpr2MGBOpr; | ||||
| @@ -18,26 +18,31 @@ | |||||
| #include "megbrain/comp_node.h" | #include "megbrain/comp_node.h" | ||||
| #include "megdnn/basic_types.h" | #include "megdnn/basic_types.h" | ||||
| #include "megdnn/oprs/linalg.h" | |||||
| #include "megdnn/oprs/nn.h" | #include "megdnn/oprs/nn.h" | ||||
| namespace mgb { | namespace mgb { | ||||
| namespace opr { | namespace opr { | ||||
| #define MGB_FOREACH_FASTRUN_OPR(cb) \ | |||||
| cb(ConvolutionForward); \ | |||||
| cb(ConvBiasForward); \ | |||||
| cb(ConvolutionBackwardData); \ | |||||
| cb(ConvolutionBackwardFilter); \ | |||||
| cb(Convolution3DForward); \ | |||||
| cb(Convolution3DBackwardData); \ | |||||
| cb(Convolution3DBackwardFilter); \ | |||||
| cb(LocalShareForward); \ | |||||
| cb(LocalShareBackwardData); \ | |||||
| cb(LocalShareBackwardFilter); \ | |||||
| cb(DeformableConvForward); \ | |||||
| cb(DeformableConvBackwardFilter); \ | |||||
| cb(DeformableConvBackwardData); \ | |||||
| cb(BatchConvBiasForward); | |||||
| // clang-format off | |||||
| #define MGB_FOREACH_FASTRUN_OPR(cb) \ | |||||
| cb(ConvolutionForward) \ | |||||
| cb(ConvBiasForward) \ | |||||
| cb(ConvolutionBackwardData) \ | |||||
| cb(ConvolutionBackwardFilter) \ | |||||
| cb(Convolution3DForward) \ | |||||
| cb(Convolution3DBackwardData) \ | |||||
| cb(Convolution3DBackwardFilter) \ | |||||
| cb(LocalShareForward) \ | |||||
| cb(LocalShareBackwardData) \ | |||||
| cb(LocalShareBackwardFilter) \ | |||||
| cb(DeformableConvForward) \ | |||||
| cb(DeformableConvBackwardFilter) \ | |||||
| cb(DeformableConvBackwardData) \ | |||||
| cb(BatchConvBiasForward) \ | |||||
| cb(MatrixMul) \ | |||||
| cb(BatchedMatrixMul) | |||||
| // clang-format on | |||||
| template <typename Opr> | template <typename Opr> | ||||
| struct OprArityTrait; | struct OprArityTrait; | ||||
| @@ -67,6 +72,8 @@ INST_ARITY(megdnn::DeformableConvBackwardFilter, 4, 1); | |||||
| INST_ARITY(megdnn::BatchConvBiasForward, 4, 1); | INST_ARITY(megdnn::BatchConvBiasForward, 4, 1); | ||||
| INST_ARITY(megdnn::ConvBias, 4, 1); | INST_ARITY(megdnn::ConvBias, 4, 1); | ||||
| INST_ARITY(megdnn::DeformableConvBackwardData, 5, 3); | INST_ARITY(megdnn::DeformableConvBackwardData, 5, 3); | ||||
| INST_ARITY(megdnn::MatrixMul, 2, 1); | |||||
| INST_ARITY(megdnn::BatchedMatrixMul, 2, 1); | |||||
| #undef INST_ARITY | #undef INST_ARITY | ||||
| @@ -269,7 +269,7 @@ void run_trans_inp_test_case(bool trans_a, bool trans_b) { | |||||
| if (DTypeTrait<dt_dst>::enumv == DTypeEnum::Int16) { | if (DTypeTrait<dt_dst>::enumv == DTypeEnum::Int16) { | ||||
| config.output_dtype(dtype::Int16()); | config.output_dtype(dtype::Int16()); | ||||
| } | } | ||||
| auto z = opr::MatrixMul::make(x, y, {}, config); | |||||
| auto z = opr::MatrixMul::make(x, y, {}, {}, config); | |||||
| HostTensorND host_z; | HostTensorND host_z; | ||||
| auto func = graph->compile({make_callback_copy(z, host_z)}); | auto func = graph->compile({make_callback_copy(z, host_z)}); | ||||
| @@ -359,7 +359,7 @@ void run_bgemm_trans_inp_test_case(bool trans_a, bool trans_b) { | |||||
| trans_a ? (x = opr::Dimshuffle::make(x, {0, 2, 1})) : 0; | trans_a ? (x = opr::Dimshuffle::make(x, {0, 2, 1})) : 0; | ||||
| trans_b ? (y = opr::Dimshuffle::make(y, {0, 2, 1})) : 0; | trans_b ? (y = opr::Dimshuffle::make(y, {0, 2, 1})) : 0; | ||||
| auto z = opr::BatchedMatrixMul::make(x, y, {}, OperatorNodeConfig{}); | |||||
| auto z = opr::BatchedMatrixMul::make(x, y, {}, {}, OperatorNodeConfig{}); | |||||
| HostTensorND host_z; | HostTensorND host_z; | ||||
| auto func = graph->compile({make_callback_copy(z, host_z)}); | auto func = graph->compile({make_callback_copy(z, host_z)}); | ||||
| auto run = [&](size_t B, size_t M, size_t K, size_t N) { | auto run = [&](size_t B, size_t M, size_t K, size_t N) { | ||||
| @@ -420,6 +420,43 @@ TEST(TestOprBlas, MatrixMul_TT) { | |||||
| run_sgemm_test(true, true); | run_sgemm_test(true, true); | ||||
| } | } | ||||
| TEST(TestOprDNN, MatrixMulExePolicy) { | |||||
| using Param = opr::MatrixMul::Param; | |||||
| Param param; | |||||
| using Policy = opr::MatrixMul::ExecutionPolicy; | |||||
| using S = Policy::Strategy; | |||||
| auto cn = CompNode::load("cpux"); | |||||
| #if MGB_ENABLE_FASTRUN | |||||
| for (auto strategy : {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE, | |||||
| S::PROFILE_HEURISTIC}) { | |||||
| #else | |||||
| for (auto strategy: {S:HEURISTIC, S::PROFILE_HEURISTIC}) { | |||||
| #endif | |||||
| auto graph = ComputingGraph::make(); | |||||
| HostTensorGenerator<> gen; | |||||
| auto mkvar = [&](const char* name, const TensorShape& shp) { | |||||
| return opr::Host2DeviceCopy::make(*graph, gen(shp), cn) | |||||
| .rename(name); | |||||
| }; | |||||
| auto A = mkvar("A", {32, 64}); | |||||
| auto B = mkvar("B", {64, 32}); | |||||
| Policy policy; | |||||
| policy.strategy = strategy; | |||||
| auto C = opr::MatrixMul::make(A, B, param, policy); | |||||
| HostTensorND host_c; | |||||
| auto func = graph->compile({make_callback_copy(C, host_c)}); | |||||
| func->execute(); | |||||
| } | |||||
| } | |||||
| TEST(TestOprBlas, BatchedMatrixMulFp32_NN) { | TEST(TestOprBlas, BatchedMatrixMulFp32_NN) { | ||||
| run_batched_sgemm_test(false, false); | run_batched_sgemm_test(false, false); | ||||
| } | } | ||||