| @@ -0,0 +1,161 @@ | |||
| /** | |||
| * \file src/core/impl/graph/seq_modifier_base.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "./seq_modifier_base.h" | |||
| #if MGB_ENABLE_SUBLINEAR | |||
| using namespace mgb; | |||
| using namespace cg; | |||
| void SeqModifierBase::ModifyActionPlannerBase::init_seq(const OprNodeArray& opr_seq) { | |||
| m_orig_opr_seq = &opr_seq; | |||
| m_var_storage.clear(); | |||
| m_seq.clear(); | |||
| m_var_mempool.reorder_free(); | |||
| m_opr_mempool.reorder_free(); | |||
| m_nr_endpoint_oprs = 0; | |||
| ThinHashMap<VarNode*, Var*> varmap; | |||
| for (auto orig_opr : *m_orig_opr_seq) { | |||
| auto time = m_seq.size(); | |||
| m_seq.emplace_back(m_opr_mempool.alloc_unique(orig_opr, time)); | |||
| auto opr = m_seq.back().get(); | |||
| m_nr_endpoint_oprs += opr->is_endpoint; | |||
| for (auto&& dep : orig_opr->node_prop().dep_map()) { | |||
| if (!OperatorNodeBase::NodeProp::is_device_value_dep(dep.second)) | |||
| continue; | |||
| auto iter = varmap.find(dep.first); | |||
| if (iter == varmap.end()) { | |||
| // input var needs not to be considered | |||
| continue; | |||
| } | |||
| auto ivar = iter->second; | |||
| bool exist = false; | |||
| for (auto i : opr->input) { | |||
| if (i == ivar) { | |||
| exist = true; | |||
| break; | |||
| } | |||
| } | |||
| if (exist) { | |||
| // same var for different inputs | |||
| continue; | |||
| } | |||
| opr->input.push_back(ivar); | |||
| auto&& prev_rec = ivar->access_rec.back(); | |||
| prev_rec.stride = time - prev_rec.opr->time; | |||
| ivar->access_rec.emplace_back(opr); | |||
| } | |||
| for (auto i : orig_opr->output()) { | |||
| auto var2memsize = m_par_modifier->m_mem_opt.var2memsize(); | |||
| auto iter = var2memsize->find(i); | |||
| if (iter == var2memsize->end()) { | |||
| // some vars are ignored; see split_into_cn2oprseq() | |||
| continue; | |||
| } | |||
| m_var_storage.emplace_back( | |||
| m_var_mempool.alloc_unique(i, iter->second, opr)); | |||
| auto ovar = m_var_storage.back().get(); | |||
| varmap[i] = ovar; | |||
| opr->output.push_back(ovar); | |||
| } | |||
| mgb_assert(!opr->output.empty()); | |||
| } | |||
| // remove unused output | |||
| for (auto&& i : m_seq) { | |||
| auto&& oarr = i->output; | |||
| for (size_t j = 0; j < oarr.size();) { | |||
| if (oarr[j]->access_rec.size() == 1) { | |||
| std::swap(oarr[j], oarr.back()); | |||
| oarr.pop_back(); | |||
| } else | |||
| ++j; | |||
| } | |||
| } | |||
| } | |||
| bool SeqModifierBase::replace_vars(const VarNodeArray& inputs) { | |||
| m_new_inputs.assign(inputs.begin(), inputs.end()); | |||
| bool changed = false; | |||
| for (auto&& i : m_new_inputs) { | |||
| auto iter = m_var_map.find(i); | |||
| if (iter != m_var_map.end()) { | |||
| i = iter->second; | |||
| changed = true; | |||
| } | |||
| } | |||
| return changed; | |||
| } | |||
| OperatorNodeBase* SeqModifierBase::copy_opr_from_new_inputs( | |||
| OperatorNodeBase* opr, bool recomp, size_t recomp_cnt) { | |||
| auto config = opr->config(); | |||
| // update operator instance id to bybass the shallow copy's cache if | |||
| // it's a dup-opr-copying due to discarding. | |||
| // Don't update instance id by `this` pointer if it's a recomp-opr-copying | |||
| // because: | |||
| // 0) recomp-opr would be copied iff its input vars is changed | |||
| // 1) some pair of recomp-opr and dup-opr have the same inputs, params | |||
| // and config, we use instance id to differentiate them. | |||
| config.name(opr->name() + (recomp ? ":recomp" : ":dup") + std::to_string(recomp_cnt)); | |||
| config.update_instance_id(reinterpret_cast<void*>( | |||
| reinterpret_cast<size_t>(this) + | |||
| ((static_cast<size_t>(recomp) + 1) << 10) * recomp_cnt)); | |||
| // Note: if all outputs of op were placed on the same comp_node, since its | |||
| // stream maybe changed during seq_comp_node_opt, output's comp_node has | |||
| // higher priority than opr->config() | |||
| auto out_cn = opr->output(0)->comp_node(); | |||
| for (auto i : opr->output()) { | |||
| auto cn = i->comp_node(); | |||
| if (out_cn != cn) { | |||
| out_cn = {}; | |||
| break; | |||
| } | |||
| } | |||
| if (out_cn.valid()) | |||
| config.comp_node(out_cn); | |||
| auto opr_new = serialization::copy_opr_shallow(*opr, m_new_inputs, config); | |||
| mgb_assert(opr_new != opr); | |||
| auto&& out0 = opr->output(); | |||
| auto&& out1 = opr_new->output(); | |||
| mgb_assert(out0.size() == out1.size()); | |||
| bool stream_changed = false; | |||
| for (size_t i = 0; i < out0.size(); ++i) { | |||
| auto &&cn0 = out0[i]->comp_node(), | |||
| &&cn1 = out1[i]->comp_node(); | |||
| if (cn0 != cn1) { | |||
| mgb_assert(recomp); | |||
| mgb_assert(cn0.locator().type == cn1.locator().type && | |||
| cn0.locator().device == cn1.locator().device); | |||
| out1[i]->comp_node(cn0); | |||
| stream_changed = true; | |||
| } | |||
| m_var_map[out0[i]] = out1[i]; | |||
| } | |||
| if (stream_changed) { | |||
| opr_new->on_output_comp_node_stream_changed(); | |||
| } | |||
| return opr_new; | |||
| } | |||
| #endif // MGB_ENABLE_SUBLINEAR | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -0,0 +1,237 @@ | |||
| /** | |||
| * \file src/core/impl/graph/seq_modifier_base.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "./memory_optimizer.h" | |||
| #include "megbrain/comp_node_env.h" | |||
| #include "megbrain/graph/cg.h" | |||
| #include "megbrain/plugin/opr_footprint.h" | |||
| #include "megbrain/serialization/opr_shallow_copy.h" | |||
| #include "megbrain/system.h" | |||
| #include "megbrain/utils/async_worker.h" | |||
| #include "megbrain/utils/arith_helper.h" | |||
| #include "megbrain/utils/mempool.h" | |||
| #include "megbrain/utils/timer.h" | |||
| #if MGB_ENABLE_SUBLINEAR | |||
| namespace mgb { | |||
| namespace cg { | |||
| /*! | |||
| * \brief modifying computing sequence, with basically the same idea of Training | |||
| * Deep Nets with Sublinear Memory Cost and Dynamic Tensor Rematerialization | |||
| */ | |||
| class SeqModifierBase { | |||
| public: | |||
| /*! | |||
| * describes modifications that should be applied to an operator sequnce: | |||
| * maps from an opr to the oprs that should be duplicated and inserted | |||
| * before it. | |||
| */ | |||
| using SeqModifyAction = std::unordered_map<OperatorNodeBase*, OprNodeArray>; | |||
| struct Var; | |||
| struct Opr; | |||
| class ModifyActionPlannerBase { | |||
| const SeqModifierBase* const m_par_modifier; | |||
| const OprNodeArray* m_orig_opr_seq; | |||
| MemPool<Var> m_var_mempool; | |||
| MemPool<Opr> m_opr_mempool; | |||
| std::vector<MemPool<Var>::UniquePtr> m_var_storage; | |||
| std::vector<MemPool<Opr>::UniquePtr> m_seq; | |||
| size_t m_nr_endpoint_oprs = 0; | |||
| public: | |||
| //! special creation time used for oprs duplicated from others | |||
| static constexpr size_t DUPOPR_TIME = | |||
| std::numeric_limits<size_t>::max() - 1; | |||
| const SeqModifierBase* const par_modifier() { | |||
| return m_par_modifier; | |||
| } | |||
| const OprNodeArray* const orig_opr_seq() { | |||
| return m_orig_opr_seq; | |||
| } | |||
| MemPool<Var>& var_mempool() { | |||
| return m_var_mempool; | |||
| } | |||
| MemPool<Opr>& opr_mempool() { | |||
| return m_opr_mempool; | |||
| } | |||
| std::vector<MemPool<Var>::UniquePtr>& var_storage() { | |||
| return m_var_storage; | |||
| } | |||
| std::vector<MemPool<Opr>::UniquePtr>& seq() { | |||
| return m_seq; | |||
| } | |||
| size_t& nr_endpoint_oprs() { | |||
| return m_nr_endpoint_oprs; | |||
| } | |||
| ModifyActionPlannerBase(SeqModifierBase* par) | |||
| : m_par_modifier{par} {} | |||
| ~ModifyActionPlannerBase() noexcept { | |||
| m_opr_mempool.disable_freelist(); | |||
| m_var_mempool.disable_freelist(); | |||
| } | |||
| //! init m_orig_opr_seq from opr_seq, should be called first. | |||
| void init_seq(const OprNodeArray& opr_seq); | |||
| }; | |||
| SeqModifierBase(ComputingGraphImpl* owner) : m_mem_opt(owner), m_owner_graph(owner) {} | |||
| MemoryOptimizerHelper& mem_opt() { | |||
| return m_mem_opt; | |||
| } | |||
| ComputingGraphImpl* const owner_graph() { | |||
| return m_owner_graph; | |||
| } | |||
| ThinHashMap<VarNode*, VarNode*>& var_map() { | |||
| return m_var_map; | |||
| } | |||
| /*! | |||
| * \brief copy opr and set inputs to m_new_inputs, and add outputs in | |||
| * m_var_map | |||
| * \return new operator | |||
| */ | |||
| OperatorNodeBase* copy_opr_from_new_inputs(OperatorNodeBase* opr, bool recomp, size_t recomp_cnt=0); | |||
| /*! | |||
| * \brief replace input vars according to m_var_map, and store results in | |||
| * m_new_inputs; | |||
| * \return whether any var is changed | |||
| */ | |||
| bool replace_vars(const VarNodeArray& inputs); | |||
| //! see memory_optimizer set_priority_before_opt | |||
| void set_priority_before_opt(const VarNodeArray& endpoints) { | |||
| m_mem_opt.set_priority_before_opt(endpoints); | |||
| } | |||
| //! see memory_optimizer restore_graph_option | |||
| void restore_graph_option() { | |||
| m_mem_opt.restore_graph_option(); | |||
| } | |||
| private: | |||
| MemoryOptimizerHelper m_mem_opt; | |||
| ComputingGraphImpl* const m_owner_graph = nullptr; | |||
| //! map from original var to replaced var | |||
| ThinHashMap<VarNode*, VarNode*> m_var_map; | |||
| VarNodeArray m_new_inputs; //!< setup by replace_vars | |||
| }; | |||
| struct SeqModifierBase::Opr { | |||
| OperatorNodeBase* const orig_opr; | |||
| std::vector<Var*> input, output; | |||
| const size_t time; //!< index in opr sequence | |||
| const bool is_endpoint; | |||
| double estimate_compute_time = 1; | |||
| //! input vars that have been discarded and need to be recomputed before | |||
| //! this opr; for internal use by apply_discard_plan() | |||
| std::vector<Var*> inputs_to_recompute; | |||
| //! new oprs to be inserted before this opr; setup by apply_discard_plan() | |||
| std::vector<MemPool<Opr>::UniquePtr> oprs_insert_before; | |||
| //! [begin, end) interval of *time* for oprs belonging to this block; setup | |||
| //! by make_discard_plan() | |||
| size_t block_begin_time = 0, block_end_time = 0; | |||
| Opr(OperatorNodeBase* opr, size_t t) | |||
| : orig_opr{opr}, | |||
| time{t}, | |||
| is_endpoint{opr->owner_graph() | |||
| ->options() | |||
| .opr_attribute.get_sublinear_memory_endpoint( | |||
| opr)} {} | |||
| }; | |||
| struct SeqModifierBase::Var { | |||
| VarNode* const orig_var; | |||
| size_t size; //!< memory usage in bytes of this var | |||
| size_t recomp_id = 0; | |||
| double last_access_time = 0; | |||
| //! write or read access of a var | |||
| struct AccessRecord { | |||
| Opr* const opr; | |||
| const size_t time; | |||
| size_t stride; | |||
| explicit AccessRecord(Opr* o = nullptr) | |||
| : opr{o}, time{o->time}, stride{0} {} | |||
| }; | |||
| //! access_rec[0] is the creation opr, and others are reader oprs | |||
| std::vector<AccessRecord> access_rec; | |||
| /*! | |||
| * An index in access_rec | |||
| * | |||
| * if valid, then the var should be discarded after | |||
| * discard_tailing_access->opr finishes | |||
| * | |||
| * setup by make_discard_plan | |||
| */ | |||
| Maybe<size_t> discard_tailing_access; | |||
| /*! | |||
| * An index in access_rec | |||
| * maintained during make_discard_plan(), for the next access relative to | |||
| * current operator | |||
| */ | |||
| Maybe<size_t> next_access; | |||
| AccessRecord* visit_discard_tailing_access() { | |||
| return discard_tailing_access.valid() | |||
| ? &access_rec.at(discard_tailing_access.val()) | |||
| : nullptr; | |||
| } | |||
| AccessRecord* visit_next_access() { | |||
| return next_access.valid() ? &access_rec.at(next_access.val()) | |||
| : nullptr; | |||
| } | |||
| auto owner_opr() const { return access_rec[0].opr; } | |||
| auto last_access_opr() const { return access_rec.back().opr; } | |||
| Var(VarNode* var, size_t s, Opr* opr) : orig_var{var}, size{s} { | |||
| access_rec.emplace_back(opr); | |||
| } | |||
| }; | |||
| } // namespace cg | |||
| } // namespace mgb | |||
| #endif // MGB_ENABLE_SUBLINEAR | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -61,108 +61,15 @@ bool is_bad_opr(OperatorNodeBase* opr) { | |||
| } | |||
| } // namespace | |||
| /* ====================== Abstract Opr & Var ====================== */ | |||
| struct SeqModifierForSublinearMemory::Opr { | |||
| OperatorNodeBase* const orig_opr; | |||
| std::vector<Var*> input, output; | |||
| const size_t time; //!< index in opr sequence | |||
| const bool is_endpoint; | |||
| //! input vars that have been discarded and need to be recomputed before | |||
| //! this opr; for internal use by apply_discard_plan() | |||
| std::vector<Var*> inputs_to_recompute; | |||
| //! new oprs to be inserted before this opr; setup by apply_discard_plan() | |||
| std::vector<MemPool<Opr>::UniquePtr> oprs_insert_before; | |||
| //! [begin, end) interval of *time* for oprs belonging to this block; setup | |||
| //! by make_discard_plan() | |||
| size_t block_begin_time = 0, block_end_time = 0; | |||
| Opr(OperatorNodeBase* opr, size_t t) | |||
| : orig_opr{opr}, | |||
| time{t}, | |||
| is_endpoint{opr->owner_graph() | |||
| ->options() | |||
| .opr_attribute.get_sublinear_memory_endpoint( | |||
| opr)} {} | |||
| }; | |||
| struct SeqModifierForSublinearMemory::Var { | |||
| //! write or read access of a var | |||
| struct AccessRecord { | |||
| Opr* const opr; | |||
| const size_t time; | |||
| size_t stride; //!< time distance until next read; 0 for last access | |||
| explicit AccessRecord(Opr* o = nullptr) | |||
| : opr{o}, time{o->time}, stride{0} {} | |||
| }; | |||
| VarNode* const orig_var; | |||
| const size_t size; //!< memory usage in bytes of this var | |||
| //! access_rec[0] is the creation opr, and others are reader oprs | |||
| std::vector<AccessRecord> access_rec; | |||
| /*! | |||
| * An index in access_rec | |||
| * | |||
| * if valid, then the var should be discarded after | |||
| * discard_tailing_access->opr finishes | |||
| * | |||
| * setup by make_discard_plan | |||
| */ | |||
| Maybe<size_t> discard_tailing_access; | |||
| /*! | |||
| * An index in access_rec | |||
| * maintained during make_discard_plan(), for the next access relative to | |||
| * current operator | |||
| */ | |||
| Maybe<size_t> next_access; | |||
| AccessRecord* visit_discard_tailing_access() { | |||
| return discard_tailing_access.valid() | |||
| ? &access_rec.at(discard_tailing_access.val()) | |||
| : nullptr; | |||
| } | |||
| AccessRecord* visit_next_access() { | |||
| return next_access.valid() ? &access_rec.at(next_access.val()) | |||
| : nullptr; | |||
| } | |||
| auto owner_opr() const { return access_rec[0].opr; } | |||
| auto last_access_opr() const { return access_rec.back().opr; } | |||
| Var(VarNode* var, size_t s, Opr* opr) : orig_var{var}, size{s} { | |||
| access_rec.emplace_back(opr); | |||
| } | |||
| }; | |||
| /* ====================== ModifyActionPlanner ====================== */ | |||
| class SeqModifierForSublinearMemory::ModifyActionPlanner { | |||
| //! special creation time used for oprs duplicated from others | |||
| static constexpr size_t DUPOPR_TIME = | |||
| std::numeric_limits<size_t>::max() - 1; | |||
| class SeqModifierForSublinearMemory::ModifyActionPlanner : public ModifyActionPlannerBase { | |||
| using VarArray = std::vector<Var*>; | |||
| using VarSet = ThinHashSet<Var*>; | |||
| using OprArray = std::vector<Opr*>; | |||
| const SeqModifierForSublinearMemory* const m_par_modifier; | |||
| const OprNodeArray* m_orig_opr_seq; | |||
| MemPool<Var> m_var_mempool; | |||
| MemPool<Opr> m_opr_mempool; | |||
| std::vector<MemPool<Var>::UniquePtr> m_var_storage; | |||
| std::vector<MemPool<Opr>::UniquePtr> m_seq; | |||
| size_t m_nr_endpoint_oprs = 0; | |||
| VarSet m_prev_block_discard_vars; | |||
| std::vector<OprArray> m_blocks; | |||
| SeqModifyAction m_action; | |||
| //! split_point_set to block | |||
| void split_into_blocks(const SplitPointSet& split_point_set); | |||
| @@ -188,14 +95,7 @@ class SeqModifierForSublinearMemory::ModifyActionPlanner { | |||
| public: | |||
| ModifyActionPlanner(SeqModifierForSublinearMemory* par) | |||
| : m_par_modifier{par} {} | |||
| ~ModifyActionPlanner() noexcept { | |||
| m_opr_mempool.disable_freelist(); | |||
| m_var_mempool.disable_freelist(); | |||
| } | |||
| //! init m_orig_opr_seq from opr_seq, should be called first. | |||
| void init_seq(const OprNodeArray& opr_seq); | |||
| : ModifyActionPlannerBase{par} {} | |||
| //! generate split point set from thresh | |||
| SplitPointSet get_split_point_set(size_t block_size_thresh); | |||
| @@ -213,7 +113,7 @@ public: | |||
| void SeqModifierForSublinearMemory::ModifyActionPlanner::get_prev_action( | |||
| SeqModifyAction& action) { | |||
| action.clear(); | |||
| for (auto&& opr : m_seq) { | |||
| for (auto&& opr : seq()) { | |||
| auto&& arr = opr->oprs_insert_before; | |||
| if (arr.empty()) | |||
| continue; | |||
| @@ -261,8 +161,8 @@ SeqModifierForSublinearMemory::ModifyActionPlanner::get_split_point_set( | |||
| cur_block_alive_vars.clear(); | |||
| }; | |||
| for (size_t i = 0; i < m_seq.size(); ++i) { | |||
| auto opr = m_seq[i].get(); | |||
| for (size_t i = 0; i < seq().size(); ++i) { | |||
| auto opr = seq()[i].get(); | |||
| for (auto i : opr->output) | |||
| add_alive(i); | |||
| @@ -272,8 +172,8 @@ SeqModifierForSublinearMemory::ModifyActionPlanner::get_split_point_set( | |||
| remove_alive(i); | |||
| } | |||
| if (i + 1 < m_seq.size() && (cur_block_usage < block_size_thresh || | |||
| (m_nr_endpoint_oprs && !opr->is_endpoint))) | |||
| if (i + 1 < seq().size() && (cur_block_usage < block_size_thresh || | |||
| (nr_endpoint_oprs() && !opr->is_endpoint))) | |||
| continue; | |||
| flush_block_member(i); | |||
| @@ -281,81 +181,6 @@ SeqModifierForSublinearMemory::ModifyActionPlanner::get_split_point_set( | |||
| return split_point_set; | |||
| } | |||
| void SeqModifierForSublinearMemory::ModifyActionPlanner::init_seq( | |||
| const OprNodeArray& opr_seq) { | |||
| m_orig_opr_seq = &opr_seq; | |||
| m_var_storage.clear(); | |||
| m_seq.clear(); | |||
| m_var_mempool.reorder_free(); | |||
| m_opr_mempool.reorder_free(); | |||
| m_nr_endpoint_oprs = 0; | |||
| ThinHashMap<VarNode*, Var*> varmap; | |||
| for (auto orig_opr : *m_orig_opr_seq) { | |||
| auto time = m_seq.size(); | |||
| m_seq.emplace_back(m_opr_mempool.alloc_unique(orig_opr, time)); | |||
| auto opr = m_seq.back().get(); | |||
| m_nr_endpoint_oprs += opr->is_endpoint; | |||
| for (auto&& dep : orig_opr->node_prop().dep_map()) { | |||
| if (!OperatorNodeBase::NodeProp::is_device_value_dep(dep.second)) | |||
| continue; | |||
| auto iter = varmap.find(dep.first); | |||
| if (iter == varmap.end()) { | |||
| // input var needs not to be considered | |||
| continue; | |||
| } | |||
| auto ivar = iter->second; | |||
| bool exist = false; | |||
| for (auto i : opr->input) { | |||
| if (i == ivar) { | |||
| exist = true; | |||
| break; | |||
| } | |||
| } | |||
| if (exist) { | |||
| // same var for different inputs | |||
| continue; | |||
| } | |||
| opr->input.push_back(ivar); | |||
| auto&& prev_rec = ivar->access_rec.back(); | |||
| prev_rec.stride = time - prev_rec.opr->time; | |||
| ivar->access_rec.emplace_back(opr); | |||
| } | |||
| for (auto i : orig_opr->output()) { | |||
| auto var2memsize = m_par_modifier->m_mem_opt.var2memsize(); | |||
| auto iter = var2memsize->find(i); | |||
| if (iter == var2memsize->end()) { | |||
| // some vars are ignored; see split_into_cn2oprseq() | |||
| continue; | |||
| } | |||
| m_var_storage.emplace_back( | |||
| m_var_mempool.alloc_unique(i, iter->second, opr)); | |||
| auto ovar = m_var_storage.back().get(); | |||
| varmap[i] = ovar; | |||
| opr->output.push_back(ovar); | |||
| } | |||
| mgb_assert(!opr->output.empty()); | |||
| } | |||
| // remove unused output | |||
| for (auto&& i : m_seq) { | |||
| auto&& oarr = i->output; | |||
| for (size_t j = 0; j < oarr.size();) { | |||
| if (oarr[j]->access_rec.size() == 1) { | |||
| std::swap(oarr[j], oarr.back()); | |||
| oarr.pop_back(); | |||
| } else | |||
| ++j; | |||
| } | |||
| } | |||
| } | |||
| size_t SeqModifierForSublinearMemory::ModifyActionPlanner:: | |||
| calc_bottleneck_from_discard_plan() { | |||
| size_t cur_usage = 0, max_usage = 0; | |||
| @@ -394,7 +219,7 @@ size_t SeqModifierForSublinearMemory::ModifyActionPlanner:: | |||
| ++time; | |||
| }; | |||
| for (auto&& opr : m_seq) { | |||
| for (auto&& opr : seq()) { | |||
| for (auto&& i : opr->oprs_insert_before) | |||
| process_opr(i.get()); | |||
| process_opr(opr.get()); | |||
| @@ -480,7 +305,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::apply_discard_plan() { | |||
| mgb_assert(opr->time < block_end); | |||
| auto new_opr_storage = m_opr_mempool.alloc_unique( | |||
| auto new_opr_storage = opr_mempool().alloc_unique( | |||
| opr->orig_opr, static_cast<size_t>(DUPOPR_TIME)); | |||
| auto new_opr = new_opr_storage.get(); | |||
| @@ -497,7 +322,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::apply_discard_plan() { | |||
| Var* new_var = nullptr; | |||
| for (auto i : opr->output) { | |||
| auto&& ovar = m_var_mempool.alloc_unique(i->orig_var, i->size, | |||
| auto&& ovar = var_mempool().alloc_unique(i->orig_var, i->size, | |||
| new_opr); | |||
| new_opr->output.push_back(ovar.get()); | |||
| if (i == var) | |||
| @@ -507,7 +332,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::apply_discard_plan() { | |||
| auto ins = var_map.insert({i, ovar.get()}); | |||
| mgb_assert(ins.second); | |||
| m_var_storage.emplace_back(std::move(ovar)); | |||
| var_storage().emplace_back(std::move(ovar)); | |||
| } | |||
| mgb_assert(new_var); | |||
| return new_var; | |||
| @@ -515,7 +340,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::apply_discard_plan() { | |||
| add_dep(var); | |||
| }; | |||
| for (auto&& _raw_opr : m_seq) { | |||
| for (auto&& _raw_opr : seq()) { | |||
| auto opr = _raw_opr.get(); | |||
| for (auto i : opr->inputs_to_recompute) | |||
| @@ -640,8 +465,8 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::split_into_blocks( | |||
| m_blocks.clear(); | |||
| std::vector<Opr*> cur_block_member; | |||
| size_t i, j; | |||
| for (i = j = 0; i < m_seq.size() && j < split_point_set->size(); ++i) { | |||
| auto opr = m_seq[i].get(); | |||
| for (i = j = 0; i < seq().size() && j < split_point_set->size(); ++i) { | |||
| auto opr = seq()[i].get(); | |||
| cur_block_member.push_back(opr); | |||
| if (i != split_point_set->at(j)) | |||
| continue; | |||
| @@ -649,7 +474,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::split_into_blocks( | |||
| cur_block_member.clear(); | |||
| j++; | |||
| } | |||
| mgb_assert(i >= m_seq.size()); | |||
| mgb_assert(i >= seq().size()); | |||
| mgb_assert(j >= split_point_set->size()); | |||
| } | |||
| @@ -1081,7 +906,7 @@ void SeqModifierForSublinearMemory::InternalDeleter::operator()( | |||
| } | |||
| void SeqModifierForSublinearMemory::reset_opr_seq(const OprNodeArray& oprseq) { | |||
| m_var_map.clear(); | |||
| var_map().clear(); | |||
| m_opr2replace_info.clear(); | |||
| auto config = | |||
| MemoryOptimizerHelper::SubGraphConfig() | |||
| @@ -1099,7 +924,7 @@ void SeqModifierForSublinearMemory::reset_opr_seq(const OprNodeArray& oprseq) { | |||
| .add_bad_var_flag(VarNode::Flag::NO_SYS_MEM_ALLOC) | |||
| .add_bad_var_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE); | |||
| auto cn2oprseq = m_mem_opt.split_into_cn2oprseq(oprseq, config); | |||
| auto cn2oprseq = mem_opt().split_into_cn2oprseq(oprseq, config); | |||
| if (cn2oprseq->empty()) { | |||
| // empty graph | |||
| @@ -1175,7 +1000,7 @@ void SeqModifierForSublinearMemory::apply_action(SeqModifyAction& action, | |||
| // each operator should be set no more than once | |||
| auto set_priority = [&](OperatorNodeBase* opr) { | |||
| mgb_assert(modified_opr.insert(opr).second); | |||
| m_mem_opt.set_priority(opr, cur_priority++); | |||
| mem_opt().set_priority(opr, cur_priority++); | |||
| }; | |||
| auto on_opr_visited = [&](OperatorNodeBase* opr) { | |||
| @@ -1218,80 +1043,13 @@ void SeqModifierForSublinearMemory::apply_action(SeqModifyAction& action, | |||
| mgb_assert(action.empty()); | |||
| } | |||
| bool SeqModifierForSublinearMemory::replace_vars(const VarNodeArray& inputs) { | |||
| m_new_inputs.assign(inputs.begin(), inputs.end()); | |||
| bool changed = false; | |||
| for (auto&& i : m_new_inputs) { | |||
| auto iter = m_var_map.find(i); | |||
| if (iter != m_var_map.end()) { | |||
| i = iter->second; | |||
| changed = true; | |||
| } | |||
| } | |||
| return changed; | |||
| } | |||
| OperatorNodeBase* SeqModifierForSublinearMemory::copy_opr_from_new_inputs( | |||
| OperatorNodeBase* opr, bool recomp) { | |||
| auto config = opr->config(); | |||
| // update operator instance id to bybass the shallow copy's cache if | |||
| // it's a dup-opr-copying due to discarding. | |||
| // Don't update instance id by `this` pointer if it's a recomp-opr-copying | |||
| // because: | |||
| // 0) recomp-opr would be copied iff its input vars is changed | |||
| // 1) some pair of recomp-opr and dup-opr have the same inputs, params | |||
| // and config, we use instance id to differentiate them. | |||
| config.name(opr->name() + (recomp ? ":recomp" : ":dup")); | |||
| if (!recomp) { | |||
| config.update_instance_id(this); | |||
| } | |||
| // Note: if all outputs of op were placed on the same comp_node, since its | |||
| // stream maybe changed during seq_comp_node_opt, output's comp_node has | |||
| // higher priority than opr->config() | |||
| auto out_cn = opr->output(0)->comp_node(); | |||
| for (auto i : opr->output()) { | |||
| auto cn = i->comp_node(); | |||
| if (out_cn != cn) { | |||
| out_cn = {}; | |||
| break; | |||
| } | |||
| } | |||
| if (out_cn.valid()) | |||
| config.comp_node(out_cn); | |||
| auto opr_new = serialization::copy_opr_shallow(*opr, m_new_inputs, config); | |||
| mgb_assert(opr_new != opr); | |||
| auto&& out0 = opr->output(); | |||
| auto&& out1 = opr_new->output(); | |||
| mgb_assert(out0.size() == out1.size()); | |||
| bool stream_changed = false; | |||
| for (size_t i = 0; i < out0.size(); ++i) { | |||
| auto &&cn0 = out0[i]->comp_node(), | |||
| &&cn1 = out1[i]->comp_node(); | |||
| if (cn0 != cn1) { | |||
| mgb_assert(recomp); | |||
| mgb_assert(cn0.locator().type == cn1.locator().type && | |||
| cn0.locator().device == cn1.locator().device); | |||
| out1[i]->comp_node(cn0); | |||
| stream_changed = true; | |||
| } | |||
| m_var_map[out0[i]] = out1[i]; | |||
| } | |||
| if (stream_changed) { | |||
| opr_new->on_output_comp_node_stream_changed(); | |||
| } | |||
| return opr_new; | |||
| } | |||
| void SeqModifierForSublinearMemory::modify_endpoint_vars( | |||
| VarNodeArray& endpoints) { | |||
| auto comp_seq = MemoryOptimizerHelper::CompSeq(m_owner_graph, endpoints); | |||
| auto comp_seq = MemoryOptimizerHelper::CompSeq(owner_graph(), endpoints); | |||
| reset_opr_seq(*comp_seq.m_seq); | |||
| for (auto&& i : endpoints) { | |||
| auto iter = m_var_map.find(i); | |||
| if (iter != m_var_map.end()) { | |||
| auto iter = var_map().find(i); | |||
| if (iter != var_map().end()) { | |||
| i = iter->second; | |||
| } | |||
| } | |||
| @@ -1357,8 +1115,8 @@ SeqModifierForSublinearMemory::prev_min_bottleneck() { | |||
| SeqModifierForSublinearMemory::SeqModifierForSublinearMemory( | |||
| ComputingGraphImpl* owner, Config* config_p) | |||
| : m_config(config_p), m_mem_opt(owner), m_owner_graph(owner) {} | |||
| : SeqModifierBase(owner), m_config(config_p) {} | |||
| #endif // !MGB_ENABLE_SUBLINEAR | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -12,6 +12,7 @@ | |||
| #pragma once | |||
| #include "./memory_optimizer.h" | |||
| #include "./seq_modifier_base.h" | |||
| #include "megbrain/graph/cg.h" | |||
| #include "megbrain/utils/async_worker.h" | |||
| @@ -23,28 +24,31 @@ namespace cg { | |||
| * \brief modifying computing sequence, with basically the same idea of Training | |||
| * Deep Nets with Sublinear Memory Cost | |||
| */ | |||
| class SeqModifierForSublinearMemory { | |||
| /*! | |||
| * describes modifications that should be applied to an operator sequnce: | |||
| * maps from an opr to the oprs that should be duplicated and inserted | |||
| * before it. | |||
| */ | |||
| using SeqModifyAction = std::unordered_map<OperatorNodeBase*, OprNodeArray>; | |||
| using SplitPointSet = std::shared_ptr<std::vector<size_t>>; | |||
| class SeqModifierForSublinearMemory : public SeqModifierBase { | |||
| //! Config options | |||
| using Config = mgb::cg::ComputingGraph::Options::SublinearMemConfig; | |||
| Config* m_config; | |||
| public: | |||
| SeqModifierForSublinearMemory(ComputingGraphImpl* owner, Config* config_g); | |||
| //! replace endpoint vars by the ones that require more computing | |||
| void modify_endpoint_vars(VarNodeArray& endpoints); | |||
| //! check whether actual opr_seq is what we expect; throw InternalError | |||
| void sanity_check(const OprNodeArray& opr_seq); | |||
| const CompNode::UnorderedMap<size_t>& prev_min_bottleneck(); | |||
| private: | |||
| using SplitPointSet = std::shared_ptr<std::vector<size_t>>; | |||
| //! get modifications to be taken under some specific constraints | |||
| class ModifyActionPlanner; | |||
| //! search best modify action for opr seq on a single comp node | |||
| class ActionSearcherSingleCN; | |||
| struct Opr; | |||
| struct Var; | |||
| struct InternalDeleter { | |||
| void operator()(ActionSearcherSingleCN*) const; | |||
| void operator()(ModifyActionPlanner*) const; | |||
| @@ -67,32 +71,8 @@ class SeqModifierForSublinearMemory { | |||
| //! thread pool to run ModifyActionPlanner | |||
| FutureThreadPool<void> m_planner_thread_pool; | |||
| //! map from original var to replaced var | |||
| ThinHashMap<VarNode*, VarNode*> m_var_map; | |||
| VarNodeArray m_new_inputs; //!< setup by replace_vars | |||
| MemoryOptimizerHelper m_mem_opt; | |||
| ComputingGraphImpl* const m_owner_graph = nullptr; | |||
| CompNode::UnorderedMap<size_t> m_prev_min_bottleneck; | |||
| /*! | |||
| * \brief replace input vars according to m_var_map, and store results in | |||
| * m_new_inputs; | |||
| * \return whether any var is changed | |||
| */ | |||
| bool replace_vars(const VarNodeArray& inputs); | |||
| /*! | |||
| * \brief copy opr and set inputs to m_new_inputs, and add outputs in | |||
| * m_var_map | |||
| * \return new operator | |||
| */ | |||
| OperatorNodeBase* copy_opr_from_new_inputs(OperatorNodeBase* opr, | |||
| bool recomp); | |||
| //! restore computing sequence and modify operator priority | |||
| void reset_opr_seq(const OprNodeArray& oprseq); | |||
| @@ -107,27 +87,6 @@ class SeqModifierForSublinearMemory { | |||
| return std::make_shared<SplitPointSet::element_type>( | |||
| std::forward<Args>(args)...); | |||
| } | |||
| public: | |||
| SeqModifierForSublinearMemory(ComputingGraphImpl* owner, Config* config_g); | |||
| //! see memory_optimizer set_priority_before_opt | |||
| void set_priority_before_opt(const VarNodeArray& endpoints) { | |||
| m_mem_opt.set_priority_before_opt(endpoints); | |||
| } | |||
| //! see memory_optimizer restore_graph_option | |||
| void restore_graph_option() { | |||
| m_mem_opt.restore_graph_option(); | |||
| } | |||
| //! replace endpoint vars by the ones that require more computing | |||
| void modify_endpoint_vars(VarNodeArray& endpoints); | |||
| //! check whether actual opr_seq is what we expect; throw InternalError | |||
| void sanity_check(const OprNodeArray& opr_seq); | |||
| const CompNode::UnorderedMap<size_t>& prev_min_bottleneck(); | |||
| }; | |||
| } // namespace cg | |||