| @@ -33,16 +33,16 @@ class JITFusionPass::Impl final { | |||||
| CompNode::UnorderedMap<size_t> m_cn2max_nr_input; | CompNode::UnorderedMap<size_t> m_cn2max_nr_input; | ||||
| SubGraph::Rewriter m_rewriter; | SubGraph::Rewriter m_rewriter; | ||||
| SmallVector<std::unique_ptr<InternalGraphGenrator>> m_igraph_gen_storage; | |||||
| ThinHashMap<VarNode*, InternalGraphGenrator*> m_var2igraph_gen; | |||||
| SmallVector<std::unique_ptr<InternalGraphGenerator>> m_igraph_gen_storage; | |||||
| ThinHashMap<VarNode*, InternalGraphGenerator*> m_var2igraph_gen; | |||||
| //! map from var to its reader oprs and the corresponding dependency types | //! map from var to its reader oprs and the corresponding dependency types | ||||
| ThinHashMap<VarNode*, SmallVector<std::pair<OperatorNodeBase*, DepType>>> | ThinHashMap<VarNode*, SmallVector<std::pair<OperatorNodeBase*, DepType>>> | ||||
| m_var_readers; | m_var_readers; | ||||
| ThinHashSet<VarNode*> m_endpoint_set; | ThinHashSet<VarNode*> m_endpoint_set; | ||||
| //! create a new InternalGraphGenrator rooted at given opr | |||||
| InternalGraphGenrator* create_new_igraph_gen(OperatorNodeBase* opr); | |||||
| //! create a new InternalGraphGenerator rooted at given opr | |||||
| InternalGraphGenerator* create_new_igraph_gen(OperatorNodeBase* opr); | |||||
| //! process a single operator, maintaining m_var2igraph_gen | //! process a single operator, maintaining m_var2igraph_gen | ||||
| void process_opr(OperatorNodeBase* opr); | void process_opr(OperatorNodeBase* opr); | ||||
| @@ -51,11 +51,11 @@ class JITFusionPass::Impl final { | |||||
| //! check whether all oprs which depend on the var are in i_graph | //! check whether all oprs which depend on the var are in i_graph | ||||
| bool test_all_readers_in_the_graph(VarNode* var, | bool test_all_readers_in_the_graph(VarNode* var, | ||||
| InternalGraphGenrator* i_graph); | |||||
| InternalGraphGenerator* i_graph); | |||||
| //! check shape to determine whether the opr should be added to the internal | //! check shape to determine whether the opr should be added to the internal | ||||
| //! graph | //! graph | ||||
| bool check_shape(cg::OperatorNodeBase* opr, InternalGraphGenrator* i_graph); | |||||
| bool check_shape(cg::OperatorNodeBase* opr, InternalGraphGenerator* i_graph); | |||||
| //! use m_rewriter to update graph | //! use m_rewriter to update graph | ||||
| void update_graph(); | void update_graph(); | ||||
| @@ -155,7 +155,7 @@ void JITFusionPass::Impl::update_graph() { | |||||
| } | } | ||||
| bool JITFusionPass::Impl::test_all_readers_in_the_graph( | bool JITFusionPass::Impl::test_all_readers_in_the_graph( | ||||
| VarNode* var, InternalGraphGenrator* ig_gen) { | |||||
| VarNode* var, InternalGraphGenerator* ig_gen) { | |||||
| for (auto&& reader : m_var_readers.at(var)) { | for (auto&& reader : m_var_readers.at(var)) { | ||||
| if (reader.second & DepType::DEV_VALUE) { | if (reader.second & DepType::DEV_VALUE) { | ||||
| if (ig_gen->opr_set().count(reader.first) == 0) { | if (ig_gen->opr_set().count(reader.first) == 0) { | ||||
| @@ -167,7 +167,7 @@ bool JITFusionPass::Impl::test_all_readers_in_the_graph( | |||||
| } | } | ||||
| bool JITFusionPass::Impl::check_shape(cg::OperatorNodeBase* opr, | bool JITFusionPass::Impl::check_shape(cg::OperatorNodeBase* opr, | ||||
| InternalGraphGenrator* ig_gen) { | |||||
| InternalGraphGenerator* ig_gen) { | |||||
| if (!cg::is_static_var_shape(opr->output(0))) { | if (!cg::is_static_var_shape(opr->output(0))) { | ||||
| // currently we do not handle dynamic shape in JIT | // currently we do not handle dynamic shape in JIT | ||||
| return false; | return false; | ||||
| @@ -249,9 +249,9 @@ bool JITFusionPass::Impl::check_shape(cg::OperatorNodeBase* opr, | |||||
| } | } | ||||
| } | } | ||||
| InternalGraphGenrator* JITFusionPass::Impl::create_new_igraph_gen( | |||||
| InternalGraphGenerator* JITFusionPass::Impl::create_new_igraph_gen( | |||||
| OperatorNodeBase* opr) { | OperatorNodeBase* opr) { | ||||
| auto uptr = std::make_unique<InternalGraphGenrator>(opr); | |||||
| auto uptr = std::make_unique<InternalGraphGenerator>(opr); | |||||
| auto ptr = uptr.get(); | auto ptr = uptr.get(); | ||||
| m_igraph_gen_storage.emplace_back(std::move(uptr)); | m_igraph_gen_storage.emplace_back(std::move(uptr)); | ||||
| m_var2igraph_gen[opr->output(0)] = ptr; | m_var2igraph_gen[opr->output(0)] = ptr; | ||||
| @@ -267,7 +267,7 @@ void JITFusionPass::Impl::process_opr(OperatorNodeBase* opr) { | |||||
| } | } | ||||
| // dimshuffle should not be an endpoint, because megbrain has lazy | // dimshuffle should not be an endpoint, because megbrain has lazy | ||||
| // dimshuffle machanism | // dimshuffle machanism | ||||
| InternalGraphGenrator* ig_gen = nullptr; | |||||
| InternalGraphGenerator* ig_gen = nullptr; | |||||
| if (m_var2igraph_gen.count(opr->output(0)) == 0) { | if (m_var2igraph_gen.count(opr->output(0)) == 0) { | ||||
| // because of the reverse traversal, when an operator is being | // because of the reverse traversal, when an operator is being | ||||
| // processed but not in m_var2igraph_gen, means it is a endpoint of a | // processed but not in m_var2igraph_gen, means it is a endpoint of a | ||||
| @@ -81,12 +81,12 @@ InternalGraphPtr expand_executor_opr(const InternalGraphPtr& prev_igraph) { | |||||
| } // namespace | } // namespace | ||||
| InternalGraphGenrator::InternalGraphGenrator(cg::OperatorNodeBase* opr) | |||||
| InternalGraphGenerator::InternalGraphGenerator(cg::OperatorNodeBase* opr) | |||||
| : m_output{opr->output(0)} { | : m_output{opr->output(0)} { | ||||
| add_opr(opr); | add_opr(opr); | ||||
| } | } | ||||
| VarNode* InternalGraphGenrator::replace_graph_by_placeholder() { | |||||
| VarNode* InternalGraphGenerator::replace_graph_by_placeholder() { | |||||
| ThinHashMap<VarNode*, VarNode*> old2new; | ThinHashMap<VarNode*, VarNode*> old2new; | ||||
| auto cpu_default = CompNode::default_cpu(); | auto cpu_default = CompNode::default_cpu(); | ||||
| auto igraph_copy_opr_shallow = [cpu_default](OperatorNodeBase* opr, | auto igraph_copy_opr_shallow = [cpu_default](OperatorNodeBase* opr, | ||||
| @@ -163,7 +163,7 @@ VarNode* InternalGraphGenrator::replace_graph_by_placeholder() { | |||||
| return old2new.at(m_output); | return old2new.at(m_output); | ||||
| } | } | ||||
| InternalGraphPtr InternalGraphGenrator::generate() { | |||||
| InternalGraphPtr InternalGraphGenerator::generate() { | |||||
| m_input_idx = 0; | m_input_idx = 0; | ||||
| auto new_nd = replace_graph_by_placeholder(); | auto new_nd = replace_graph_by_placeholder(); | ||||
| @@ -172,7 +172,7 @@ InternalGraphPtr InternalGraphGenrator::generate() { | |||||
| return expand_executor_opr(igraph); | return expand_executor_opr(igraph); | ||||
| } | } | ||||
| size_t InternalGraphGenrator::get_cnt_input_if_add( | |||||
| size_t InternalGraphGenerator::get_cnt_input_if_add( | |||||
| cg::OperatorNodeBase* opr) const { | cg::OperatorNodeBase* opr) const { | ||||
| // minus 1 first because this opr should be removed from subgraph's input | // minus 1 first because this opr should be removed from subgraph's input | ||||
| size_t new_cnt_input = m_graph_input_set.size() - 1; | size_t new_cnt_input = m_graph_input_set.size() - 1; | ||||
| @@ -183,7 +183,7 @@ size_t InternalGraphGenrator::get_cnt_input_if_add( | |||||
| return new_cnt_input; | return new_cnt_input; | ||||
| } | } | ||||
| void InternalGraphGenrator::add_opr(cg::OperatorNodeBase* opr) { | |||||
| void InternalGraphGenerator::add_opr(cg::OperatorNodeBase* opr) { | |||||
| if (m_opr_set.count(opr)) { | if (m_opr_set.count(opr)) { | ||||
| // ignore duplicated oprs (which occur in tests) | // ignore duplicated oprs (which occur in tests) | ||||
| return; | return; | ||||
| @@ -253,7 +253,7 @@ void InternalGraphGenrator::add_opr(cg::OperatorNodeBase* opr) { | |||||
| } | } | ||||
| } | } | ||||
| void InternalGraphGenrator::find_reduce_opr_deps(cg::OperatorNodeBase* opr) { | |||||
| void InternalGraphGenerator::find_reduce_opr_deps(cg::OperatorNodeBase* opr) { | |||||
| mgb_assert(opr->same_type<opr::Reduce>() || | mgb_assert(opr->same_type<opr::Reduce>() || | ||||
| (opr->same_type<jit::JITExecutor>() && | (opr->same_type<jit::JITExecutor>() && | ||||
| try_cast_as_op<jit::JITExecutor>(opr)->has_reduce())); | try_cast_as_op<jit::JITExecutor>(opr)->has_reduce())); | ||||
| @@ -264,7 +264,7 @@ void InternalGraphGenrator::find_reduce_opr_deps(cg::OperatorNodeBase* opr) { | |||||
| cg::DepOprIter{cb}.add(opr); | cg::DepOprIter{cb}.add(opr); | ||||
| } | } | ||||
| void InternalGraphGenrator::find_oprs_depended_by_dimshuffle( | |||||
| void InternalGraphGenerator::find_oprs_depended_by_dimshuffle( | |||||
| cg::OperatorNodeBase* dimshuffle) { | cg::OperatorNodeBase* dimshuffle) { | ||||
| mgb_assert( | mgb_assert( | ||||
| dimshuffle->same_type<opr::Dimshuffle>() || | dimshuffle->same_type<opr::Dimshuffle>() || | ||||
| @@ -287,7 +287,7 @@ void InternalGraphGenrator::find_oprs_depended_by_dimshuffle( | |||||
| cg::DepOprIter{cb}.add(dimshuffle); | cg::DepOprIter{cb}.add(dimshuffle); | ||||
| } | } | ||||
| PlaceholderArray InternalGraphGenrator::to_placeholder_opr_arr( | |||||
| PlaceholderArray InternalGraphGenerator::to_placeholder_opr_arr( | |||||
| const VarNodeArray& vars) { | const VarNodeArray& vars) { | ||||
| PlaceholderArray ret(vars.size()); | PlaceholderArray ret(vars.size()); | ||||
| for (size_t i = 0; i < vars.size(); ++i) { | for (size_t i = 0; i < vars.size(); ++i) { | ||||
| @@ -76,12 +76,12 @@ private: | |||||
| * This object stores intermediate state during visiting the computing graph in | * This object stores intermediate state during visiting the computing graph in | ||||
| * JITFusionPass. | * JITFusionPass. | ||||
| * | * | ||||
| * The graph is iterated in reverse topological order. InternalGraphGenrator | |||||
| * The graph is iterated in reverse topological order. InternalGraphGenerator | |||||
| * starts with a single operator (i.e. the output node of the fused opr), and | * starts with a single operator (i.e. the output node of the fused opr), and | ||||
| * new oprs are gradually added into it. Thus the process is expanding a tree | * new oprs are gradually added into it. Thus the process is expanding a tree | ||||
| * rooted at the output node. | * rooted at the output node. | ||||
| */ | */ | ||||
| class InternalGraphGenrator { | |||||
| class InternalGraphGenerator { | |||||
| //! replace oprs in the graph of m_output and populate m_orig_inps, | //! replace oprs in the graph of m_output and populate m_orig_inps, | ||||
| //! m_placeholders | //! m_placeholders | ||||
| VarNode* replace_graph_by_placeholder(); | VarNode* replace_graph_by_placeholder(); | ||||
| @@ -95,7 +95,7 @@ class InternalGraphGenrator { | |||||
| void find_oprs_depended_by_dimshuffle(cg::OperatorNodeBase* opr); | void find_oprs_depended_by_dimshuffle(cg::OperatorNodeBase* opr); | ||||
| public: | public: | ||||
| explicit InternalGraphGenrator(cg::OperatorNodeBase* opr); | |||||
| explicit InternalGraphGenerator(cg::OperatorNodeBase* opr); | |||||
| //! generate the graph; this method can be called multiple times | //! generate the graph; this method can be called multiple times | ||||
| InternalGraphPtr generate(); | InternalGraphPtr generate(); | ||||
| @@ -54,7 +54,7 @@ void run<simple>(Backend backend, CompNode cn) { | |||||
| VarNodeArray inputs{a.node(), b.node(), c.node()}, outputs{y.node()}; | VarNodeArray inputs{a.node(), b.node(), c.node()}, outputs{y.node()}; | ||||
| auto ig_gen = | auto ig_gen = | ||||
| std::make_unique<InternalGraphGenrator>(y.node()->owner_opr()); | |||||
| std::make_unique<InternalGraphGenerator>(y.node()->owner_opr()); | |||||
| for (auto i : get_rev_topo_order(y)) { | for (auto i : get_rev_topo_order(y)) { | ||||
| if (!i->same_type<opr::Host2DeviceCopy>()) { | if (!i->same_type<opr::Host2DeviceCopy>()) { | ||||
| @@ -91,7 +91,7 @@ void run<grad>(Backend backend, CompNode cn) { | |||||
| VarNodeArray inputs{a.node(), b.node(), c.node()}, outputs{y.node()}; | VarNodeArray inputs{a.node(), b.node(), c.node()}, outputs{y.node()}; | ||||
| auto ig_gen = | auto ig_gen = | ||||
| std::make_unique<InternalGraphGenrator>(y.node()->owner_opr()); | |||||
| std::make_unique<InternalGraphGenerator>(y.node()->owner_opr()); | |||||
| for (auto i : get_rev_topo_order(y)) { | for (auto i : get_rev_topo_order(y)) { | ||||
| if (!i->same_type<opr::Host2DeviceCopy>()) { | if (!i->same_type<opr::Host2DeviceCopy>()) { | ||||
| @@ -540,7 +540,7 @@ void run<expand_jit_executor>(Backend backend, CompNode cn) { | |||||
| auto make_jit = [](SymbolVar target, const SymbolVarArray& inputs) { | auto make_jit = [](SymbolVar target, const SymbolVarArray& inputs) { | ||||
| auto y = target.node(); | auto y = target.node(); | ||||
| auto ig_gen = std::make_unique<InternalGraphGenrator>(y->owner_opr()); | |||||
| auto ig_gen = std::make_unique<InternalGraphGenerator>(y->owner_opr()); | |||||
| auto inputs_vptr = cg::to_var_node_array(inputs); | auto inputs_vptr = cg::to_var_node_array(inputs); | ||||
| for (auto i : get_rev_topo_order( | for (auto i : get_rev_topo_order( | ||||
| target, {inputs_vptr.begin(), inputs_vptr.end()})) { | target, {inputs_vptr.begin(), inputs_vptr.end()})) { | ||||
| @@ -830,9 +830,9 @@ TEST(TestJITFusionHalide, JITExecutor) { | |||||
| y = opr::reduce_sum(a + b, shape_of_b), | y = opr::reduce_sum(a + b, shape_of_b), | ||||
| z = opr::reduce_sum(a * b, shape_of_a); | z = opr::reduce_sum(a * b, shape_of_a); | ||||
| auto ig_gen_1 = | auto ig_gen_1 = | ||||
| std::make_unique<InternalGraphGenrator>(y.node()->owner_opr()); | |||||
| std::make_unique<InternalGraphGenerator>(y.node()->owner_opr()); | |||||
| auto ig_gen_2 = | auto ig_gen_2 = | ||||
| std::make_unique<InternalGraphGenrator>(z.node()->owner_opr()); | |||||
| std::make_unique<InternalGraphGenerator>(z.node()->owner_opr()); | |||||
| { | { | ||||
| ThinHashSet<VarNode*> nd_set; | ThinHashSet<VarNode*> nd_set; | ||||
| nd_set.insert(a.node()); | nd_set.insert(a.node()); | ||||
| @@ -85,7 +85,7 @@ void FusionChecker::ensure_init_graph() { | |||||
| SymbolVar jit_y; | SymbolVar jit_y; | ||||
| if (m_direct_build) { | if (m_direct_build) { | ||||
| auto ig_gen = std::make_unique<InternalGraphGenrator>( | |||||
| auto ig_gen = std::make_unique<InternalGraphGenerator>( | |||||
| m_truth_y.node()->owner_opr()); | m_truth_y.node()->owner_opr()); | ||||
| ThinHashSet<VarNode*> endpoints_set; | ThinHashSet<VarNode*> endpoints_set; | ||||
| for (size_t i = 0; i < m_nr_input; ++i) { | for (size_t i = 0; i < m_nr_input; ++i) { | ||||