GitOrigin-RevId: eaad25a7ef
tags/v1.2.0
| @@ -514,6 +514,16 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( | |||||
| optimizer.add_passes_for_optimize_options(options().graph_opt, true); | optimizer.add_passes_for_optimize_options(options().graph_opt, true); | ||||
| optimizer.apply_inplace(dest_vars); | optimizer.apply_inplace(dest_vars); | ||||
| if (sopr_stat.has_shape_hint) { | |||||
| // FIXME(zhangxuanrun): strictly speaking, it could and has to remove | |||||
| // ShapeHints even they were occured in subgraph | |||||
| mgb_assert(!m_parent_graph, "can not use ShapeHint in subgraph"); | |||||
| // always need remove shape hint | |||||
| gopt::GraphOptimizer opt; | |||||
| opt.add_pass<gopt::RemoveShapeHintPass>(); | |||||
| opt.apply_inplace(dest_vars); | |||||
| } | |||||
| const OprNodeArray* opr_seq = nullptr; | const OprNodeArray* opr_seq = nullptr; | ||||
| CompSeqExtraInfo extra_info; | CompSeqExtraInfo extra_info; | ||||
| cmpnt.seq_comp_node_opt.optimize_comp_nodes(dest_vars); | cmpnt.seq_comp_node_opt.optimize_comp_nodes(dest_vars); | ||||
| @@ -564,6 +564,9 @@ void ExtraDependencyMerger::on_opr(OperatorNodeBase* opr) { | |||||
| sopr_stat->has_virtual_grad = true; | sopr_stat->has_virtual_grad = true; | ||||
| } | } | ||||
| #endif | #endif | ||||
| if (sopr_stat && opr->same_type<opr::ShapeHint>()) { | |||||
| sopr_stat->has_shape_hint = true; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -149,6 +149,7 @@ SymbolVar current_grad_target(ComputingGraph &graph); | |||||
| struct SpecialOprStat { | struct SpecialOprStat { | ||||
| bool has_virtual_grad = false; | bool has_virtual_grad = false; | ||||
| bool has_shape_hint = false; | |||||
| }; | }; | ||||
| /*! | /*! | ||||
| @@ -678,6 +678,11 @@ GraphOptimizer& GraphOptimizer::add_preset_passes( | |||||
| add_pass<ParamMergePass>(); | add_pass<ParamMergePass>(); | ||||
| add_pass<FuseDeconvCvtPass>(); | add_pass<FuseDeconvCvtPass>(); | ||||
| } | } | ||||
| if (inference_opt) { | |||||
| // remove shape hint after inference optimization | |||||
| add_pass<RemoveShapeHintPass>(); | |||||
| } | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -1055,4 +1055,30 @@ void PackAllReduceReplacePass::insert_packed_oprs( | |||||
| #endif // MGB_ENABLE_OPR_MM | #endif // MGB_ENABLE_OPR_MM | ||||
| /* ======================= RemoveShapeHintPass ====================== */ | |||||
| const char* RemoveShapeHintPass::name() const { | |||||
| return "remove_shape_hint"; | |||||
| } | |||||
| void RemoveShapeHintPass::apply(OptState& opt) const { | |||||
| MIDOUT_B("RemoveShapeHintPass::apply") | |||||
| opt.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_DTYPE); | |||||
| auto rewriter = opt.graph().make_rewriter(); | |||||
| auto on_opr = [&](OperatorNodeBase* opr) { | |||||
| if (auto sh = try_cast_as_op<opr::ShapeHint>(opr)) { | |||||
| auto inp = rewriter.get_var(sh->input(0)); | |||||
| rewriter.replace_var(sh->output(0), inp, | |||||
| mgb_cstr_log("remove shape hint")); | |||||
| return; | |||||
| } | |||||
| rewriter.auto_replace_outputs(opr); | |||||
| }; | |||||
| opt.graph().iter(on_opr); | |||||
| rewriter.apply_inplace(); | |||||
| MIDOUT_E | |||||
| } | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -141,6 +141,12 @@ namespace gopt { | |||||
| ThinHashMap<VarNode*, VarNode*>& replace_map, int priority); | ThinHashMap<VarNode*, VarNode*>& replace_map, int priority); | ||||
| }; | }; | ||||
| class RemoveShapeHintPass final : public Pass { | |||||
| public: | |||||
| const char* name() const override; | |||||
| void apply(OptState& opt) const override; | |||||
| }; | |||||
| } // namespace gopt | } // namespace gopt | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -840,4 +840,57 @@ SymbolVar RequireInputDynamicStorage::make(const SymbolVar input, | |||||
| input.node(), config); | input.node(), config); | ||||
| } | } | ||||
| /* ===================== ShapeHint ===================== */ | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShapeHint); | |||||
| void ShapeHint::scn_do_execute() { | |||||
| mgb_assert(0); | |||||
| } | |||||
| void ShapeHint::init_output_static_infer_desc() { | |||||
| using namespace cg::static_infer; | |||||
| auto infer_shp = [this](TensorShape& dest, const InpVal&) -> bool { | |||||
| const TensorShape* inferred = nullptr; | |||||
| if (cg::is_static_var_shape(input(0))) { | |||||
| inferred = owner_graph()->static_infer_manager().infer_shape_fallible(input(0)); | |||||
| } | |||||
| if (inferred) { | |||||
| dest = *inferred; | |||||
| if (!dest.eq_shape(m_shape)) { | |||||
| mgb_log_warn( | |||||
| "given shape hint on var %s is different from inferred shape, " | |||||
| "hint %s vs inferred %s", cg::dump_var_info({input(0)}).c_str(), | |||||
| m_shape.to_string().c_str(), dest.to_string().c_str()); | |||||
| } | |||||
| } else { | |||||
| dest = m_shape; | |||||
| } | |||||
| return dest.ndim; | |||||
| }; | |||||
| owner_graph()->static_infer_manager().register_shape_infer( | |||||
| output(0), {m_is_const ? SourceType::CONSTANT : SourceType::MUTABLE, {}, infer_shp}); | |||||
| } | |||||
| ShapeHint::ShapeHint(VarNode* inp, TensorShape shape, | |||||
| bool is_const, const OperatorNodeConfig& config) | |||||
| : Super{inp->owner_graph(), config, "shape_hint", {inp}}, | |||||
| m_shape(shape), m_is_const(is_const) { | |||||
| add_input({inp}); | |||||
| add_output(None); | |||||
| } | |||||
| SymbolVar ShapeHint::make(SymbolVar inp, TensorShape shape, | |||||
| bool is_const, const OperatorNodeConfig& config) { | |||||
| return inp.insert_single_output_opr<ShapeHint>(inp.node(), shape, is_const, config); | |||||
| } | |||||
| #if MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(ShapeHint) { | |||||
| // since the shape of output(0) could be inferred, no need to | |||||
| // give hint on out_grad(0) | |||||
| return out_grad.at(0); | |||||
| } | |||||
| #endif | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -90,4 +90,15 @@ decl_opr( | |||||
| params='Empty' | params='Empty' | ||||
| ) | ) | ||||
| decl_raw_opr( | |||||
| 'shape_hint', | |||||
| desc='a special op providing shape hint only used in graph compilation', | |||||
| inputs=[Doc('input', 'input var the shape hint was on'), | |||||
| Doc('shape', 'given hint shape', 'list of int'), | |||||
| Doc('is_const', 'whether treat given shape as constant', 'bool', 'False')], | |||||
| body=[ | |||||
| 'output = _mgb._Opr.shape_hint(input, shape, is_const, config)' | |||||
| ] | |||||
| ) | |||||
| # vim: ft=python | # vim: ft=python | ||||
| @@ -153,6 +153,17 @@ namespace opr { | |||||
| #endif | #endif | ||||
| MGB_SEREG_OPR(PersistentOutputStorage, 1); | MGB_SEREG_OPR(PersistentOutputStorage, 1); | ||||
| cg::OperatorNodeBase* opr_shallow_copy_shape_hint( | |||||
| const serialization::OprShallowCopyContext &ctx, | |||||
| const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, | |||||
| const OperatorNodeConfig &config) { | |||||
| auto &&opr = opr_.cast_final_safe<ShapeHint>(); | |||||
| mgb_assert(inputs.size() == 1); | |||||
| return ShapeHint::make(inputs[0], opr.shape(), opr.is_const(), config) | |||||
| .node()->owner_opr(); | |||||
| } | |||||
| MGB_REG_OPR_SHALLOW_COPY(ShapeHint, opr_shallow_copy_shape_hint); | |||||
| } // namespace opr | } // namespace opr | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -512,6 +512,27 @@ public: | |||||
| const OperatorNodeConfig& config = {}); | const OperatorNodeConfig& config = {}); | ||||
| }; | }; | ||||
| /* | |||||
| * \brief a special op providing shape hint only used in graph compilation (gopt) | |||||
| */ | |||||
| MGB_DEFINE_OPR_CLASS(ShapeHint, cg::SingleCNOperatorNodeBase) // { | |||||
| TensorShape m_shape; | |||||
| bool m_is_const; | |||||
| void scn_do_execute() override; | |||||
| void init_output_static_infer_desc() override; | |||||
| public: | |||||
| ShapeHint(VarNode* inp, const TensorShape shape, | |||||
| bool is_const, const OperatorNodeConfig& config); | |||||
| static SymbolVar make(SymbolVar inp, const TensorShape shape, | |||||
| bool is_const=false, const OperatorNodeConfig& config = {}); | |||||
| TensorShape shape() const { return m_shape; } | |||||
| bool is_const() const { return m_is_const; } | |||||
| }; | |||||
| } // namespace opr | } // namespace opr | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -12,6 +12,7 @@ | |||||
| #include "megbrain/opr/utility.h" | #include "megbrain/opr/utility.h" | ||||
| #include "megbrain/gopt/framework.h" | #include "megbrain/gopt/framework.h" | ||||
| #include "megbrain/opr/io.h" | #include "megbrain/opr/io.h" | ||||
| #include "megbrain/serialization/opr_shallow_copy.h" | |||||
| #include "megbrain/test/helper.h" | #include "megbrain/test/helper.h" | ||||
| using namespace mgb; | using namespace mgb; | ||||
| @@ -467,4 +468,64 @@ TEST(TestOprUtility, RequireInputDynamicStorage) { | |||||
| ASSERT_LT(nr_opr(func), nr0); | ASSERT_LT(nr_opr(func), nr0); | ||||
| } | } | ||||
| TEST(TestOprUtility, ShapeHint) { | |||||
| HostTensorGenerator<> gen; | |||||
| HostTensorGenerator<dtype::Int32> gen_int; | |||||
| constexpr size_t length = 233; | |||||
| { // basic | |||||
| for (bool dynamic : {false, true}) { | |||||
| auto host_x = gen_int({length}); | |||||
| auto graph = ComputingGraph::make(); | |||||
| SymbolVar x = opr::Host2DeviceCopy::make(*graph, host_x), x_shape_hint, y; | |||||
| if (dynamic) { | |||||
| x_shape_hint = opr::ShapeHint::make(opr::MarkDynamicVar::make(x), TensorShape{length * 2}); | |||||
| } else { | |||||
| x_shape_hint = opr::ShapeHint::make(x, TensorShape{length * 2}); | |||||
| } | |||||
| y = x_shape_hint * 2 + 1; | |||||
| if (dynamic) { | |||||
| ASSERT_TRUE(y.shape().eq_shape({length * 2})); | |||||
| } else { | |||||
| ASSERT_TRUE(y.shape().eq_shape({length})); | |||||
| } | |||||
| HostTensorND host_y; | |||||
| auto func = graph->compile({make_callback_copy(y, host_y)}); | |||||
| func->execute(); | |||||
| ASSERT_TRUE(host_y.shape().eq_shape({length})); | |||||
| for (size_t i = 0; i < length; ++ i) { | |||||
| ASSERT_EQ((*host_x->ptr<int32_t>()) * 2 + 1, *host_y.ptr<int32_t>()); | |||||
| } | |||||
| } | |||||
| } | |||||
| { // shallow copy | |||||
| auto graph = ComputingGraph::make(); | |||||
| auto host_x = gen({length}); | |||||
| SymbolVar x = opr::Host2DeviceCopy::make(*graph, host_x), | |||||
| y = opr::ShapeHint::make(x, TensorShape{length * 2}), | |||||
| x_unknown = opr::MarkDynamicVar::make(x), | |||||
| y_copy = serialization::copy_opr_shallow( | |||||
| *y.node()->owner_opr(), {x_unknown.node()})->output(0); | |||||
| ASSERT_TRUE(y.shape().eq_shape({length})); | |||||
| ASSERT_TRUE(y_copy.shape().eq_shape({length * 2})); | |||||
| } | |||||
| { // grad | |||||
| auto host_x = gen({1}), host_y = gen({1}); | |||||
| auto graph = ComputingGraph::make(); | |||||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x), | |||||
| y = opr::Host2DeviceCopy::make(*graph, host_y), | |||||
| x_shape_hint = opr::ShapeHint::make(opr::MarkDynamicVar::make(x), TensorShape{1}), | |||||
| y_shape_hint = opr::ShapeHint::make(y, TensorShape{1}), | |||||
| t = x_shape_hint * y_shape_hint; | |||||
| HostTensorND host_gx, host_gy; | |||||
| auto func = graph->compile({ | |||||
| make_callback_copy(cg::grad(t, x), host_gx), | |||||
| make_callback_copy(cg::grad(t, y), host_gy) | |||||
| }); | |||||
| func->execute(); | |||||
| ASSERT_TRUE(host_gx.shape().is_scalar()); | |||||
| ASSERT_TRUE(host_gy.shape().is_scalar()); | |||||
| ASSERT_FLOAT_EQ(*host_x->ptr<float>(), *host_gy.ptr<float>()); | |||||
| ASSERT_FLOAT_EQ(*host_y->ptr<float>(), *host_gx.ptr<float>()); | |||||
| } | |||||
| } | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||