GitOrigin-RevId: f666f6d700
tags/v0.5.0
| @@ -74,7 +74,7 @@ OperatorNodeBase* SubGraph::Rewriter::auto_replace_outputs( | |||
| auto &&ins = m_varmap.insert({out0[i], {true, nullptr}}); | |||
| mgb_assert(ins.second || ins.first->second.first, | |||
| "opr output already replaced"); | |||
| "opr output already replaced"); | |||
| // handle repeated call on the same opr | |||
| ins.first->second.second = out1[i]; | |||
| on_var_replaced(out0[i], out1[i], nullptr); | |||
| @@ -771,7 +771,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( | |||
| /* ================ ConstVarPropogateBase ================ */ | |||
| ConstVarPropogateBase::AddOprResult ConstVarPropogateBase::add_opr( | |||
| ConstVarPropogate::AddOprResult ConstVarPropogate::add_opr( | |||
| OperatorNodeBase *opr) { | |||
| using ProfFlag = OperatorNodeBase::NodeProp::Flag; | |||
| auto &&info = m_oprinfo[opr]; | |||
| @@ -834,7 +834,6 @@ ConstVarPropogateBase::AddOprResult ConstVarPropogateBase::add_opr( | |||
| #endif | |||
| info.max_size = max_input_size; | |||
| info.is_const = true; | |||
| on_midconst_opr(opr, max_input_size); | |||
| } | |||
| return make_ret(); | |||
| } | |||
| @@ -442,50 +442,6 @@ void ParamRedistributePass::apply(OptState &state) const { | |||
| /* ================ ParamFusePass ================ */ | |||
| class ParamFusePass::ConstVarPropogateWithSizeCheck final: | |||
| public ConstVarPropogateBase | |||
| { | |||
| public: | |||
| //! rewrite a var; reader == nullptr means needed by endpoint | |||
| using VarRewriter = std::function< | |||
| void(VarNode *var, OperatorNodeBase *reader)>; | |||
| ConstVarPropogateWithSizeCheck( | |||
| const ParamFusePass &pf, OptState &opt_state, | |||
| const VarRewriter &rewriter): | |||
| ConstVarPropogateBase{ConstVarType::IMMUTABLE_AND_PARAM}, | |||
| m_owner{pf}, m_opt_state{opt_state}, m_rewriter{rewriter} | |||
| { | |||
| } | |||
| private: | |||
| const ParamFusePass &m_owner; | |||
| OptState &m_opt_state; | |||
| VarRewriter m_rewriter; | |||
| void on_midconst_opr( | |||
| OperatorNodeBase *opr, size_t max_src_size) override { | |||
| for (auto var: opr->output()) { | |||
| if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) | |||
| continue; | |||
| auto osize = var_mem_size(var); | |||
| if (osize >= max_src_size && | |||
| osize - max_src_size > m_owner.m_param_grow_limit) { | |||
| return; | |||
| } | |||
| // const oprs should be evaluated when output is used by another | |||
| // non-const opr or output is needed by the user | |||
| if (m_opt_state.graph().endpoint_contain(var)) { | |||
| m_rewriter(var, nullptr); | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| /*! | |||
| * \brief get name for new param | |||
| */ | |||
| @@ -565,9 +521,15 @@ const char* ParamFusePass::name() const { | |||
| void ParamFusePass::apply(OptState &state) const { | |||
| auto rewriter = state.graph().make_rewriter(); | |||
| auto cg = state.graph().comp_graph(); | |||
| ConstVarPropogate cvprop{ConstVarType::IMMUTABLE_AND_PARAM}; | |||
| state.graph().iter([&cvprop](OperatorNodeBase *opr) { | |||
| cvprop.add_opr(opr); | |||
| }); | |||
| ThinHashSet<VarNode*> processed_var; | |||
| VarNamer var_namer; | |||
| // reader: null if used as endvar | |||
| auto replace_single_var = [&](VarNode *var, OperatorNodeBase *reader) { | |||
| if (!processed_var.insert(var).second) | |||
| @@ -619,9 +581,8 @@ void ParamFusePass::apply(OptState &state) const { | |||
| rewriter.replace_var(var, new_var.node(), log.c_str()); | |||
| }; | |||
| ConstVarPropogateWithSizeCheck cvprop{*this, state, replace_single_var}; | |||
| auto on_opr = [&](OperatorNodeBase *opr) { | |||
| auto add_ret = cvprop.add_opr(opr); | |||
| auto replace_opr = [&](OperatorNodeBase* opr) { | |||
| auto add_ret = cvprop.opr_rst(opr); | |||
| if (!add_ret.all_const_inp && add_ret.has_midconst_inp) { | |||
| for (auto i: opr->input()) { | |||
| if (cvprop.is_midconst(i)) { | |||
| @@ -631,9 +592,33 @@ void ParamFusePass::apply(OptState &state) const { | |||
| } | |||
| } | |||
| rewriter.auto_replace_outputs(opr); | |||
| //! we should deal with midconst var after auto_replace_outputs, as | |||
| //! on_midconst_opr will replace the endpoint output which may cause | |||
| //! double replace. | |||
| if (add_ret.all_const_inp) { | |||
| for (auto var : opr->output()) { | |||
| if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) | |||
| continue; | |||
| auto osize = ConstVarPropogate::var_mem_size(var); | |||
| if (osize >= cvprop.max_size(opr) && | |||
| osize - cvprop.max_size(opr) > m_param_grow_limit) { | |||
| return; | |||
| } | |||
| // const oprs should be evaluated when output is used by another | |||
| // non-const opr or output is needed by the user | |||
| if (state.graph().endpoint_contain(var)) { | |||
| replace_single_var(var, nullptr); | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| state.graph().iter(on_opr); | |||
| state.graph().iter(replace_opr); | |||
| rewriter.apply_inplace(); | |||
| } | |||
| @@ -490,28 +490,17 @@ namespace gopt { | |||
| * Usually you would want to use ConstVarPropogate, and this base class | |||
| * exists to avoid virtual dtor while allowing polymorphism. | |||
| */ | |||
| class ConstVarPropogateBase { | |||
| protected: | |||
| ~ConstVarPropogateBase() = default; | |||
| //! memory usage of a var | |||
| static size_t var_mem_size(VarNode *var) { | |||
| return var->dtype().size(var->shape().total_nr_elems()); | |||
| } | |||
| //! called after a const but non-source opr is visited | |||
| virtual void on_midconst_opr( | |||
| OperatorNodeBase *opr, size_t max_src_size) { | |||
| MGB_MARK_USED_VAR(opr); | |||
| MGB_MARK_USED_VAR(max_src_size); | |||
| } | |||
| class ConstVarPropogate{ | |||
| public: | |||
| explicit ConstVarPropogateBase(ConstVarType const_var_type): | |||
| explicit ConstVarPropogate(ConstVarType const_var_type): | |||
| m_const_var_type{const_var_type} | |||
| { | |||
| } | |||
| ConstVarPropogate() = default; | |||
| ~ConstVarPropogate() = default; | |||
| //! note that both attrs would be false if opr is impure or it is | |||
| //! not allowed to be replaced | |||
| struct AddOprResult { | |||
| @@ -527,12 +516,19 @@ namespace gopt { | |||
| AddOprResult add_opr(OperatorNodeBase *opr); | |||
| const AddOprResult& opr_rst(OperatorNodeBase *opr) const { | |||
| return m_oprinfo.at(opr).result; | |||
| } | |||
| bool is_const(OperatorNodeBase *opr) const { | |||
| return m_oprinfo.at(opr).is_const; | |||
| } | |||
| bool is_const(VarNode *var) const { | |||
| return is_const(var->owner_opr()); | |||
| } | |||
| size_t max_size(OperatorNodeBase *opr) const { | |||
| return m_oprinfo.at(opr).max_size; | |||
| } | |||
| //! whether a var is produced by non-source const opr | |||
| bool is_midconst(OperatorNodeBase *opr) const { | |||
| @@ -543,6 +539,11 @@ namespace gopt { | |||
| return is_midconst(var->owner_opr()); | |||
| } | |||
| //! memory usage of a var | |||
| static size_t var_mem_size(VarNode *var) { | |||
| return var->dtype().size(var->shape().total_nr_elems()); | |||
| } | |||
| private: | |||
| struct OprInfo { | |||
| bool processed = false, is_const = false; | |||
| @@ -556,11 +557,6 @@ namespace gopt { | |||
| }; | |||
| class ConstVarPropogate final: public ConstVarPropogateBase { | |||
| public: | |||
| using ConstVarPropogateBase::ConstVarPropogateBase; | |||
| }; | |||
| } // namespace gopt | |||
| } // namespace mgb | |||
| @@ -112,6 +112,52 @@ void warp_perspective_mat_gen(HostTensorND& mat, size_t N, size_t INP_H, | |||
| #endif | |||
| } // namespace | |||
| TEST(TestGoptInference, ParamFuseConstEndPoint) { | |||
| constexpr size_t SIZE = 23; | |||
| HostTensorGenerator<> gen; | |||
| auto host_x = gen({SIZE}), host_y = gen({1}), host_p = gen({1}); | |||
| auto graph = ComputingGraph::make(); | |||
| graph->options().graph_opt_level = 0; | |||
| auto x = opr::SharedDeviceTensor::make(*graph, *host_x), | |||
| y = opr::SharedDeviceTensor::make(*graph, *host_y), | |||
| p = opr::Host2DeviceCopy::make(*graph, host_p), | |||
| q = p + x, | |||
| a = y + 3, | |||
| z0 = a + q, | |||
| z1 = a + 4; | |||
| HostTensorND host_z0, host_z1; | |||
| SymbolVar z0_1, z1_1; | |||
| unpack_vector( | |||
| gopt::GraphOptimizer{}. | |||
| add_pass<gopt::ParamFusePass>(). | |||
| apply({{z1, z0}}).endpoint_vars(), | |||
| z1_1, z0_1); | |||
| auto func = graph->compile({make_callback_copy(z0_1, host_z0), | |||
| make_callback_copy(z1_1, host_z1)}); | |||
| func->to_json()->writeto_fpath( | |||
| output_file("TestGoptInference.ParamFuseEndPoint.json")); | |||
| func->execute(); | |||
| int nr_opr = 0; | |||
| func->iter_opr_seq([&](cg::OperatorNodeBase*) {++ nr_opr; return true; }); | |||
| ASSERT_EQ(8, nr_opr); | |||
| auto px = host_x->ptr<float>(), pz0 = host_z0.ptr<float>(); | |||
| auto yv = host_y->ptr<float>()[0], pv = host_p->ptr<float>()[0], | |||
| pz1 = host_z1.ptr<float>()[0]; | |||
| for (size_t i = 0; i < SIZE; ++ i) { | |||
| MGB_ASSERT_FLOAT_EQ(px[i] + yv + 3 + pv, pz0[i]); | |||
| } | |||
| MGB_ASSERT_FLOAT_EQ(yv + 7, pz1); | |||
| } | |||
| TEST(TestGoptInference, ParamFuse) { | |||
| constexpr size_t SIZE = 23; | |||
| HostTensorGenerator<> gen; | |||
| @@ -144,7 +190,7 @@ TEST(TestGoptInference, ParamFuse) { | |||
| func->execute(); | |||
| int nr_opr = 0; | |||
| func->iter_opr_seq([&](cg::OperatorNodeBase*op) {++ nr_opr; return true; }); | |||
| func->iter_opr_seq([&](cg::OperatorNodeBase*) {++ nr_opr; return true; }); | |||
| ASSERT_EQ(6, nr_opr); | |||
| auto px = host_x->ptr<float>(), pz = host_z.ptr<float>(), | |||