GitOrigin-RevId: 4973820e02
tags/v1.6.0
| @@ -30,7 +30,10 @@ | |||||
| #include "megbrain/tensorrt/opr_replace.h" | #include "megbrain/tensorrt/opr_replace.h" | ||||
| #endif | #endif | ||||
| #include "megbrain/gopt/global_layout_transform.h" | |||||
| #include "megbrain/gopt/layout_transform_context.h" | |||||
| #include "megbrain/gopt/layout_transform_pass.h" | |||||
| #include "megbrain/gopt/profiler.h" | |||||
| #include "megbrain/gopt/solver.h" | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace gopt; | using namespace gopt; | ||||
| @@ -12,7 +12,9 @@ | |||||
| #include <queue> | #include <queue> | ||||
| #include "./utils.h" | #include "./utils.h" | ||||
| #include "megbrain/gopt/global_layout_transform.h" | |||||
| #include "megbrain/gopt/layout_transform_context.h" | |||||
| #include "megbrain/gopt/profiler.h" | |||||
| #include "megbrain/gopt/solver.h" | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace gopt; | using namespace gopt; | ||||
| @@ -85,11 +87,11 @@ private: | |||||
| const SmallVector<TensorFormats>& available_tensor_formats; | const SmallVector<TensorFormats>& available_tensor_formats; | ||||
| }; | }; | ||||
| /*! | /*! | ||||
| * \brief get the tensor formats configuration for the operator with particular op format | |||||
| * \param[out] var2fmts hashmap that maps varnode to actual tensor formats of the op format configuration | |||||
| * \param[in] opr given operator | |||||
| * \param[in] opr_fmt given op format, an enum type argument which indicates the op format configuration. | |||||
| * \param[in] ctx context | |||||
| * \brief get the tensor formats configuration for the operator with | |||||
| * particular op format \param[out] var2fmts hashmap that maps varnode to | |||||
| * actual tensor formats of the op format configuration \param[in] opr given | |||||
| * operator \param[in] opr_fmt given op format, an enum type argument which | |||||
| * indicates the op format configuration. \param[in] ctx context | |||||
| */ | */ | ||||
| TensorFormats get_io_formats(ThinHashMap<VarNode*, TensorFormats>& var2fmts, | TensorFormats get_io_formats(ThinHashMap<VarNode*, TensorFormats>& var2fmts, | ||||
| const OperatorNodeBase* opr, OprFormat opr_fmt, | const OperatorNodeBase* opr, OprFormat opr_fmt, | ||||
| @@ -11,7 +11,7 @@ | |||||
| */ | */ | ||||
| #include "./utils.h" | #include "./utils.h" | ||||
| #include "megbrain/gopt/global_layout_transform.h" | |||||
| #include "megbrain/gopt/layout_transform_context.h" | |||||
| #include "megbrain/opr/dnn/pooling.h" | #include "megbrain/opr/dnn/pooling.h" | ||||
| #include "megbrain/opr/imgproc.h" | #include "megbrain/opr/imgproc.h" | ||||
| #include "megbrain/opr/nn_int.h" | #include "megbrain/opr/nn_int.h" | ||||
| @@ -10,9 +10,11 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #include "megbrain/gopt/layout_transform_pass.h" | |||||
| #include "./opr_format_modifier.h" | #include "./opr_format_modifier.h" | ||||
| #include "./utils.h" | #include "./utils.h" | ||||
| #include "megbrain/gopt/global_layout_transform.h" | |||||
| #include "megbrain/gopt/profiler.h" | |||||
| #include "megbrain/gopt/solver.h" | |||||
| #include "megbrain/opr/dnn/pooling.h" | #include "megbrain/opr/dnn/pooling.h" | ||||
| #include "megbrain/opr/imgproc.h" | #include "megbrain/opr/imgproc.h" | ||||
| #include "megbrain/serialization/sereg.h" | #include "megbrain/serialization/sereg.h" | ||||
| @@ -46,8 +48,7 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||||
| auto&& opr_configs = m_ctx->opr_configs(); | auto&& opr_configs = m_ctx->opr_configs(); | ||||
| auto&& base_fmt = m_ctx->attribute().base_tensor_formats; | auto&& base_fmt = m_ctx->attribute().base_tensor_formats; | ||||
| auto&& reformat_attribute = | |||||
| ReformatManager::ReformatKey::Attribute::DEFAULT; | |||||
| auto&& reformat_attribute = m_ctx->attribute().reformat_attribute; | |||||
| ThinHashMap<VarNode*, TensorFormats> var2fmts; | ThinHashMap<VarNode*, TensorFormats> var2fmts; | ||||
| static ThinHashSet<Typeinfo*> format_aware_oprs = { | static ThinHashSet<Typeinfo*> format_aware_oprs = { | ||||
| #define cb(_Opr) opr::_Opr::typeinfo(), | #define cb(_Opr) opr::_Opr::typeinfo(), | ||||
| @@ -55,8 +56,8 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||||
| #undef cb | #undef cb | ||||
| }; | }; | ||||
| auto rewriter = opt.graph().make_rewriter(); | auto rewriter = opt.graph().make_rewriter(); | ||||
| auto on_opr = [this, &opr_configs, &base_fmt, &reformat_attribute, | |||||
| &rewriter, &solution, &var2fmts, | |||||
| auto on_opr = [&opr_configs, &base_fmt, &reformat_attribute, &rewriter, | |||||
| &solution, &var2fmts, | |||||
| &endpoint_vars](OperatorNodeBase* opr) { | &endpoint_vars](OperatorNodeBase* opr) { | ||||
| auto it = solution.find(opr); | auto it = solution.find(opr); | ||||
| if (it != solution.end()) { | if (it != solution.end()) { | ||||
| @@ -122,19 +123,6 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||||
| opr->config()) | opr->config()) | ||||
| ->output(0); | ->output(0); | ||||
| } | } | ||||
| if (endpoint_vars.count(opr->output(0)) && out_fmt != base_fmt) { | |||||
| ReformatManager::ReformatKey key{ | |||||
| out_fmt, base_fmt, reformat_attribute, | |||||
| opr->output(0)->dtype().enumv(), | |||||
| opr->output(0)->dtype().enumv()}; | |||||
| auto reformat = ReformatManager::instance() | |||||
| .auto_aligned_reformat_featrue( | |||||
| opr->output(0), base_fmt, key); | |||||
| new_out = reformat({new_out}); | |||||
| var2fmts[new_out] = base_fmt; | |||||
| } else { | |||||
| var2fmts[new_out] = out_fmt; | |||||
| } | |||||
| auto &&out0 = opr->output(), | auto &&out0 = opr->output(), | ||||
| &&out1 = new_out->owner_opr()->output(); | &&out1 = new_out->owner_opr()->output(); | ||||
| mgb_assert(opr->usable_output().size() == | mgb_assert(opr->usable_output().size() == | ||||
| @@ -146,20 +134,29 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||||
| new_out->owner_opr()->cname(), | new_out->owner_opr()->cname(), | ||||
| new_out->owner_opr()->dyn_typeinfo()->name, out0.size(), | new_out->owner_opr()->dyn_typeinfo()->name, out0.size(), | ||||
| out1.size()); | out1.size()); | ||||
| for (size_t i = 0; i < out0.size(); ++i) { | |||||
| if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { | |||||
| mgb_assert(!out1[i]->contain_flag( | |||||
| VarNode::Flag::VOLATILE_CONTENT)); | |||||
| auto src = out0[i]; | |||||
| auto dst = out1[i]; | |||||
| rewriter.replace_var( | |||||
| src, dst, | |||||
| mgb_cstr_log(ssprintf("replace opr(%s) to new opr " | |||||
| "format(%s)", | |||||
| opr->cname(), | |||||
| opr_format_to_string(opr_fmt)) | |||||
| .c_str())); | |||||
| size_t nr_outs = opr->usable_output().size(); | |||||
| for (size_t i = 0; i < nr_outs; ++i) { | |||||
| const auto& ovar = out0[i]; | |||||
| auto new_ovar = out1[i]; | |||||
| if (endpoint_vars.count(ovar) && out_fmt != base_fmt) { | |||||
| ReformatManager::ReformatKey key{ | |||||
| out_fmt, base_fmt, reformat_attribute, | |||||
| ovar->dtype().enumv(), ovar->dtype().enumv()}; | |||||
| auto reformat = ReformatManager::instance() | |||||
| .auto_aligned_reformat_featrue( | |||||
| ovar, base_fmt, key); | |||||
| new_ovar = reformat({new_ovar}); | |||||
| var2fmts[new_ovar] = base_fmt; | |||||
| } else { | |||||
| var2fmts[new_ovar] = out_fmt; | |||||
| } | } | ||||
| rewriter.replace_var( | |||||
| ovar, new_ovar, | |||||
| mgb_cstr_log(ssprintf("replace opr(%s) to new opr " | |||||
| "format(%s)", | |||||
| opr->cname(), | |||||
| opr_format_to_string(opr_fmt)) | |||||
| .c_str())); | |||||
| } | } | ||||
| } else { | } else { | ||||
| auto new_opr = rewriter.auto_replace_outputs(opr); | auto new_opr = rewriter.auto_replace_outputs(opr); | ||||
| @@ -11,7 +11,7 @@ | |||||
| */ | */ | ||||
| #include "./utils.h" | #include "./utils.h" | ||||
| #include "megbrain/gopt/global_layout_transform.h" | |||||
| #include "megbrain/gopt/layout_transform_context.h" | |||||
| #include "megbrain/opr/dnn/pooling.h" | #include "megbrain/opr/dnn/pooling.h" | ||||
| #include "megbrain/opr/imgproc.h" | #include "megbrain/opr/imgproc.h" | ||||
| @@ -13,7 +13,7 @@ | |||||
| #include "./opr_format_modifier.h" | #include "./opr_format_modifier.h" | ||||
| #include "./utils.h" | #include "./utils.h" | ||||
| #include "megbrain/gopt/framework.h" | #include "megbrain/gopt/framework.h" | ||||
| #include "megbrain/gopt/global_layout_transform.h" | |||||
| #include "megbrain/gopt/profiler.h" | |||||
| #include "megbrain/graph/event.h" | #include "megbrain/graph/event.h" | ||||
| #include "megbrain/opr/dnn/pooling.h" | #include "megbrain/opr/dnn/pooling.h" | ||||
| #include "megbrain/opr/imgproc.h" | #include "megbrain/opr/imgproc.h" | ||||
| @@ -10,7 +10,8 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #include "megbrain/gopt/global_layout_transform.h" | |||||
| #include "megbrain/gopt/profiler.h" | |||||
| #include "megbrain/gopt/solver.h" | |||||
| #include "megbrain/opr/dnn/pooling.h" | #include "megbrain/opr/dnn/pooling.h" | ||||
| #include "megbrain/opr/imgproc.h" | #include "megbrain/opr/imgproc.h" | ||||
| @@ -11,7 +11,7 @@ | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #include "megbrain/gopt/global_layout_transform.h" | |||||
| #include "megbrain/gopt/layout_transform_context.h" | |||||
| namespace mgb { | namespace mgb { | ||||
| namespace gopt { | namespace gopt { | ||||
| @@ -1,5 +1,6 @@ | |||||
| /** | /** | ||||
| * \file src/gopt/include/megbrain/gopt/global_layout_transformation.h | |||||
| * \file | |||||
| * src/gopt/include/megbrain/gopt/layout_transform_context.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
| * | * | ||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
| @@ -12,11 +13,10 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "megbrain/gopt/framework.h" | #include "megbrain/gopt/framework.h" | ||||
| #include "megbrain/gopt/inference.h" | |||||
| #include "megbrain/gopt/reformat_manager.h" | #include "megbrain/gopt/reformat_manager.h" | ||||
| #include "megbrain/gopt/subgraph_extractor.h" | #include "megbrain/gopt/subgraph_extractor.h" | ||||
| #include "megbrain/opr/dnn/convolution.h" | |||||
| #include "megbrain/plugin/opr_footprint.h" | #include "megbrain/plugin/opr_footprint.h" | ||||
| #include "megbrain/gopt/inference.h" | |||||
| namespace mgb { | namespace mgb { | ||||
| namespace gopt { | namespace gopt { | ||||
| @@ -118,7 +118,7 @@ public: | |||||
| TensorFormats base_tensor_format = TensorFormats::NCHW); | TensorFormats base_tensor_format = TensorFormats::NCHW); | ||||
| private: | private: | ||||
| OprList m_opr_list; /// supported operator list | |||||
| OprList m_opr_list; /// supported operator list | |||||
| SmallVector<TensorFormats> | SmallVector<TensorFormats> | ||||
| m_available_tensor_formats; /// the available tensor formats, used | m_available_tensor_formats; /// the available tensor formats, used | ||||
| /// for format agnostic operators (like | /// for format agnostic operators (like | ||||
| @@ -180,164 +180,6 @@ private: | |||||
| const GraphPartition& m_graph_partition; /// the graph partition | const GraphPartition& m_graph_partition; /// the graph partition | ||||
| const LayoutTransformContext& m_ctx; | const LayoutTransformContext& m_ctx; | ||||
| }; | }; | ||||
| /*! | |||||
| * \brief A profiler that collects all the performance data to describe the | |||||
| * global layout transform problem. | |||||
| */ | |||||
| class ProfilerBase { | |||||
| public: | |||||
| using OprFormat = Problem::OprFormat; | |||||
| struct OperatorNodeRecord { | |||||
| const cg::OperatorNodeBase* opr; ///< pointer to operator node | |||||
| ThinHashMap<OprFormat, float> | |||||
| costs; ///< costs of operator node, i.e. the elapsed device | |||||
| ///< time of the operator node on different opr format | |||||
| ///< (layout configuration). | |||||
| std::string to_string() const; | |||||
| }; | |||||
| struct VarNodeRecord { | |||||
| struct KeyHash { | |||||
| size_t operator()( | |||||
| const std::pair<TensorFormats, TensorFormats>& val) const { | |||||
| size_t h1 = | |||||
| std::hash<uint32_t>()(static_cast<uint32_t>(val.first)); | |||||
| size_t h2 = std::hash<uint32_t>()( | |||||
| static_cast<uint32_t>(val.second)); | |||||
| return mgb::hash_pair_combine(h1, h2); | |||||
| } | |||||
| }; | |||||
| const VarNode* var; ///< pointer to var node | |||||
| std::unordered_map<std::pair<TensorFormats, TensorFormats>, float, | |||||
| KeyHash> | |||||
| costs; ///< costs of var node, i.e. the elapsed | |||||
| ///< device time of the layout transform. | |||||
| ///< Key of the hashmap indicates the | |||||
| ///< source tensor format and the target | |||||
| ///< tensor format. | |||||
| std::string to_string() const; | |||||
| }; | |||||
| /*! | |||||
| * \note the profiler assumes all the input and output var node are stored | |||||
| * in contiguous layout in memory | |||||
| */ | |||||
| struct ProfilingResult { | |||||
| /// A hashmap, that maps the operator node to the costs (device elapsed | |||||
| /// time) of different layouts configuration | |||||
| ThinHashMap<cg::OperatorNodeBase*, OperatorNodeRecord> opr_record; | |||||
| /// A hashmap, that maps the var node to the costs of layout transform | |||||
| ThinHashMap<VarNode*, VarNodeRecord> var_record; | |||||
| }; | |||||
| using OprFilter = thin_function<bool(const cg::OperatorNodeBase*, | |||||
| cg::OperatorNodeBase*)>; | |||||
| using VarNodeFilter = | |||||
| thin_function<bool(const VarNode*, TensorShape, TensorShape, | |||||
| ReformatManager::ReformatKey)>; | |||||
| ProfilerBase(float opr_threshold = 2.f, float var_node_threshold = 2.f); | |||||
| ProfilerBase(OprFilter opr_filter, VarNodeFilter var_node_filter = {}) | |||||
| : m_opr_filter{std::move(opr_filter)}, | |||||
| m_var_node_filter{std::move(var_node_filter)} {} | |||||
| virtual ~ProfilerBase() = default; | |||||
| virtual ProfilingResult profile(const Problem& problem) const = 0; | |||||
| static std::unique_ptr<ProfilerBase> make_profiler(); | |||||
| protected: | |||||
| OprFilter m_opr_filter; | |||||
| VarNodeFilter m_var_node_filter; | |||||
| float m_opr_threshold; | |||||
| float m_var_node_threshold; | |||||
| private: | |||||
| OprFootprint m_opr_footprint; | |||||
| }; | |||||
| /*! | |||||
| * \brief abstract solver | |||||
| */ | |||||
| class SolverBase { | |||||
| public: | |||||
| using OprFormat = Problem::OprFormat; | |||||
| using Solution = ThinHashMap<cg::OperatorNodeBase*, OprFormat>; | |||||
| SolverBase() = default; | |||||
| virtual ~SolverBase() = default; | |||||
| /*! | |||||
| * \brief solve the given problem | |||||
| */ | |||||
| virtual Solution solve(const Problem& problem) const = 0; | |||||
| /*! | |||||
| * \brief check whether the given problem can be solved by the | |||||
| * algorithm(i.e. solver). | |||||
| */ | |||||
| virtual bool can_solve(const Problem& problem) const = 0; | |||||
| }; | |||||
| /*! | |||||
| * \brief solvers that will first collect the costs of operators in different op | |||||
| * format and the costs of layout transform of varnode with a user provided | |||||
| * profiler on the target device. This will lead to time consuming. | |||||
| */ | |||||
| class ProfilingBasedSolver : public SolverBase { | |||||
| public: | |||||
| using GraphPartitionFilter = | |||||
| thin_function<bool(const GraphPartition& graph_partition)>; | |||||
| ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profiler); | |||||
| /*! | |||||
| * \note some graph partition (for example, graph partition without format | |||||
| * aware operators like conv, deconv, warp, resize etc.) will be filtered by | |||||
| * the GraphPartitionFilter, which can reduce the profiling time. */ | |||||
| ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profiler, | |||||
| GraphPartitionFilter graph_partition_filter) | |||||
| : m_profiler{std::move(profiler)}, | |||||
| m_graph_partition_filter{std::move(graph_partition_filter)} {} | |||||
| virtual ~ProfilingBasedSolver() = default; | |||||
| Solution solve(const Problem& problem) const override; | |||||
| virtual Solution do_solve(const Problem& problem) const = 0; | |||||
| protected: | |||||
| std::unique_ptr<ProfilerBase> m_profiler; | |||||
| private: | |||||
| GraphPartitionFilter m_graph_partition_filter; | |||||
| }; | |||||
| /*! | |||||
| * \brief A solver that solves the layout selection problem using dynamic | |||||
| * programming algorithm (Markov decision process). | |||||
| */ | |||||
| class DynamicProgrammingSolver final : public ProfilingBasedSolver { | |||||
| public: | |||||
| DynamicProgrammingSolver(std::unique_ptr<ProfilerBase> profiler) | |||||
| : ProfilingBasedSolver(std::move(profiler)){}; | |||||
| DynamicProgrammingSolver(std::unique_ptr<ProfilerBase> profiler, | |||||
| GraphPartitionFilter graph_partition_filter) | |||||
| : ProfilingBasedSolver(std::move(profiler), | |||||
| std::move(graph_partition_filter)){}; | |||||
| ~DynamicProgrammingSolver() noexcept = default; | |||||
| Solution do_solve(const Problem& problem) const override; | |||||
| bool can_solve(const Problem& problem) const override; | |||||
| private: | |||||
| class Impl; | |||||
| }; | |||||
| /*! | |||||
| * \brief A layout transform pass, which convert the operator's format to the | |||||
| * optimal format using the results of the solver. | |||||
| */ | |||||
| class LayoutTransformPass final : public Pass { | |||||
| public: | |||||
| const char* name() const override { return "layout assignment pass"; } | |||||
| void apply(OptState& opt) const override; | |||||
| LayoutTransformPass(std::unique_ptr<LayoutTransformContext> ctx, | |||||
| std::unique_ptr<SolverBase> solver) | |||||
| : m_ctx{std::move(ctx)}, m_solver{std::move(solver)} {} | |||||
| private: | |||||
| std::unique_ptr<LayoutTransformContext> m_ctx; | |||||
| std::unique_ptr<SolverBase> m_solver; | |||||
| }; | |||||
| } // namespace gopt | } // namespace gopt | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -0,0 +1,42 @@ | |||||
| /** | |||||
| * \file src/gopt/include/megbrain/gopt/global_layout_transformation.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megbrain/gopt/framework.h" | |||||
| namespace mgb { | |||||
| namespace gopt { | |||||
| class LayoutTransformContext; | |||||
| class SolverBase; | |||||
| /*! | |||||
| * \brief A layout transform pass, which convert the operator's format to the | |||||
| * optimal format using the results of the solver. | |||||
| */ | |||||
| class LayoutTransformPass final : public Pass { | |||||
| public: | |||||
| const char* name() const override { return "layout assignment pass"; } | |||||
| void apply(OptState& opt) const override; | |||||
| LayoutTransformPass(std::unique_ptr<LayoutTransformContext> ctx, | |||||
| std::unique_ptr<SolverBase> solver) | |||||
| : m_ctx{std::move(ctx)}, m_solver{std::move(solver)} {} | |||||
| private: | |||||
| std::unique_ptr<LayoutTransformContext> m_ctx; | |||||
| std::unique_ptr<SolverBase> m_solver; | |||||
| }; | |||||
| } // namespace gopt | |||||
| } // namespace mgb | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| @@ -0,0 +1,101 @@ | |||||
| /** | |||||
| * \file src/gopt/include/megbrain/gopt/global_layout_transformation.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megbrain/gopt/framework.h" | |||||
| #include "megbrain/gopt/inference.h" | |||||
| #include "megbrain/gopt/layout_transform_context.h" | |||||
| #include "megbrain/gopt/reformat_manager.h" | |||||
| #include "megbrain/gopt/subgraph_extractor.h" | |||||
| #include "megbrain/opr/dnn/convolution.h" | |||||
| #include "megbrain/plugin/opr_footprint.h" | |||||
| namespace mgb { | |||||
| namespace gopt { | |||||
| class Problem; | |||||
| /*! | |||||
| * \brief A profiler that collects all the performance data to describe the | |||||
| * global layout transform problem. | |||||
| */ | |||||
| class ProfilerBase { | |||||
| public: | |||||
| using OprFormat = Problem::OprFormat; | |||||
| struct OperatorNodeRecord { | |||||
| const cg::OperatorNodeBase* opr; ///< pointer to operator node | |||||
| ThinHashMap<OprFormat, float> | |||||
| costs; ///< costs of operator node, i.e. the elapsed device | |||||
| ///< time of the operator node on different opr format | |||||
| ///< (layout configuration). | |||||
| std::string to_string() const; | |||||
| }; | |||||
| struct VarNodeRecord { | |||||
| struct KeyHash { | |||||
| size_t operator()( | |||||
| const std::pair<TensorFormats, TensorFormats>& val) const { | |||||
| size_t h1 = | |||||
| std::hash<uint32_t>()(static_cast<uint32_t>(val.first)); | |||||
| size_t h2 = std::hash<uint32_t>()( | |||||
| static_cast<uint32_t>(val.second)); | |||||
| return mgb::hash_pair_combine(h1, h2); | |||||
| } | |||||
| }; | |||||
| const VarNode* var; ///< pointer to var node | |||||
| std::unordered_map<std::pair<TensorFormats, TensorFormats>, float, | |||||
| KeyHash> | |||||
| costs; ///< costs of var node, i.e. the elapsed | |||||
| ///< device time of the layout transform. | |||||
| ///< Key of the hashmap indicates the | |||||
| ///< source tensor format and the target | |||||
| ///< tensor format. | |||||
| std::string to_string() const; | |||||
| }; | |||||
| /*! | |||||
| * \note the profiler assumes all the input and output var node are stored | |||||
| * in contiguous layout in memory | |||||
| */ | |||||
| struct ProfilingResult { | |||||
| /// A hashmap, that maps the operator node to the costs (device elapsed | |||||
| /// time) of different layouts configuration | |||||
| ThinHashMap<cg::OperatorNodeBase*, OperatorNodeRecord> opr_record; | |||||
| /// A hashmap, that maps the var node to the costs of layout transform | |||||
| ThinHashMap<VarNode*, VarNodeRecord> var_record; | |||||
| }; | |||||
| using OprFilter = thin_function<bool(const cg::OperatorNodeBase*, | |||||
| cg::OperatorNodeBase*)>; | |||||
| using VarNodeFilter = | |||||
| thin_function<bool(const VarNode*, TensorShape, TensorShape, | |||||
| ReformatManager::ReformatKey)>; | |||||
| ProfilerBase(float opr_threshold = 2.f, float var_node_threshold = 2.f); | |||||
| ProfilerBase(OprFilter opr_filter, VarNodeFilter var_node_filter = {}) | |||||
| : m_opr_filter{std::move(opr_filter)}, | |||||
| m_var_node_filter{std::move(var_node_filter)} {} | |||||
| virtual ~ProfilerBase() = default; | |||||
| virtual ProfilingResult profile(const Problem& problem) const = 0; | |||||
| static std::unique_ptr<ProfilerBase> make_profiler(); | |||||
| protected: | |||||
| OprFilter m_opr_filter; | |||||
| VarNodeFilter m_var_node_filter; | |||||
| float m_opr_threshold; | |||||
| float m_var_node_threshold; | |||||
| private: | |||||
| OprFootprint m_opr_footprint; | |||||
| }; | |||||
| } // namespace gopt | |||||
| } // namespace mgb | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| @@ -0,0 +1,97 @@ | |||||
| /** | |||||
| * \file src/gopt/include/megbrain/gopt/solver.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megbrain/gopt/framework.h" | |||||
| #include "megbrain/gopt/layout_transform_context.h" | |||||
| #include "megbrain/opr/dnn/convolution.h" | |||||
| #include "megbrain/plugin/opr_footprint.h" | |||||
| #include "megbrain/gopt/inference.h" | |||||
| namespace mgb { | |||||
| namespace gopt { | |||||
| class ProfilerBase; | |||||
| /*! | |||||
| * \brief abstract solver | |||||
| */ | |||||
| class SolverBase { | |||||
| public: | |||||
| using OprFormat = Problem::OprFormat; | |||||
| using Solution = ThinHashMap<cg::OperatorNodeBase*, OprFormat>; | |||||
| SolverBase() = default; | |||||
| virtual ~SolverBase() = default; | |||||
| /*! | |||||
| * \brief solve the given problem | |||||
| */ | |||||
| virtual Solution solve(const Problem& problem) const = 0; | |||||
| /*! | |||||
| * \brief check whether the given problem can be solved by the | |||||
| * algorithm(i.e. solver). | |||||
| */ | |||||
| virtual bool can_solve(const Problem& problem) const = 0; | |||||
| }; | |||||
| /*! | |||||
| * \brief solvers that will first collect the costs of operators in different op | |||||
| * format and the costs of layout transform of varnode with a user provided | |||||
| * profiler on the target device. This will lead to time consuming. | |||||
| */ | |||||
| class ProfilingBasedSolver : public SolverBase { | |||||
| public: | |||||
| using GraphPartitionFilter = | |||||
| thin_function<bool(const GraphPartition& graph_partition)>; | |||||
| ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profiler); | |||||
| /*! | |||||
| * \note some graph partition (for example, graph partition without format | |||||
| * aware operators like conv, deconv, warp, resize etc.) will be filtered by | |||||
| * the GraphPartitionFilter, which can reduce the profiling time. */ | |||||
| ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profiler, | |||||
| GraphPartitionFilter graph_partition_filter) | |||||
| : m_profiler{std::move(profiler)}, | |||||
| m_graph_partition_filter{std::move(graph_partition_filter)} {} | |||||
| virtual ~ProfilingBasedSolver() = default; | |||||
| Solution solve(const Problem& problem) const override; | |||||
| virtual Solution do_solve(const Problem& problem) const = 0; | |||||
| protected: | |||||
| std::unique_ptr<ProfilerBase> m_profiler; | |||||
| private: | |||||
| GraphPartitionFilter m_graph_partition_filter; | |||||
| }; | |||||
| /*! | |||||
| * \brief A solver that solves the layout selection problem using dynamic | |||||
| * programming algorithm (Markov decision process). | |||||
| */ | |||||
| class DynamicProgrammingSolver final : public ProfilingBasedSolver { | |||||
| public: | |||||
| DynamicProgrammingSolver(std::unique_ptr<ProfilerBase> profiler) | |||||
| : ProfilingBasedSolver(std::move(profiler)){}; | |||||
| DynamicProgrammingSolver(std::unique_ptr<ProfilerBase> profiler, | |||||
| GraphPartitionFilter graph_partition_filter) | |||||
| : ProfilingBasedSolver(std::move(profiler), | |||||
| std::move(graph_partition_filter)){}; | |||||
| ~DynamicProgrammingSolver() noexcept = default; | |||||
| Solution do_solve(const Problem& problem) const override; | |||||
| bool can_solve(const Problem& problem) const override; | |||||
| private: | |||||
| class Impl; | |||||
| }; | |||||
| } // namespace gopt | |||||
| } // namespace mgb | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| @@ -10,10 +10,13 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #include "megbrain/gopt/layout_transform_pass.h" | |||||
| #include "./network.h" | #include "./network.h" | ||||
| #include "megbrain/comp_node_env.h" | #include "megbrain/comp_node_env.h" | ||||
| #include "megbrain/gopt/global_layout_transform.h" | |||||
| #include "megbrain/gopt/inference.h" | #include "megbrain/gopt/inference.h" | ||||
| #include "megbrain/gopt/layout_transform_context.h" | |||||
| #include "megbrain/gopt/profiler.h" | |||||
| #include "megbrain/gopt/solver.h" | |||||
| #include "megbrain/opr/dnn/pooling.h" | #include "megbrain/opr/dnn/pooling.h" | ||||
| #include "megbrain/opr/imgproc.h" | #include "megbrain/opr/imgproc.h" | ||||
| #include "megbrain/opr/nn_int.h" | #include "megbrain/opr/nn_int.h" | ||||
| @@ -12,7 +12,7 @@ | |||||
| #include "megbrain/plugin/profiler.h" | #include "megbrain/plugin/profiler.h" | ||||
| #include "./helper.h" | #include "./helper.h" | ||||
| #include "megbrain/gopt/global_layout_transform.h" | |||||
| #include "megbrain/gopt/profiler.h" | |||||
| #include "megbrain/gopt/inference.h" | #include "megbrain/gopt/inference.h" | ||||
| #include "megbrain/opr/dnn/pooling.h" | #include "megbrain/opr/dnn/pooling.h" | ||||
| #include "megbrain/opr/imgproc.h" | #include "megbrain/opr/imgproc.h" | ||||