GitOrigin-RevId: 4973820e02
tags/v1.6.0
| @@ -30,7 +30,10 @@ | |||
| #include "megbrain/tensorrt/opr_replace.h" | |||
| #endif | |||
| #include "megbrain/gopt/global_layout_transform.h" | |||
| #include "megbrain/gopt/layout_transform_context.h" | |||
| #include "megbrain/gopt/layout_transform_pass.h" | |||
| #include "megbrain/gopt/profiler.h" | |||
| #include "megbrain/gopt/solver.h" | |||
| using namespace mgb; | |||
| using namespace gopt; | |||
| @@ -12,7 +12,9 @@ | |||
| #include <queue> | |||
| #include "./utils.h" | |||
| #include "megbrain/gopt/global_layout_transform.h" | |||
| #include "megbrain/gopt/layout_transform_context.h" | |||
| #include "megbrain/gopt/profiler.h" | |||
| #include "megbrain/gopt/solver.h" | |||
| using namespace mgb; | |||
| using namespace gopt; | |||
| @@ -85,11 +87,11 @@ private: | |||
| const SmallVector<TensorFormats>& available_tensor_formats; | |||
| }; | |||
| /*! | |||
| * \brief get the tensor formats configuration for the operator with particular op format | |||
| * \param[out] var2fmts hashmap that maps varnode to actual tensor formats of the op format configuration | |||
| * \param[in] opr given operator | |||
| * \param[in] opr_fmt given op format, an enum type argument which indicates the op format configuration. | |||
| * \param[in] ctx context | |||
| * \brief get the tensor formats configuration for the operator with | |||
| * particular op format \param[out] var2fmts hashmap that maps varnode to | |||
| * actual tensor formats of the op format configuration \param[in] opr given | |||
| * operator \param[in] opr_fmt given op format, an enum type argument which | |||
| * indicates the op format configuration. \param[in] ctx context | |||
| */ | |||
| TensorFormats get_io_formats(ThinHashMap<VarNode*, TensorFormats>& var2fmts, | |||
| const OperatorNodeBase* opr, OprFormat opr_fmt, | |||
| @@ -11,7 +11,7 @@ | |||
| */ | |||
| #include "./utils.h" | |||
| #include "megbrain/gopt/global_layout_transform.h" | |||
| #include "megbrain/gopt/layout_transform_context.h" | |||
| #include "megbrain/opr/dnn/pooling.h" | |||
| #include "megbrain/opr/imgproc.h" | |||
| #include "megbrain/opr/nn_int.h" | |||
| @@ -10,9 +10,11 @@ | |||
| * implied. | |||
| */ | |||
| #include "megbrain/gopt/layout_transform_pass.h" | |||
| #include "./opr_format_modifier.h" | |||
| #include "./utils.h" | |||
| #include "megbrain/gopt/global_layout_transform.h" | |||
| #include "megbrain/gopt/profiler.h" | |||
| #include "megbrain/gopt/solver.h" | |||
| #include "megbrain/opr/dnn/pooling.h" | |||
| #include "megbrain/opr/imgproc.h" | |||
| #include "megbrain/serialization/sereg.h" | |||
| @@ -46,8 +48,7 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||
| auto&& opr_configs = m_ctx->opr_configs(); | |||
| auto&& base_fmt = m_ctx->attribute().base_tensor_formats; | |||
| auto&& reformat_attribute = | |||
| ReformatManager::ReformatKey::Attribute::DEFAULT; | |||
| auto&& reformat_attribute = m_ctx->attribute().reformat_attribute; | |||
| ThinHashMap<VarNode*, TensorFormats> var2fmts; | |||
| static ThinHashSet<Typeinfo*> format_aware_oprs = { | |||
| #define cb(_Opr) opr::_Opr::typeinfo(), | |||
| @@ -55,8 +56,8 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||
| #undef cb | |||
| }; | |||
| auto rewriter = opt.graph().make_rewriter(); | |||
| auto on_opr = [this, &opr_configs, &base_fmt, &reformat_attribute, | |||
| &rewriter, &solution, &var2fmts, | |||
| auto on_opr = [&opr_configs, &base_fmt, &reformat_attribute, &rewriter, | |||
| &solution, &var2fmts, | |||
| &endpoint_vars](OperatorNodeBase* opr) { | |||
| auto it = solution.find(opr); | |||
| if (it != solution.end()) { | |||
| @@ -122,19 +123,6 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||
| opr->config()) | |||
| ->output(0); | |||
| } | |||
| if (endpoint_vars.count(opr->output(0)) && out_fmt != base_fmt) { | |||
| ReformatManager::ReformatKey key{ | |||
| out_fmt, base_fmt, reformat_attribute, | |||
| opr->output(0)->dtype().enumv(), | |||
| opr->output(0)->dtype().enumv()}; | |||
| auto reformat = ReformatManager::instance() | |||
| .auto_aligned_reformat_featrue( | |||
| opr->output(0), base_fmt, key); | |||
| new_out = reformat({new_out}); | |||
| var2fmts[new_out] = base_fmt; | |||
| } else { | |||
| var2fmts[new_out] = out_fmt; | |||
| } | |||
| auto &&out0 = opr->output(), | |||
| &&out1 = new_out->owner_opr()->output(); | |||
| mgb_assert(opr->usable_output().size() == | |||
| @@ -146,20 +134,29 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||
| new_out->owner_opr()->cname(), | |||
| new_out->owner_opr()->dyn_typeinfo()->name, out0.size(), | |||
| out1.size()); | |||
| for (size_t i = 0; i < out0.size(); ++i) { | |||
| if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { | |||
| mgb_assert(!out1[i]->contain_flag( | |||
| VarNode::Flag::VOLATILE_CONTENT)); | |||
| auto src = out0[i]; | |||
| auto dst = out1[i]; | |||
| rewriter.replace_var( | |||
| src, dst, | |||
| mgb_cstr_log(ssprintf("replace opr(%s) to new opr " | |||
| "format(%s)", | |||
| opr->cname(), | |||
| opr_format_to_string(opr_fmt)) | |||
| .c_str())); | |||
| size_t nr_outs = opr->usable_output().size(); | |||
| for (size_t i = 0; i < nr_outs; ++i) { | |||
| const auto& ovar = out0[i]; | |||
| auto new_ovar = out1[i]; | |||
| if (endpoint_vars.count(ovar) && out_fmt != base_fmt) { | |||
| ReformatManager::ReformatKey key{ | |||
| out_fmt, base_fmt, reformat_attribute, | |||
| ovar->dtype().enumv(), ovar->dtype().enumv()}; | |||
| auto reformat = ReformatManager::instance() | |||
| .auto_aligned_reformat_featrue( | |||
| ovar, base_fmt, key); | |||
| new_ovar = reformat({new_ovar}); | |||
| var2fmts[new_ovar] = base_fmt; | |||
| } else { | |||
| var2fmts[new_ovar] = out_fmt; | |||
| } | |||
| rewriter.replace_var( | |||
| ovar, new_ovar, | |||
| mgb_cstr_log(ssprintf("replace opr(%s) to new opr " | |||
| "format(%s)", | |||
| opr->cname(), | |||
| opr_format_to_string(opr_fmt)) | |||
| .c_str())); | |||
| } | |||
| } else { | |||
| auto new_opr = rewriter.auto_replace_outputs(opr); | |||
| @@ -11,7 +11,7 @@ | |||
| */ | |||
| #include "./utils.h" | |||
| #include "megbrain/gopt/global_layout_transform.h" | |||
| #include "megbrain/gopt/layout_transform_context.h" | |||
| #include "megbrain/opr/dnn/pooling.h" | |||
| #include "megbrain/opr/imgproc.h" | |||
| @@ -13,7 +13,7 @@ | |||
| #include "./opr_format_modifier.h" | |||
| #include "./utils.h" | |||
| #include "megbrain/gopt/framework.h" | |||
| #include "megbrain/gopt/global_layout_transform.h" | |||
| #include "megbrain/gopt/profiler.h" | |||
| #include "megbrain/graph/event.h" | |||
| #include "megbrain/opr/dnn/pooling.h" | |||
| #include "megbrain/opr/imgproc.h" | |||
| @@ -10,7 +10,8 @@ | |||
| * implied. | |||
| */ | |||
| #include "megbrain/gopt/global_layout_transform.h" | |||
| #include "megbrain/gopt/profiler.h" | |||
| #include "megbrain/gopt/solver.h" | |||
| #include "megbrain/opr/dnn/pooling.h" | |||
| #include "megbrain/opr/imgproc.h" | |||
| @@ -11,7 +11,7 @@ | |||
| */ | |||
| #pragma once | |||
| #include "megbrain/gopt/global_layout_transform.h" | |||
| #include "megbrain/gopt/layout_transform_context.h" | |||
| namespace mgb { | |||
| namespace gopt { | |||
| @@ -1,5 +1,6 @@ | |||
| /** | |||
| * \file src/gopt/include/megbrain/gopt/global_layout_transformation.h | |||
| * \file | |||
| * src/gopt/include/megbrain/gopt/layout_transform_context.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -12,11 +13,10 @@ | |||
| #pragma once | |||
| #include "megbrain/gopt/framework.h" | |||
| #include "megbrain/gopt/inference.h" | |||
| #include "megbrain/gopt/reformat_manager.h" | |||
| #include "megbrain/gopt/subgraph_extractor.h" | |||
| #include "megbrain/opr/dnn/convolution.h" | |||
| #include "megbrain/plugin/opr_footprint.h" | |||
| #include "megbrain/gopt/inference.h" | |||
| namespace mgb { | |||
| namespace gopt { | |||
| @@ -118,7 +118,7 @@ public: | |||
| TensorFormats base_tensor_format = TensorFormats::NCHW); | |||
| private: | |||
| OprList m_opr_list; /// supported operator list | |||
| OprList m_opr_list; /// supported operator list | |||
| SmallVector<TensorFormats> | |||
| m_available_tensor_formats; /// the available tensor formats, used | |||
| /// for format agnostic operators (like | |||
| @@ -180,164 +180,6 @@ private: | |||
| const GraphPartition& m_graph_partition; /// the graph partition | |||
| const LayoutTransformContext& m_ctx; | |||
| }; | |||
| /*! | |||
| * \brief A profiler that collects all the performance data to describe the | |||
| * global layout transform problem. | |||
| */ | |||
| class ProfilerBase { | |||
| public: | |||
| using OprFormat = Problem::OprFormat; | |||
| struct OperatorNodeRecord { | |||
| const cg::OperatorNodeBase* opr; ///< pointer to operator node | |||
| ThinHashMap<OprFormat, float> | |||
| costs; ///< costs of operator node, i.e. the elapsed device | |||
| ///< time of the operator node on different opr format | |||
| ///< (layout configuration). | |||
| std::string to_string() const; | |||
| }; | |||
| struct VarNodeRecord { | |||
| struct KeyHash { | |||
| size_t operator()( | |||
| const std::pair<TensorFormats, TensorFormats>& val) const { | |||
| size_t h1 = | |||
| std::hash<uint32_t>()(static_cast<uint32_t>(val.first)); | |||
| size_t h2 = std::hash<uint32_t>()( | |||
| static_cast<uint32_t>(val.second)); | |||
| return mgb::hash_pair_combine(h1, h2); | |||
| } | |||
| }; | |||
| const VarNode* var; ///< pointer to var node | |||
| std::unordered_map<std::pair<TensorFormats, TensorFormats>, float, | |||
| KeyHash> | |||
| costs; ///< costs of var node, i.e. the elapsed | |||
| ///< device time of the layout transform. | |||
| ///< Key of the hashmap indicates the | |||
| ///< source tensor format and the target | |||
| ///< tensor format. | |||
| std::string to_string() const; | |||
| }; | |||
| /*! | |||
| * \note the profiler assumes all the input and output var node are stored | |||
| * in contiguous layout in memory | |||
| */ | |||
| struct ProfilingResult { | |||
| /// A hashmap, that maps the operator node to the costs (device elapsed | |||
| /// time) of different layouts configuration | |||
| ThinHashMap<cg::OperatorNodeBase*, OperatorNodeRecord> opr_record; | |||
| /// A hashmap, that maps the var node to the costs of layout transform | |||
| ThinHashMap<VarNode*, VarNodeRecord> var_record; | |||
| }; | |||
| using OprFilter = thin_function<bool(const cg::OperatorNodeBase*, | |||
| cg::OperatorNodeBase*)>; | |||
| using VarNodeFilter = | |||
| thin_function<bool(const VarNode*, TensorShape, TensorShape, | |||
| ReformatManager::ReformatKey)>; | |||
| ProfilerBase(float opr_threshold = 2.f, float var_node_threshold = 2.f); | |||
| ProfilerBase(OprFilter opr_filter, VarNodeFilter var_node_filter = {}) | |||
| : m_opr_filter{std::move(opr_filter)}, | |||
| m_var_node_filter{std::move(var_node_filter)} {} | |||
| virtual ~ProfilerBase() = default; | |||
| virtual ProfilingResult profile(const Problem& problem) const = 0; | |||
| static std::unique_ptr<ProfilerBase> make_profiler(); | |||
| protected: | |||
| OprFilter m_opr_filter; | |||
| VarNodeFilter m_var_node_filter; | |||
| float m_opr_threshold; | |||
| float m_var_node_threshold; | |||
| private: | |||
| OprFootprint m_opr_footprint; | |||
| }; | |||
| /*! | |||
| * \brief abstract solver | |||
| */ | |||
| class SolverBase { | |||
| public: | |||
| using OprFormat = Problem::OprFormat; | |||
| using Solution = ThinHashMap<cg::OperatorNodeBase*, OprFormat>; | |||
| SolverBase() = default; | |||
| virtual ~SolverBase() = default; | |||
| /*! | |||
| * \brief solve the given problem | |||
| */ | |||
| virtual Solution solve(const Problem& problem) const = 0; | |||
| /*! | |||
| * \brief check whether the given problem can be solved by the | |||
| * algorithm(i.e. solver). | |||
| */ | |||
| virtual bool can_solve(const Problem& problem) const = 0; | |||
| }; | |||
| /*! | |||
| * \brief solvers that will first collect the costs of operators in different op | |||
| * format and the costs of layout transform of varnode with a user provided | |||
| * profiler on the target device. This will lead to time consuming. | |||
| */ | |||
| class ProfilingBasedSolver : public SolverBase { | |||
| public: | |||
| using GraphPartitionFilter = | |||
| thin_function<bool(const GraphPartition& graph_partition)>; | |||
| ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profiler); | |||
| /*! | |||
| * \note some graph partition (for example, graph partition without format | |||
| * aware operators like conv, deconv, warp, resize etc.) will be filtered by | |||
| * the GraphPartitionFilter, which can reduce the profiling time. */ | |||
| ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profiler, | |||
| GraphPartitionFilter graph_partition_filter) | |||
| : m_profiler{std::move(profiler)}, | |||
| m_graph_partition_filter{std::move(graph_partition_filter)} {} | |||
| virtual ~ProfilingBasedSolver() = default; | |||
| Solution solve(const Problem& problem) const override; | |||
| virtual Solution do_solve(const Problem& problem) const = 0; | |||
| protected: | |||
| std::unique_ptr<ProfilerBase> m_profiler; | |||
| private: | |||
| GraphPartitionFilter m_graph_partition_filter; | |||
| }; | |||
| /*! | |||
| * \brief A solver that solves the layout selection problem using dynamic | |||
| * programming algorithm (Markov decision process). | |||
| */ | |||
| class DynamicProgrammingSolver final : public ProfilingBasedSolver { | |||
| public: | |||
| DynamicProgrammingSolver(std::unique_ptr<ProfilerBase> profiler) | |||
| : ProfilingBasedSolver(std::move(profiler)){}; | |||
| DynamicProgrammingSolver(std::unique_ptr<ProfilerBase> profiler, | |||
| GraphPartitionFilter graph_partition_filter) | |||
| : ProfilingBasedSolver(std::move(profiler), | |||
| std::move(graph_partition_filter)){}; | |||
| ~DynamicProgrammingSolver() noexcept = default; | |||
| Solution do_solve(const Problem& problem) const override; | |||
| bool can_solve(const Problem& problem) const override; | |||
| private: | |||
| class Impl; | |||
| }; | |||
| /*! | |||
| * \brief A layout transform pass, which convert the operator's format to the | |||
| * optimal format using the results of the solver. | |||
| */ | |||
| class LayoutTransformPass final : public Pass { | |||
| public: | |||
| const char* name() const override { return "layout assignment pass"; } | |||
| void apply(OptState& opt) const override; | |||
| LayoutTransformPass(std::unique_ptr<LayoutTransformContext> ctx, | |||
| std::unique_ptr<SolverBase> solver) | |||
| : m_ctx{std::move(ctx)}, m_solver{std::move(solver)} {} | |||
| private: | |||
| std::unique_ptr<LayoutTransformContext> m_ctx; | |||
| std::unique_ptr<SolverBase> m_solver; | |||
| }; | |||
| } // namespace gopt | |||
| } // namespace mgb | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * \file src/gopt/include/megbrain/gopt/global_layout_transformation.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain/gopt/framework.h" | |||
| namespace mgb { | |||
| namespace gopt { | |||
| class LayoutTransformContext; | |||
| class SolverBase; | |||
| /*! | |||
| * \brief A layout transform pass, which convert the operator's format to the | |||
| * optimal format using the results of the solver. | |||
| */ | |||
| class LayoutTransformPass final : public Pass { | |||
| public: | |||
| const char* name() const override { return "layout assignment pass"; } | |||
| void apply(OptState& opt) const override; | |||
| LayoutTransformPass(std::unique_ptr<LayoutTransformContext> ctx, | |||
| std::unique_ptr<SolverBase> solver) | |||
| : m_ctx{std::move(ctx)}, m_solver{std::move(solver)} {} | |||
| private: | |||
| std::unique_ptr<LayoutTransformContext> m_ctx; | |||
| std::unique_ptr<SolverBase> m_solver; | |||
| }; | |||
| } // namespace gopt | |||
| } // namespace mgb | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -0,0 +1,101 @@ | |||
| /** | |||
| * \file src/gopt/include/megbrain/gopt/global_layout_transformation.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain/gopt/framework.h" | |||
| #include "megbrain/gopt/inference.h" | |||
| #include "megbrain/gopt/layout_transform_context.h" | |||
| #include "megbrain/gopt/reformat_manager.h" | |||
| #include "megbrain/gopt/subgraph_extractor.h" | |||
| #include "megbrain/opr/dnn/convolution.h" | |||
| #include "megbrain/plugin/opr_footprint.h" | |||
| namespace mgb { | |||
| namespace gopt { | |||
| class Problem; | |||
| /*! | |||
| * \brief A profiler that collects all the performance data to describe the | |||
| * global layout transform problem. | |||
| */ | |||
| class ProfilerBase { | |||
| public: | |||
| using OprFormat = Problem::OprFormat; | |||
| struct OperatorNodeRecord { | |||
| const cg::OperatorNodeBase* opr; ///< pointer to operator node | |||
| ThinHashMap<OprFormat, float> | |||
| costs; ///< costs of operator node, i.e. the elapsed device | |||
| ///< time of the operator node on different opr format | |||
| ///< (layout configuration). | |||
| std::string to_string() const; | |||
| }; | |||
| struct VarNodeRecord { | |||
| struct KeyHash { | |||
| size_t operator()( | |||
| const std::pair<TensorFormats, TensorFormats>& val) const { | |||
| size_t h1 = | |||
| std::hash<uint32_t>()(static_cast<uint32_t>(val.first)); | |||
| size_t h2 = std::hash<uint32_t>()( | |||
| static_cast<uint32_t>(val.second)); | |||
| return mgb::hash_pair_combine(h1, h2); | |||
| } | |||
| }; | |||
| const VarNode* var; ///< pointer to var node | |||
| std::unordered_map<std::pair<TensorFormats, TensorFormats>, float, | |||
| KeyHash> | |||
| costs; ///< costs of var node, i.e. the elapsed | |||
| ///< device time of the layout transform. | |||
| ///< Key of the hashmap indicates the | |||
| ///< source tensor format and the target | |||
| ///< tensor format. | |||
| std::string to_string() const; | |||
| }; | |||
| /*! | |||
| * \note the profiler assumes all the input and output var node are stored | |||
| * in contiguous layout in memory | |||
| */ | |||
| struct ProfilingResult { | |||
| /// A hashmap, that maps the operator node to the costs (device elapsed | |||
| /// time) of different layouts configuration | |||
| ThinHashMap<cg::OperatorNodeBase*, OperatorNodeRecord> opr_record; | |||
| /// A hashmap, that maps the var node to the costs of layout transform | |||
| ThinHashMap<VarNode*, VarNodeRecord> var_record; | |||
| }; | |||
| using OprFilter = thin_function<bool(const cg::OperatorNodeBase*, | |||
| cg::OperatorNodeBase*)>; | |||
| using VarNodeFilter = | |||
| thin_function<bool(const VarNode*, TensorShape, TensorShape, | |||
| ReformatManager::ReformatKey)>; | |||
| ProfilerBase(float opr_threshold = 2.f, float var_node_threshold = 2.f); | |||
| ProfilerBase(OprFilter opr_filter, VarNodeFilter var_node_filter = {}) | |||
| : m_opr_filter{std::move(opr_filter)}, | |||
| m_var_node_filter{std::move(var_node_filter)} {} | |||
| virtual ~ProfilerBase() = default; | |||
| virtual ProfilingResult profile(const Problem& problem) const = 0; | |||
| static std::unique_ptr<ProfilerBase> make_profiler(); | |||
| protected: | |||
| OprFilter m_opr_filter; | |||
| VarNodeFilter m_var_node_filter; | |||
| float m_opr_threshold; | |||
| float m_var_node_threshold; | |||
| private: | |||
| OprFootprint m_opr_footprint; | |||
| }; | |||
| } // namespace gopt | |||
| } // namespace mgb | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -0,0 +1,97 @@ | |||
| /** | |||
| * \file src/gopt/include/megbrain/gopt/solver.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain/gopt/framework.h" | |||
| #include "megbrain/gopt/layout_transform_context.h" | |||
| #include "megbrain/opr/dnn/convolution.h" | |||
| #include "megbrain/plugin/opr_footprint.h" | |||
| #include "megbrain/gopt/inference.h" | |||
| namespace mgb { | |||
| namespace gopt { | |||
| class ProfilerBase; | |||
| /*! | |||
| * \brief abstract solver | |||
| */ | |||
| class SolverBase { | |||
| public: | |||
| using OprFormat = Problem::OprFormat; | |||
| using Solution = ThinHashMap<cg::OperatorNodeBase*, OprFormat>; | |||
| SolverBase() = default; | |||
| virtual ~SolverBase() = default; | |||
| /*! | |||
| * \brief solve the given problem | |||
| */ | |||
| virtual Solution solve(const Problem& problem) const = 0; | |||
| /*! | |||
| * \brief check whether the given problem can be solved by the | |||
| * algorithm(i.e. solver). | |||
| */ | |||
| virtual bool can_solve(const Problem& problem) const = 0; | |||
| }; | |||
| /*! | |||
| * \brief solvers that will first collect the costs of operators in different op | |||
| * format and the costs of layout transform of varnode with a user provided | |||
| * profiler on the target device. This will lead to time consuming. | |||
| */ | |||
| class ProfilingBasedSolver : public SolverBase { | |||
| public: | |||
| using GraphPartitionFilter = | |||
| thin_function<bool(const GraphPartition& graph_partition)>; | |||
| ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profiler); | |||
| /*! | |||
| * \note some graph partition (for example, graph partition without format | |||
| * aware operators like conv, deconv, warp, resize etc.) will be filtered by | |||
| * the GraphPartitionFilter, which can reduce the profiling time. */ | |||
| ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profiler, | |||
| GraphPartitionFilter graph_partition_filter) | |||
| : m_profiler{std::move(profiler)}, | |||
| m_graph_partition_filter{std::move(graph_partition_filter)} {} | |||
| virtual ~ProfilingBasedSolver() = default; | |||
| Solution solve(const Problem& problem) const override; | |||
| virtual Solution do_solve(const Problem& problem) const = 0; | |||
| protected: | |||
| std::unique_ptr<ProfilerBase> m_profiler; | |||
| private: | |||
| GraphPartitionFilter m_graph_partition_filter; | |||
| }; | |||
| /*! | |||
| * \brief A solver that solves the layout selection problem using dynamic | |||
| * programming algorithm (Markov decision process). | |||
| */ | |||
| class DynamicProgrammingSolver final : public ProfilingBasedSolver { | |||
| public: | |||
| DynamicProgrammingSolver(std::unique_ptr<ProfilerBase> profiler) | |||
| : ProfilingBasedSolver(std::move(profiler)){}; | |||
| DynamicProgrammingSolver(std::unique_ptr<ProfilerBase> profiler, | |||
| GraphPartitionFilter graph_partition_filter) | |||
| : ProfilingBasedSolver(std::move(profiler), | |||
| std::move(graph_partition_filter)){}; | |||
| ~DynamicProgrammingSolver() noexcept = default; | |||
| Solution do_solve(const Problem& problem) const override; | |||
| bool can_solve(const Problem& problem) const override; | |||
| private: | |||
| class Impl; | |||
| }; | |||
| } // namespace gopt | |||
| } // namespace mgb | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -10,10 +10,13 @@ | |||
| * implied. | |||
| */ | |||
| #include "megbrain/gopt/layout_transform_pass.h" | |||
| #include "./network.h" | |||
| #include "megbrain/comp_node_env.h" | |||
| #include "megbrain/gopt/global_layout_transform.h" | |||
| #include "megbrain/gopt/inference.h" | |||
| #include "megbrain/gopt/layout_transform_context.h" | |||
| #include "megbrain/gopt/profiler.h" | |||
| #include "megbrain/gopt/solver.h" | |||
| #include "megbrain/opr/dnn/pooling.h" | |||
| #include "megbrain/opr/imgproc.h" | |||
| #include "megbrain/opr/nn_int.h" | |||
| @@ -12,7 +12,7 @@ | |||
| #include "megbrain/plugin/profiler.h" | |||
| #include "./helper.h" | |||
| #include "megbrain/gopt/global_layout_transform.h" | |||
| #include "megbrain/gopt/profiler.h" | |||
| #include "megbrain/gopt/inference.h" | |||
| #include "megbrain/opr/dnn/pooling.h" | |||
| #include "megbrain/opr/imgproc.h" | |||