diff --git a/src/jit/impl/executor_opr.cpp b/src/jit/impl/executor_opr.cpp index d4c948c2..744af62f 100644 --- a/src/jit/impl/executor_opr.cpp +++ b/src/jit/impl/executor_opr.cpp @@ -88,15 +88,34 @@ JITExecutor::JITExecutor(const InternalGraphPtr& internal_graph, cg::add_workspace_output(this); } + // check if output of internal_graph is depend on all placeholders + size_t nr_placeholders = internal_graph_ptr()->placeholders().size(); + std::vector used(nr_placeholders, false); // check if there is reduce or dimshuffle opr - cg::DepOprIter{[this](cg::OperatorNodeBase* opr) { + cg::DepOprIter{[this, nr_placeholders, &used](cg::OperatorNodeBase* opr) { if (opr->same_type()) { m_feature_bits |= JITFeatureBits::REDUCE; } if (opr->same_type()) { m_feature_bits |= JITFeatureBits::DIMSHUFFLE; } + if (auto ph = opr->try_cast_final()) { + mgb_assert(ph->input_id() < nr_placeholders, + "bad placeholders %s in JITExecutor %s", + ph->cname(), cname()); + used[ph->input_id()] = true; + } }}.add(internal_graph->output()); + + for (size_t i = 0; i < nr_placeholders; ++ i) { + mgb_assert(used[i], + "placeholder %s is not depended on the output of %s", + internal_graph_ptr()->placeholders()[i]->cname(), cname()); + } + + if (has_dimshuffle()) { + prepare_dimshuffle(); + } } void JITExecutor::add_input_layout_constraint() { @@ -151,14 +170,14 @@ void JITExecutor::scn_do_execute() { //! can be ignored void JITExecutor::do_dimshuffle() { - auto get_dimshuffled_layout = [](const TensorLayout& ily, int32_t* pattern, - size_t pattern_len) { + static auto get_dimshuffled_layout = [](const TensorLayout& ily, + std::vector pattern) { TensorLayout oly{ily.dtype}; - oly.ndim = pattern_len; + oly.ndim = pattern.size(); bool input_used[TensorLayout::MAX_NDIM] = {0}; - for (uint32_t idx = 0; idx < pattern_len; ++idx) { + for (uint32_t idx = 0; idx < pattern.size(); ++idx) { auto i = pattern[idx]; if (i < 0) { oly.shape[idx] = 1; @@ -179,53 +198,20 @@ void JITExecutor::do_dimshuffle() { return oly; }; - // DFS to make sure traverse the dimshuffles in one branch - std::unordered_set visited; - std::vector stack(0); - std::vector idx(0); // input index - stack.push_back(m_internal_graph->output()->owner_opr()); - idx.push_back(0); - - while (!stack.empty()) { - if (idx.back() < stack.back()->input().size() && - !visited.count(stack.back()->input(idx.back()))) { - visited.insert(stack.back()->input(idx.back())); - stack.push_back(stack.back()->input(idx.back())->owner_opr()); - if (stack.back()->same_type()) { - auto jitph = gopt::try_cast_as_op(stack.back()); - size_t input_id = jitph->input_id(); - auto&& input = m_args.inputs[input_id]; - - for (int i = stack.size() - 1; i >= 0; --i) { - if (stack[i]->same_type()) { - auto param = - stack[i]->cast_final_safe() - .param(); - - mgb_assert(input.layout.ndim == param.ndim, - "input ndim mismatch for Dimshuffle: " - "expect=%u " - "actual=%zu", - param.ndim, input.layout.ndim); - auto dimshuffled_layout = get_dimshuffled_layout( - input.layout, param.pattern, param.pattern_len); - input.layout = dimshuffled_layout; - } - } - - stack.pop_back(); - ++idx.back(); - } else { - idx.push_back(0); - } - } else { - stack.pop_back(); - idx.pop_back(); - if (!stack.empty()) - ++idx.back(); - } + for (auto&& i : m_internal_graph->placeholders()) { + auto&& input = m_args.inputs[i->input_id()]; + auto&& iter = m_jitph2dimshuffle.find(i); + if (iter == m_jitph2dimshuffle.end()) continue; + auto&& param = iter->second; + mgb_assert(input.layout.ndim == param.second, + "input ndim mismatch for Dimshuffle: " + "expect=%u " + "actual=%zu", + param.second, input.layout.ndim); + auto dimshuffled_layout = get_dimshuffled_layout( + input.layout, param.first); + input.layout = dimshuffled_layout; } - } void JITExecutor::update_args() { @@ -259,7 +245,9 @@ void JITExecutor::update_args() { } //! dimshuffle opr need to change the input. - do_dimshuffle(); + if (has_dimshuffle()) { + do_dimshuffle(); + } if (m_compiler->property().contain_flag(CPFlag::NEED_INPUT_COLLAPSE)) { // collective collapse datum layout, try to reduce the output ndim @@ -304,6 +292,82 @@ void JITExecutor::update_args() { m_args.need_update = false; } +void JITExecutor::prepare_dimshuffle() { + std::unordered_set visited; + std::vector stack(0); + std::vector idx(0); // input index + using Param = DimshuffleParam; + std::vector dimshuffle_stack; + + auto merge_dimshuffle = [&](const opr::Dimshuffle::Param& p) { + if (dimshuffle_stack.empty()) { + dimshuffle_stack.emplace_back(); + auto&& param = dimshuffle_stack.back(); + param.first.insert(param.first.end(), p.pattern, p.pattern + p.pattern_len); + param.second = p.ndim; + } else { + // merge(p, src) -> param and it has performing dimshuffle(dimshuffle(x, p), src) + // is equivalent to dimshuffle(x, param) + dimshuffle_stack.emplace_back(); + auto&& param = dimshuffle_stack.back(); + auto&& src = dimshuffle_stack[dimshuffle_stack.size() - 2]; + mgb_assert(p.pattern_len == src.second); + param.first.resize(src.first.size()); + for (size_t i = 0; i < src.first.size(); ++ i) { + if (src.first[i] == -1) { + param.first[i] = -1; + } else { + param.first[i] = p.pattern[src.first[i]]; + } + } + param.second = p.ndim; + } + }; + auto push_back = [&](cg::OperatorNodeBase* op) { + mgb_assert(!op->same_type()); + if (auto o = op->try_cast_final()) { + merge_dimshuffle(o->param()); + } + stack.push_back(op); + idx.push_back(0); + }; + auto pop_back = [&]() { + auto&& op = stack.back(); + if (op->same_type()) { + dimshuffle_stack.pop_back(); + } + stack.pop_back(); + idx.pop_back(); + }; + + push_back(m_internal_graph->output()->owner_opr()); + + while (!stack.empty()) { + if (idx.back() < stack.back()->input().size()) { + auto cur_opr = stack.back()->input(idx.back())->owner_opr(); + if (visited.insert(cur_opr).second) { + if (auto jitph = cur_opr->try_cast_final()) { + if (!dimshuffle_stack.empty()) { + mgb_assert( + m_jitph2dimshuffle.emplace(jitph, dimshuffle_stack.back()).second, + "already visited JITPlaceholder %s", + jitph->cname()); + } + ++ idx.back(); + } else { + push_back(cur_opr); + } + } else { + ++ idx.back(); + } + } else { + pop_back(); + if (!stack.empty()) + ++ idx.back(); + } + } +} + const JITExecutor::Args& JITExecutor::args() const { if (m_args.need_update) { const_cast(this)->update_args(); @@ -383,6 +447,56 @@ megdnn::TensorShape JITExecutor::broadcasted_input_shape() const { #if MGB_ENABLE_GRAD +namespace { +class InternalGraphRewriter { + ThinHashMap m_var_map; + VarNode* m_dest_var; + VarNodeArray m_new_inp; + VarNode* get_var(VarNode* var) { + auto&& iter = m_var_map.find(var); + if (iter != m_var_map.end()) { + return iter->second; + } + return var; + } +public: + InternalGraphRewriter(VarNode* dest_var) + :m_dest_var{dest_var}{} + void iter(thin_function&& cb) { + m_var_map.clear(); + cg::DepOprIter{std::move(cb)}.add(m_dest_var->owner_opr()); + m_dest_var = get_var(m_dest_var); + } + VarNode* dest_var() { + return m_dest_var; + } + void replace_var(VarNode* src, VarNode* dst) { + // Note: do not perform var replacing recursively + // when we extract used placeholders from internal graph, we don't + // consider placeholder replacement pair (a to b), (b to c) as a + // var replacing chain (a to b to c) but as a injective function + // from (a, b) to (b, c) + // in other cases, each var node would be passed as \p src or + // \p dst at most once + m_var_map[src] = dst; + } + void auto_replace_outputs(cg::OperatorNodeBase* opr) { + // in JIT internal graph, output size of opr is always 1 + mgb_assert(opr->usable_output().size() == 1); + m_new_inp.clear(); + bool need_replace = false; + for (auto&& i : opr->input()) { + auto inp = get_var(i); + m_new_inp.push_back(inp); + need_replace |= (inp != i); + } + if (need_replace) { + auto new_op = serialization::copy_opr_shallow(*opr, m_new_inp); + replace_var(opr->output(0), new_op->output(0)); + } + } +}; +} // anonymous namespace MGB_IMPL_OPR_GRAD(JITExecutor) { VarNodeArray grad_inputs; for (auto input : opr.input()) @@ -404,49 +518,120 @@ MGB_IMPL_OPR_GRAD(JITExecutor) { if (gx.node()->owner_opr()->same_type()) { return opr::InvalidGrad::make(opr, wrt_idx); } + // early return if grad expression is single node + for (size_t i = 0; i < fwd_igraph_ptr->placeholders().size(); ++i) { + if (gx.node() == fwd_igraph_ptr->placeholders()[i]->output(0)) { + return grad_inputs[i]; + } + } + if (gx.node() == og_ph.node()) { + return out_grad[0]; + } + if (gx.node() == fwd_igraph_ptr->output()) { + return opr.output(0); + } + if (auto imm = gopt::try_cast_as_op(gx.node()->owner_opr())) { + HostTensorND hval{grad_inputs[0]->comp_node()}; + hval.copy_from(imm->value()).sync(); + return opr::ImmutableTensor::make(*imm->owner_graph(), hval).node(); + } + + // replace output var in internal graph with output placeholder, so + // we could forward opr.output(computeed by forward JITExecutor) into + // placeholder to avoid redundant computation + InternalGraphRewriter rewriter{gx.node()}; + rewriter.iter([&rewriter, &fwd_igraph_ptr, + &output_ph](cg::OperatorNodeBase* opr) { + if (opr == fwd_igraph_ptr->output()->owner_opr()) { + rewriter.replace_var(opr->output(0), output_ph.node()); + return; + } + rewriter.auto_replace_outputs(opr); + }); + + static auto expand_into_origin_graph = [](cg::OperatorNodeBase* opr, + InternalGraphRewriter& rewriter, const VarNodeArray& grad_inputs) { + if (auto ph = gopt::try_cast_as_op(opr)) { + rewriter.replace_var( + opr->output(0), grad_inputs.at(ph->input_id())); + return; + } + if (auto imm = gopt::try_cast_as_op(opr)) { + HostTensorND hval{grad_inputs[0]->comp_node()}; + hval.copy_from(imm->value()).sync(); + rewriter.replace_var(opr->output(0), + opr::ImmutableTensor::make(*opr->owner_graph(), hval).node()); + return; + } + rewriter.auto_replace_outputs(opr); + }; + if (opr.compiler()->property().feature_bits & JITFeatureBits::REDUCE) { // expand the gradient graph into the original graph to handle bcast // oprs - ThinHashMap old2new; - VarNodeArray new_inp; - auto on_opr = [&old2new, &grad_inputs, - &new_inp](cg::OperatorNodeBase* opr) { + using namespace std::placeholders; + rewriter.iter(std::bind(expand_into_origin_graph, _1, + std::ref(rewriter), std::cref(grad_inputs))); + return rewriter.dest_var(); + } else { + VarNodeArray new_grad_inputs; + PlaceholderArray placeholders; + bool all_inp_const = true; + // gx was not depend on all JITPlaceholders so we need to extract used + // placeholders and build a new internal graph + rewriter.iter([&rewriter, &grad_inputs, &new_grad_inputs, + &placeholders, &all_inp_const](cg::OperatorNodeBase* opr) { if (auto ph = gopt::try_cast_as_op(opr)) { - old2new[opr->output(0)] = grad_inputs.at(ph->input_id()); - return; - } - if (auto imm = gopt::try_cast_as_op(opr)) { - HostTensorND hval{grad_inputs[0]->comp_node()}; - hval.copy_from(imm->value()).sync(); - old2new[opr->output(0)] = - opr::ImmutableTensor::make(*opr->owner_graph(), hval) - .node(); + new_grad_inputs.push_back(grad_inputs[ph->input_id()]); + auto new_ph = JITPlaceholder::make( + new_grad_inputs.back(), placeholders.size()) + .node()->owner_opr(); + placeholders.push_back(new_ph->try_cast_final()); + mgb_assert(placeholders.back()); + rewriter.replace_var(opr->output(0), new_ph->output(0)); + if (!cg::is_const_var_value(new_grad_inputs.back())) { + all_inp_const = false; + } return; } - new_inp.clear(); - for (auto inp : opr->input()) { - new_inp.push_back(old2new.at(inp)); - } - auto new_opr = serialization::copy_opr_shallow(*opr, new_inp); - old2new[opr->output(0)] = new_opr->output(0); - }; - cg::DepOprIter{on_opr}.add(gx.node()); - return old2new.at(gx.node()); - } else { - PlaceholderArray placeholders = fwd_igraph_ptr->placeholders(); - for (SymbolVar i : {output_ph, og_ph}) { - placeholders.push_back( - &i.node()->owner_opr()->cast_final_safe()); + rewriter.auto_replace_outputs(opr); + }); + if (all_inp_const) { + // if all_inp_const, expand grad graph into origin graph by replace + // placeholders with const inputs, so it could benefit from static + // infer and const folding mechanism + using namespace std::placeholders; + rewriter.iter(std::bind(expand_into_origin_graph, _1, + std::ref(rewriter), std::cref(new_grad_inputs))); + return rewriter.dest_var(); } - for (size_t i = 0; i < placeholders.size(); ++i) { - if (gx.node() == placeholders[i]->output(0)) { - return grad_inputs[i]; + gx = rewriter.dest_var(); + + auto shape_infer = fwd_igraph_ptr->shape_infer(); + if (opr.has_dimshuffle()) { + auto&& iter = opr.dimshuffle_params().find( + fwd_igraph_ptr->placeholders()[wrt_idx]); + if (iter != opr.dimshuffle_params().end()) { + auto&& pattern = iter->second.first; + auto&& ndim = iter->second.second; + std::vector back(ndim, -1); + for (size_t i = 0; i < pattern.size(); i ++) { + // outdim[i] is indim[j] + auto j = pattern[i]; + if (j >= 0) { + mgb_assert(back[j] == -1, + "taking grad for Dimshuffle with duplicated " + "input axis unsupported"); + back[j] = i; + } + } + shape_infer = opr::Dimshuffle::make(shape_infer, back, pattern.size()).node(); } } auto grad_ig = std::make_shared( - gx.node(), fwd_igraph_ptr->shape_infer(), nullptr, + gx.node(), shape_infer, nullptr, std::move(placeholders)); - auto grad_jit = JITExecutor::make(grad_ig, grad_inputs); + auto grad_jit = JITExecutor::make(grad_ig, new_grad_inputs); if (opr.input_broadcastable()[wrt_idx]) { grad_jit = opr::reduce_sum( diff --git a/src/jit/impl/placeholder_opr.cpp b/src/jit/impl/placeholder_opr.cpp index 85a1000e..9de43199 100644 --- a/src/jit/impl/placeholder_opr.cpp +++ b/src/jit/impl/placeholder_opr.cpp @@ -26,7 +26,6 @@ JITPlaceholder::JITPlaceholder(VarNode* src_var, size_t id, InpType inp_type) {}), m_inp_type{inp_type}, m_id{id} { - add_equivalence_component>(m_id); mgb_assert(src_var->dtype().category() == DTypeCategory::FLOAT || src_var->dtype().category() == DTypeCategory::INT, "JIT can only be applied to float/int operators, got %s", diff --git a/src/jit/include/megbrain/jit/executor_opr.h b/src/jit/include/megbrain/jit/executor_opr.h index 9f9f5159..dabfede6 100644 --- a/src/jit/include/megbrain/jit/executor_opr.h +++ b/src/jit/include/megbrain/jit/executor_opr.h @@ -35,6 +35,7 @@ MGB_DEFINE_OPR_CLASS(JITExecutor, cg::SingleCNOperatorNodeBase) // { using ModeTrait = megdnn::Elemwise::ModeTrait; InternalGraphPtr m_internal_graph; + using DimshuffleParam = std::pair, uint32_t>; public: using Mode = opr::Elemwise::Mode; @@ -112,6 +113,11 @@ public: return static_cast(m_feature_bits & JITFeatureBits::DIMSHUFFLE); } + const ThinHashMap& + dimshuffle_params() const { + return m_jitph2dimshuffle; + } + //! get broadcasted shape of inputs megdnn::TensorShape broadcasted_input_shape() const; @@ -124,8 +130,14 @@ private: Compiler* const m_compiler = nullptr; Executable* m_executable = nullptr; std::vector m_input_broadcastable; + // JITPlaceHolder -> pair of (dimshuffle pattern, ndim) + // do DFS on internal graph only once in prepare_dimshuffle(), so we can + // easily get the dimshuffle param which should be applied on given + // JITPlaceholder + ThinHashMap m_jitph2dimshuffle; void update_args(); void do_dimshuffle(); + void prepare_dimshuffle(); NodeProp* do_make_node_prop() const override; }; diff --git a/src/jit/include/megbrain/jit/internal_graph.h b/src/jit/include/megbrain/jit/internal_graph.h index 5b41d0df..94715536 100644 --- a/src/jit/include/megbrain/jit/internal_graph.h +++ b/src/jit/include/megbrain/jit/internal_graph.h @@ -61,8 +61,6 @@ public: const PlaceholderArray& placeholders() const { return m_placeholders; } - static InternalGraphPtr expand_excutor_op(const InternalGraphPtr&); - private: // For compilation cache, if the output_for_cache is same means the // expression tree is same. diff --git a/src/jit/test/fusion.cpp b/src/jit/test/fusion.cpp index 60a4a7ac..82606b5b 100644 --- a/src/jit/test/fusion.cpp +++ b/src/jit/test/fusion.cpp @@ -1435,6 +1435,16 @@ TEST(TestJITNvrtc, DimshuffleGrad) { funcs.second->execute(); MGB_ASSERT_TENSOR_NEAR(host_y1, host_y2, 1e-3); } + { + FusionChecker checker{2, + [](const SymbolVarArray& inp) -> SymbolVar { + auto var = opr::Dimshuffle::make(inp[0], {1, 2, 3, 0}); + return inp[1] * var; + }, + CompNode::load("gpu0")}; + checker.set_jit_level(1) + .run({TensorShape{1, 2, 3, 4}, {2, 3, 4, 1}}); + } } #endif // MGB_JIT diff --git a/src/jit/test/helper.cpp b/src/jit/test/helper.cpp index 6889c04d..c0311a5d 100644 --- a/src/jit/test/helper.cpp +++ b/src/jit/test/helper.cpp @@ -98,7 +98,7 @@ void FusionChecker::ensure_init_graph() { } else { ComputingGraph::Options opt; opt.graph_opt_level = 3; - opt.graph_opt.jit = 2; + opt.graph_opt.jit = m_jit_level; unpack_vector(gopt::GraphOptimizer{} .add_preset_passes(true, nullptr, &opt) .apply({{m_truth_y}}) diff --git a/src/jit/test/helper.h b/src/jit/test/helper.h index 6c3b8eca..33d3e1b7 100644 --- a/src/jit/test/helper.h +++ b/src/jit/test/helper.h @@ -65,6 +65,13 @@ public: return *this; } + //! set jit level, default is 2, see graph_opt.jit in graph options + //! for more details + FusionChecker& set_jit_level(uint8_t jit_level) { + m_jit_level = jit_level; + return *this; + } + /*! * \brief run and check correctness * @@ -76,6 +83,7 @@ private: bool m_check_opr_type = true; bool m_direct_build = false; const size_t m_nr_input; + uint8_t m_jit_level = 2; const CompNode m_comp_node; HostTensorGenerator<> m_input_gen; SmallVector> m_inputs_val;