GitOrigin-RevId: 81e32da034
tags/v1.2.0
| @@ -19,6 +19,7 @@ | |||
| #include "megbrain/graph/extern_copr_api.h" | |||
| #include "megbrain/opr/dnn/convolution.h" | |||
| #include "megbrain/opr/io.h" | |||
| #include "megbrain/opr/search_policy/algo_chooser_helper.h" | |||
| #include "megbrain/opr/utility.h" | |||
| #include "megbrain/plugin/cpu_dispatch_checker.h" | |||
| #include "megbrain/plugin/num_range_checker.h" | |||
| @@ -691,7 +692,7 @@ void run_test_st(Args &env) { | |||
| } | |||
| mgb::gopt::set_opr_algo_workspace_limit_inplace(vars, env.workspace_limit); | |||
| using S = opr::mixin::Convolution::ExecutionPolicy::Strategy; | |||
| using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | |||
| S strategy = S::HEURISTIC; | |||
| #if MGB_ENABLE_FASTRUN | |||
| if (env.use_fast_run) { | |||
| @@ -15,6 +15,7 @@ | |||
| #include "megbrain/graph/event.h" | |||
| #include "megbrain/opr/dnn/batch_norm.h" | |||
| #include "megbrain/opr/dnn/local.h" | |||
| #include "megbrain/opr/search_policy/algo_chooser_helper.h" | |||
| #include "megbrain/utils/shared_set.h" | |||
| #include "megbrain/serialization/opr_shallow_copy.h" | |||
| #include "megbrain/opr/basic_arith.h" | |||
| @@ -116,8 +117,8 @@ SymbolVarArray gopt::optimize_for_inference( | |||
| namespace { | |||
| void modify_conv_strategy( | |||
| opr::mixin::Convolution& conv, | |||
| opr::mixin::Convolution::ExecutionPolicy::Strategy strategy) { | |||
| opr::mixin::AlgoChooserHelper& conv, | |||
| opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) { | |||
| auto policy = conv.execution_policy_transient(); | |||
| policy.strategy = strategy; | |||
| conv.set_execution_policy(policy); | |||
| @@ -126,13 +127,13 @@ void modify_conv_strategy( | |||
| template <typename Opr> | |||
| void inplace_conv_opr_modifier( | |||
| OperatorNodeBase& opr, | |||
| opr::mixin::Convolution::ExecutionPolicy::Strategy strategy) { | |||
| opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) { | |||
| modify_conv_strategy( | |||
| opr.cast_final_safe<Opr>(), | |||
| strategy); | |||
| } | |||
| void modify_conv_policy_workspace_limit(opr::mixin::Convolution& conv, | |||
| void modify_conv_policy_workspace_limit(opr::mixin::AlgoChooserHelper& conv, | |||
| size_t workspace_limit) { | |||
| auto policy = conv.execution_policy_transient(); | |||
| policy.workspace_limit = workspace_limit; | |||
| @@ -159,9 +160,9 @@ void inplace_conv_opr_workspace_limit_modifier(OperatorNodeBase& opr, | |||
| void gopt::modify_opr_algo_strategy_inplace( | |||
| const VarNodeArrayView& dest_vars, | |||
| opr::mixin::Convolution::ExecutionPolicy::Strategy strategy) { | |||
| opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) { | |||
| #if !MGB_ENABLE_FASTRUN | |||
| using S = opr::mixin::Convolution::ExecutionPolicy::Strategy; | |||
| using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | |||
| if (strategy == S::PROFILE || strategy == S::PROFILE_REPRODUCIBLE) { | |||
| mgb_throw(MegBrainError, "fastrun is disabled at compile time"); | |||
| } | |||
| @@ -190,16 +191,16 @@ void gopt::modify_opr_algo_strategy_inplace( | |||
| void gopt::enable_opr_algo_profiling_inplace( | |||
| const VarNodeArrayView& dest_vars) { | |||
| modify_opr_algo_strategy_inplace(dest_vars, | |||
| opr::mixin::Convolution::ExecutionPolicy:: | |||
| Strategy::PROFILE); | |||
| modify_opr_algo_strategy_inplace( | |||
| dest_vars, | |||
| opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy::PROFILE); | |||
| } | |||
| void gopt::enable_opr_use_profiling_cache_inplace( | |||
| const VarNodeArrayView& dest_vars) { | |||
| modify_opr_algo_strategy_inplace(dest_vars, | |||
| opr::mixin::Convolution::ExecutionPolicy:: | |||
| Strategy::PROFILE_HEURISTIC); | |||
| modify_opr_algo_strategy_inplace( | |||
| dest_vars, opr::mixin::AlgoChooserHelper::ExecutionPolicy:: | |||
| Strategy::PROFILE_HEURISTIC); | |||
| } | |||
| @@ -14,6 +14,7 @@ | |||
| #include "megbrain/gopt/framework.h" | |||
| #include "megbrain/graph/cg.h" | |||
| #include "megbrain/opr/dnn/convolution.h" | |||
| #include "megbrain/opr/search_policy/algo_chooser_helper.h" | |||
| namespace mgb { | |||
| namespace gopt { | |||
| @@ -342,7 +343,7 @@ namespace gopt { | |||
| */ | |||
| void modify_opr_algo_strategy_inplace( | |||
| const VarNodeArrayView& dest_vars, | |||
| opr::mixin::Convolution::ExecutionPolicy::Strategy strategy); | |||
| opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy); | |||
| /*! | |||
| * \brief enable PROFILE execution strategy for oprs with multiple | |||
| @@ -13,7 +13,7 @@ | |||
| #include "megbrain/opr/dnn/convolution.h" | |||
| #include "megbrain/opr/io.h" | |||
| #include "megbrain/opr/search_policy/algo_chooser.h" | |||
| #include "megbrain/opr/search_policy/profiler.h" | |||
| #include "megbrain/opr/search_policy/algo_chooser_helper.h" | |||
| #include "megbrain/graph/grad_impl.h" | |||
| #include "megbrain/system.h" | |||
| @@ -38,18 +38,9 @@ using intl::WorkspaceLimitGetter; | |||
| /* ==================== misc impl ==================== */ | |||
| mixin::Convolution::~Convolution() = default; | |||
| void mixin::Convolution::set_execution_policy(const ExecutionPolicy& policy) { | |||
| mgb_throw_if( | |||
| m_policy_accessed, InternalError, | |||
| "attempt to modify ExecutionPolicy after it has been accessed"); | |||
| m_policy = policy; | |||
| } | |||
| template <class MgbOpr, class MegDNNOpr> | |||
| void mixin::Convolution::init_output_static_infer_desc_for_bwd_data( | |||
| cg::OperatorNodeBase* self) { | |||
| void mixin::ConvolutionBackwardDataMixin:: | |||
| init_output_static_infer_desc_for_bwd_data(cg::OperatorNodeBase* self) { | |||
| using namespace cg::static_infer; | |||
| auto&& mgr = self->owner_graph()->static_infer_manager(); | |||
| @@ -93,7 +84,7 @@ void mixin::Convolution::init_output_static_infer_desc_for_bwd_data( | |||
| }; | |||
| inp_deps.push_back({self->output(0), DepType::SHAPE}); | |||
| auto workspace_dep_var = | |||
| WorkspaceLimitGetter::register_to_graph(self->owner_graph()); | |||
| intl::WorkspaceLimitGetter::register_to_graph(self->owner_graph()); | |||
| if (workspace_dep_var) { | |||
| inp_deps.push_back({workspace_dep_var, DepType::VALUE}); | |||
| } | |||
| @@ -101,11 +92,7 @@ void mixin::Convolution::init_output_static_infer_desc_for_bwd_data( | |||
| {SourceType::DEP, inp_deps, infer_wk}); | |||
| } | |||
| #define IMPL_CONV(_cls) \ | |||
| std::pair<const void*, size_t> _cls::param_blob() const { \ | |||
| return {¶m(), sizeof(Param)}; \ | |||
| } \ | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(_cls) | |||
| #define IMPL_CONV(_cls) MGB_DYN_TYPE_OBJ_FINAL_IMPL(_cls) | |||
| class mixin::WeightPreprocessExecutor::PreprocessedFilterExecDep final | |||
| : public cg::GraphExecutable::ExecDependency { | |||
| @@ -11,6 +11,7 @@ | |||
| */ | |||
| #include "megbrain/opr/search_policy/algo_chooser.h" | |||
| #include "megbrain/opr/search_policy/algo_chooser_helper.h" | |||
| #include "megbrain/opr/search_policy/profiler.h" | |||
| #include "../internal/invoke.h" | |||
| @@ -200,7 +201,7 @@ size_t AlgoChooser<Opr>::setup_algo(const TensorLayoutArray& layouts, | |||
| template <typename Opr> | |||
| typename AlgoChooser<Opr>::ImplAlgo AlgoChooser<Opr>::get_algo( | |||
| ExeContext& ctx) { | |||
| using S = mixin::Convolution::ExecutionPolicy::Strategy; | |||
| using S = mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | |||
| MGB_MARK_USED_VAR(TIMEOUT_TOLERANCE); | |||
| switch (ctx.mgb_opr()->execution_policy().strategy) { | |||
| case S::HEURISTIC: | |||
| @@ -0,0 +1,33 @@ | |||
| /** | |||
| * \file src/opr/impl/search_policy/algo_chooser_helper.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "megbrain/opr/search_policy/algo_chooser_helper.h" | |||
| #include "megbrain/opr/search_policy/algo_chooser.h" | |||
| #include "megbrain/graph/cg.h" | |||
| #include "../internal/megdnn_opr_wrapper.inl" | |||
| using namespace mgb; | |||
| using namespace opr; | |||
| using namespace mixin; | |||
| /* ==================== misc impl ==================== */ | |||
| AlgoChooserHelper::~AlgoChooserHelper() = default; | |||
| void AlgoChooserHelper::set_execution_policy(const ExecutionPolicy& policy) { | |||
| mgb_throw_if( | |||
| m_policy_accessed, InternalError, | |||
| "attempt to modify ExecutionPolicy after it has been accessed"); | |||
| m_policy = policy; | |||
| } | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -13,6 +13,7 @@ | |||
| #include "megbrain/opr/search_policy/profiler.h" | |||
| #include "../internal/invoke.h" | |||
| #include "../internal/megdnn_opr_wrapper.inl" | |||
| #if MGB_ROCM | |||
| #include "hcc_detail/hcc_defs_prologue.h" | |||
| @@ -11,6 +11,7 @@ | |||
| #pragma once | |||
| #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
| #include "megbrain/opr/search_policy/algo_chooser_helper.h" | |||
| #include "megbrain/utils/persistent_cache.h" | |||
| #include "megbrain/opr/param_defs.h" | |||
| #include "megdnn/oprs/nn.h" | |||
| @@ -19,68 +20,14 @@ namespace mgb { | |||
| namespace opr { | |||
| namespace mixin { | |||
| /*! | |||
| * \brief Convolution base class | |||
| */ | |||
| class Convolution { | |||
| public: | |||
| using ExecutionPolicy = megdnn::param::ExecutionPolicy; | |||
| using AlgorithmInfo = megdnn::detail::Algorithm::Info; | |||
| using AlgoChooserHook = | |||
| std::function<AlgorithmInfo(const OperatorNodeBase*)>; | |||
| const ExecutionPolicy& execution_policy() const { | |||
| if (!m_policy_accessed) { | |||
| m_policy_accessed = true; | |||
| } | |||
| return m_policy; | |||
| } | |||
| /*! | |||
| * \brief get current policy without marking it as having been accessed | |||
| * | |||
| * This is primarily used for getting current policy before calling | |||
| * set_execution_policy(). | |||
| */ | |||
| const ExecutionPolicy& execution_policy_transient() const { | |||
| return m_policy; | |||
| } | |||
| /*! | |||
| * \brief modify execution policy | |||
| * | |||
| * Exception would be thrown if execution_policy() has been accessed, | |||
| * since it would influence cache and many other decisions. | |||
| */ | |||
| void set_execution_policy(const ExecutionPolicy& policy); | |||
| AlgoChooserProfileCache& profile_cache() const; | |||
| virtual std::pair<const void*, size_t> param_blob() const = 0; | |||
| /*! | |||
| * \brief register a hook to implement custom algo chooser | |||
| */ | |||
| void setup_algo_chooser(AlgoChooserHook&& func) { | |||
| m_algo_chooser = func; | |||
| } | |||
| AlgoChooserHook algo_chooser() const { | |||
| return m_algo_chooser; | |||
| } | |||
| protected: | |||
| ~Convolution(); | |||
| mutable bool m_policy_accessed = false; | |||
| ExecutionPolicy m_policy; | |||
| AlgoChooserHook m_algo_chooser; | |||
| class ConvolutionBackwardDataMixin : public cg::OperatorNodeMixinBase { | |||
| protected: | |||
| //! init output desc for conv backward data oprs; it handles both grad | |||
| //! usage and deconv usage | |||
| template <class MgbOpr, class MegDNNOpr> | |||
| static void init_output_static_infer_desc_for_bwd_data( | |||
| cg::OperatorNodeBase* self); | |||
| //! init output desc for conv backward data oprs; it handles both grad | |||
| //! usage and deconv usage | |||
| template <class MgbOpr, class MegDNNOpr> | |||
| static void init_output_static_infer_desc_for_bwd_data( | |||
| cg::OperatorNodeBase* self); | |||
| }; | |||
| class WeightPreprocessExecutor : public cg::OperatorNodeMixinBase { | |||
| @@ -153,7 +100,7 @@ class ConvolutionTestingPeer; | |||
| } // namespace testing | |||
| MGB_DEFINE_OPR_CLASS(ConvolutionForward, | |||
| intl::ConvolutionForwardBase, public mixin::Convolution) // { | |||
| intl::ConvolutionForwardBase, public mixin::AlgoChooserHelper) // { | |||
| void init_output_dtype() override; | |||
| size_t get_workspace_size_bytes( | |||
| @@ -183,12 +130,11 @@ MGB_DEFINE_OPR_CLASS(ConvolutionForward, | |||
| const ExecutionPolicy &policy = {}, | |||
| const OperatorNodeConfig &config = {}); | |||
| std::pair<const void*, size_t> param_blob() const override; | |||
| }; | |||
| using Convolution = ConvolutionForward; | |||
| MGB_DEFINE_OPR_CLASS(ConvBiasForward, intl::ConvBiasForwardBase, | |||
| public mixin::Convolution) // { | |||
| public mixin::AlgoChooserHelper) // { | |||
| void init_output_dtype() override; | |||
| size_t get_workspace_size_bytes( | |||
| @@ -240,7 +186,6 @@ public: | |||
| const ExecutionPolicy& policy = {}, | |||
| const OperatorNodeConfig& config = {}); | |||
| std::pair<const void*, size_t> param_blob() const override; | |||
| static void check_winograd_param_valid( | |||
| const megdnn::ConvBias::WinogradParam& param, | |||
| @@ -253,10 +198,12 @@ using ConvBias = ConvBiasForward; | |||
| /*! | |||
| * \brief Can be used in two ways: compute gradient of conv, or deconv | |||
| */ | |||
| MGB_DEFINE_OPR_CLASS(ConvolutionBackwardData, | |||
| MGB_DEFINE_OPR_CLASS( | |||
| ConvolutionBackwardData, | |||
| cg::SingleCNOperatorNodeBaseT< | |||
| mixin::MegDNNOprHolderImpl<megdnn::ConvolutionBackwardData>>, | |||
| public mixin::Convolution) // { | |||
| mixin::MegDNNOprHolderImpl<megdnn::ConvolutionBackwardData>>, | |||
| public mixin::AlgoChooserHelper, | |||
| public mixin::ConvolutionBackwardDataMixin) // { | |||
| void init_output_static_infer_desc() override; | |||
| void init_output_dtype() override; | |||
| void init_output_format() override; | |||
| @@ -296,12 +243,11 @@ MGB_DEFINE_OPR_CLASS(ConvolutionBackwardData, | |||
| return make(filter, data, param, policy, config); | |||
| } | |||
| std::pair<const void*, size_t> param_blob() const override; | |||
| }; | |||
| MGB_DEFINE_OPR_CLASS(ConvolutionBackwardFilter, | |||
| intl::MegDNNOprWrapperBwd<megdnn::ConvolutionBackwardFilter>, | |||
| public mixin::Convolution ) // { | |||
| public mixin::AlgoChooserHelper ) // { | |||
| size_t get_workspace_size_bytes( | |||
| @@ -318,7 +264,6 @@ MGB_DEFINE_OPR_CLASS(ConvolutionBackwardFilter, | |||
| const ExecutionPolicy &policy = {}, | |||
| const OperatorNodeConfig &config = {}); | |||
| std::pair<const void*, size_t> param_blob() const override; | |||
| }; | |||
| MGB_DEFINE_OPR_CLASS(MaskConvolution, | |||
| @@ -350,7 +295,7 @@ public: | |||
| MGB_DEFINE_OPR_CLASS(Convolution3DForward, | |||
| intl::MegDNNOprWrapperFwd<megdnn::Convolution3DForward>, | |||
| public mixin::Convolution) // { | |||
| public mixin::AlgoChooserHelper) // { | |||
| void init_output_dtype() override; | |||
| size_t get_workspace_size_bytes( | |||
| @@ -368,17 +313,18 @@ MGB_DEFINE_OPR_CLASS(Convolution3DForward, | |||
| const ExecutionPolicy &policy = {}, | |||
| const OperatorNodeConfig &config = {}); | |||
| std::pair<const void*, size_t> param_blob() const override; | |||
| }; | |||
| using Convolution3D = Convolution3DForward; | |||
| /*! | |||
| * \brief Can be used in two ways: compute gradient of conv, or deconv | |||
| */ | |||
| MGB_DEFINE_OPR_CLASS(Convolution3DBackwardData, | |||
| MGB_DEFINE_OPR_CLASS( | |||
| Convolution3DBackwardData, | |||
| cg::SingleCNOperatorNodeBaseT< | |||
| mixin::MegDNNOprHolderImpl<megdnn::Convolution3DBackwardData>>, | |||
| public mixin::Convolution) // { | |||
| mixin::MegDNNOprHolderImpl<megdnn::Convolution3DBackwardData>>, | |||
| public mixin::AlgoChooserHelper, | |||
| public mixin::ConvolutionBackwardDataMixin) // { | |||
| void init_output_static_infer_desc() override; | |||
| void add_input_layout_constraint() override; | |||
| @@ -416,12 +362,11 @@ MGB_DEFINE_OPR_CLASS(Convolution3DBackwardData, | |||
| return make(filter, data, param, policy, config); | |||
| } | |||
| std::pair<const void*, size_t> param_blob() const override; | |||
| }; | |||
| MGB_DEFINE_OPR_CLASS(Convolution3DBackwardFilter, | |||
| intl::MegDNNOprWrapperBwd<megdnn::Convolution3DBackwardFilter>, | |||
| public mixin::Convolution) // { | |||
| public mixin::AlgoChooserHelper) // { | |||
| size_t get_workspace_size_bytes( | |||
| const TensorShapeArray &input_shapes, | |||
| @@ -437,12 +382,11 @@ MGB_DEFINE_OPR_CLASS(Convolution3DBackwardFilter, | |||
| const ExecutionPolicy &policy = {}, | |||
| const OperatorNodeConfig &config = {}); | |||
| std::pair<const void*, size_t> param_blob() const override; | |||
| }; | |||
| MGB_DEFINE_OPR_CLASS(LocalShareForward, | |||
| intl::MegDNNOprWrapperFwd<megdnn::LocalShareForward>, | |||
| public mixin::Convolution) // { | |||
| public mixin::AlgoChooserHelper) // { | |||
| void init_output_dtype() override; | |||
| void init_output_format() override; | |||
| @@ -457,7 +401,6 @@ public: | |||
| static SymbolVar make(SymbolVar src, SymbolVar filter, const Param& param = {}, | |||
| const ExecutionPolicy& policy = {}, | |||
| const OperatorNodeConfig& config = {}); | |||
| std::pair<const void*, size_t> param_blob() const override; | |||
| }; | |||
| using LocalShare = LocalShareForward; | |||
| @@ -465,7 +408,8 @@ MGB_DEFINE_OPR_CLASS( | |||
| LocalShareBackwardData, | |||
| cg::SingleCNOperatorNodeBaseT< | |||
| mixin::MegDNNOprHolderImpl<megdnn::LocalShareBackwardData>>, | |||
| public mixin::Convolution) // { | |||
| public mixin::AlgoChooserHelper, | |||
| public mixin::ConvolutionBackwardDataMixin) // { | |||
| void init_output_static_infer_desc() override; | |||
| void init_output_dtype() override; | |||
| @@ -485,13 +429,12 @@ public: | |||
| const ExecutionPolicy& policy = {}, | |||
| const OperatorNodeConfig& config = {}); | |||
| std::pair<const void*, size_t> param_blob() const override; | |||
| }; | |||
| MGB_DEFINE_OPR_CLASS( | |||
| LocalShareBackwardFilter, | |||
| intl::MegDNNOprWrapperBwd<megdnn::LocalShareBackwardFilter>, | |||
| public mixin::Convolution) // { | |||
| public mixin::AlgoChooserHelper) // { | |||
| size_t get_workspace_size_bytes( | |||
| const TensorShapeArray& input_shapes, | |||
| @@ -506,12 +449,11 @@ public: | |||
| const ExecutionPolicy& policy = {}, | |||
| const OperatorNodeConfig& config = {}); | |||
| std::pair<const void*, size_t> param_blob() const override; | |||
| }; | |||
| MGB_DEFINE_OPR_CLASS(DeformableConvForward, | |||
| intl::MegDNNOprWrapperFwd<megdnn::DeformableConvForward>, | |||
| public mixin::Convolution) // { | |||
| public mixin::AlgoChooserHelper) // { | |||
| public: | |||
| DeformableConvForward( | |||
| VarNode *src, VarNode *filter, VarNode *offset, VarNode *mask, | |||
| @@ -525,7 +467,6 @@ MGB_DEFINE_OPR_CLASS(DeformableConvForward, | |||
| const ExecutionPolicy &policy = {}, | |||
| const OperatorNodeConfig &config = {}); | |||
| std::pair<const void*, size_t> param_blob() const override; | |||
| private: | |||
| void init_output_dtype() override; | |||
| void init_output_format() override; | |||
| @@ -537,7 +478,8 @@ using DeformableConv = DeformableConvForward; | |||
| MGB_DEFINE_OPR_CLASS(DeformableConvBackwardData, | |||
| intl::DeformableConvBackwardDataBase, | |||
| public mixin::Convolution) // { | |||
| public mixin::AlgoChooserHelper, | |||
| public mixin::ConvolutionBackwardDataMixin) // { | |||
| public: | |||
| DeformableConvBackwardData( | |||
| VarNode * src, VarNode * filter, VarNode * offset, VarNode * mask, | |||
| @@ -557,7 +499,6 @@ public: | |||
| const OperatorNodeConfig& config = {}); | |||
| void scn_do_execute() override; | |||
| std::pair<const void*, size_t> param_blob() const override; | |||
| private: | |||
| void get_output_var_shape(const TensorShapeArray& inp_shape, | |||
| @@ -578,7 +519,7 @@ private: | |||
| MGB_DEFINE_OPR_CLASS( | |||
| DeformableConvBackwardFilter, | |||
| intl::MegDNNOprWrapperBwd<megdnn::DeformableConvBackwardFilter>, | |||
| public mixin::Convolution) // { | |||
| public mixin::AlgoChooserHelper) // { | |||
| public: | |||
| DeformableConvBackwardFilter( | |||
| VarNode * src, VarNode * filter, VarNode * offset, VarNode * mask, | |||
| @@ -592,7 +533,6 @@ public: | |||
| const OperatorNodeConfig& config = {}); | |||
| void scn_do_execute() override; | |||
| std::pair<const void*, size_t> param_blob() const override; | |||
| private: | |||
| size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes, | |||
| @@ -601,7 +541,7 @@ private: | |||
| }; | |||
| MGB_DEFINE_OPR_CLASS(BatchConvBiasForward, intl::BatchConvBiasForwardBase, | |||
| public mixin::Convolution) // { | |||
| public mixin::AlgoChooserHelper) // { | |||
| void init_output_dtype() override; | |||
| size_t get_workspace_size_bytes( | |||
| @@ -650,7 +590,6 @@ public: | |||
| const ExecutionPolicy& policy = {}, | |||
| const OperatorNodeConfig& config = {}); | |||
| std::pair<const void*, size_t> param_blob() const override; | |||
| }; | |||
| using BatchConvBias = BatchConvBiasForward; | |||
| @@ -13,6 +13,7 @@ | |||
| #pragma once | |||
| #include "megbrain/opr/search_policy/profiler.h" | |||
| #include "megbrain/opr/dnn/convolution.h" | |||
| template <class MegDNNOpr> | |||
| struct MegDNNOpr2MGBOpr; | |||
| @@ -0,0 +1,80 @@ | |||
| /** | |||
| * \file src/opr/include/megbrain/opr/search_policy/algo_chooser_helper.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain/graph/operator_node.h" | |||
| #include "megbrain/opr/param_defs.h" | |||
| #include "megdnn/oprs/base.h" | |||
| #include "megdnn/oprs/nn.h" | |||
| namespace mgb { | |||
| namespace opr { | |||
| namespace mixin { | |||
| /*! | |||
| * \brief base class for the opr which can be tuning | |||
| */ | |||
| class AlgoChooserHelper : cg::OperatorNodeMixinBase { | |||
| public: | |||
| using ExecutionPolicy = megdnn::param::ExecutionPolicy; | |||
| using AlgorithmInfo = megdnn::detail::Algorithm::Info; | |||
| using AlgoChooserHook = | |||
| std::function<AlgorithmInfo(const cg::OperatorNodeBase*)>; | |||
| const ExecutionPolicy& execution_policy() const { | |||
| if (!m_policy_accessed) { | |||
| m_policy_accessed = true; | |||
| } | |||
| return m_policy; | |||
| } | |||
| /*! | |||
| * \brief get current policy without marking it as having been accessed | |||
| * | |||
| * This is primarily used for getting current policy before calling | |||
| * set_execution_policy(). | |||
| */ | |||
| const ExecutionPolicy& execution_policy_transient() const { | |||
| return m_policy; | |||
| } | |||
| /*! | |||
| * \brief modify execution policy | |||
| * | |||
| * Exception would be thrown if execution_policy() has been accessed, | |||
| * since it would influence cache and many other decisions. | |||
| */ | |||
| void set_execution_policy(const ExecutionPolicy& policy); | |||
| /*! | |||
| * \brief register a hook to implement custom algo chooser | |||
| */ | |||
| void setup_algo_chooser(AlgoChooserHook&& func) { m_algo_chooser = func; } | |||
| AlgoChooserHook algo_chooser() const { return m_algo_chooser; } | |||
| protected: | |||
| ~AlgoChooserHelper(); | |||
| mutable bool m_policy_accessed = false; | |||
| ExecutionPolicy m_policy; | |||
| AlgoChooserHook m_algo_chooser; | |||
| }; | |||
| } // namespace mixin | |||
| } // namespace opr | |||
| } // namespace mgb | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -12,9 +12,10 @@ | |||
| #pragma once | |||
| #include "megbrain/opr/dnn/convolution.h" | |||
| #include "megbrain/utils/hash_ct.h" | |||
| #include "megbrain/utils/timer.h" | |||
| #include "megbrain/system.h" | |||
| #include "megbrain/comp_node.h" | |||
| #include "megdnn/basic_types.h" | |||
| #include "megdnn/oprs/nn.h" | |||
| @@ -127,15 +128,15 @@ class TimedProfiler { | |||
| static constexpr int arity_out = OprArityTrait<Opr>::arity_out; | |||
| static constexpr int arity = OprArityTrait<Opr>::arity; | |||
| using ConvTensorShapes = std::array<TensorShape, arity>; | |||
| using TensorShapeArray = std::array<megdnn::TensorShape, arity>; | |||
| public: | |||
| struct Param { | |||
| char algo_name[128]; | |||
| size_t workspace; | |||
| DTypeEnum dtypes[arity]; | |||
| megdnn::DTypeEnum dtypes[arity]; | |||
| CompNode::Locator comp_node_loc; | |||
| ConvTensorShapes shapes; | |||
| TensorShapeArray shapes; | |||
| typename Opr::Param opr_param; | |||
| bool allow_weight_preprocess; | |||