Signed-off-by: Hoai Linh Tran h00472437 <hoai.linh.tran@huawei.com> Add optimizer checking: For a group of passes in "optimizer", if flagged then it will check and collect the newly generated nodes without types (i.e. abstract() == nullptr). Before calling Renormalize(), the optimizer will check if there is any node needed retyping. If not the Renormalize pass will not be called. Add checking for non-null abstract but still needs renorm; Add flags to Substitution to help watching Renormalize Simpler pass result checker, change Bool to Enum typetags/v0.2.0-alpha
| @@ -52,7 +52,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", | |||
| {prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType, | |||
| prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); | |||
| zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor); | |||
| zero_like_fill_zero_ = | |||
| MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor, opt::FORCE_RENORM); | |||
| // ops eliminate | |||
| item_tuple_eliminate_ = | |||
| @@ -81,7 +82,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef); | |||
| get_make_ref_eliminate_ = | |||
| MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate", {prim::kPrimGetRefKey, prim::kPrimGetRefValue}); | |||
| replace_refkey_by_param_ = MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>); | |||
| replace_refkey_by_param_ = | |||
| MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM); | |||
| replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam); | |||
| // Gradient transforms | |||
| @@ -31,14 +31,14 @@ | |||
| namespace mindspore { | |||
| /* namespace to support opt */ | |||
| namespace opt { | |||
| SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, | |||
| const PrimitivePtr& prim) { | |||
| SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, const PrimitivePtr& prim, | |||
| const RenormAction& renorm_action) { | |||
| auto fn = [prim](const AnfNodePtr& node) -> bool { return IsPrimitiveCNode(node, prim); }; | |||
| return std::make_shared<Substitution>(transform, name, fn); | |||
| return std::make_shared<Substitution>(transform, name, fn, renorm_action); | |||
| } | |||
| SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, | |||
| const std::vector<PrimitivePtr>& prims) { | |||
| const std::vector<PrimitivePtr>& prims, const RenormAction& renorm_action) { | |||
| auto fn = [prims](const AnfNodePtr& node) -> bool { | |||
| if (!node->isa<CNode>()) { | |||
| return false; | |||
| @@ -52,12 +52,12 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std:: | |||
| return false; | |||
| }; | |||
| return std::make_shared<Substitution>(transform, name, fn); | |||
| return std::make_shared<Substitution>(transform, name, fn, renorm_action); | |||
| } | |||
| SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, | |||
| const PredicateFuncType& predicate) { | |||
| return std::make_shared<Substitution>(transform, name, predicate); | |||
| const PredicateFuncType& predicate, const RenormAction& renorm_action) { | |||
| return std::make_shared<Substitution>(transform, name, predicate, renorm_action); | |||
| } | |||
| AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNodePtr& node) const { | |||
| @@ -74,6 +74,16 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNode | |||
| } | |||
| } | |||
| #endif | |||
| if (optimizer != nullptr && optimizer->is_watch_renormalize() && result != nullptr) { | |||
| if (renorm_action_ == FORCE_RENORM) { | |||
| optimizer->add_node_to_renormalize(result); | |||
| } else { | |||
| // renorm_action_ is CHECK_RENORM | |||
| if (result->abstract() == nullptr) { | |||
| optimizer->add_node_to_renormalize(result); | |||
| } | |||
| } | |||
| } | |||
| return result; | |||
| } | |||
| @@ -36,24 +36,34 @@ using OptimizerWeakPtr = std::weak_ptr<Optimizer>; | |||
| using PredicateFuncType = std::function<bool(const AnfNodePtr &)>; | |||
| using TransformFuncType = std::function<AnfNodePtr(const OptimizerPtr &, const AnfNodePtr &)>; | |||
| // Define the interaction mode between an Optimize pass and Renormalize pass | |||
| // FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed | |||
| // CHECK_RENORM: check if the new node is un-typed to decide if the next Renormalize will be executted | |||
| enum RenormAction : int { FORCE_RENORM = 0, CHECK_RENORM }; | |||
| class Substitution { | |||
| public: | |||
| TransformFuncType transform_{nullptr}; | |||
| std::string name_; | |||
| PredicateFuncType predicate_{nullptr}; | |||
| explicit Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate) | |||
| : transform_(transform), name_(name), predicate_(predicate) {} | |||
| // an enum to mark this Substitution relation to renormalize pass | |||
| RenormAction renorm_action_; | |||
| explicit Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate, | |||
| const RenormAction &renorm_action) | |||
| : transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {} | |||
| ~Substitution() = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const; | |||
| }; | |||
| using SubstitutionPtr = std::shared_ptr<Substitution>; | |||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim); | |||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim, | |||
| const RenormAction &action_renorm = CHECK_RENORM); | |||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, | |||
| const std::vector<PrimitivePtr> &prims); | |||
| const std::vector<PrimitivePtr> &prims, | |||
| const RenormAction &action_renorm = CHECK_RENORM); | |||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, | |||
| const PredicateFuncType &predicate); | |||
| const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM); | |||
| class SubstitutionList { | |||
| public: | |||
| @@ -87,11 +87,12 @@ using OptPassGroupMap = std::vector<std::pair<std::string, OptPassConfig>>; | |||
| class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||
| public: | |||
| Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr) | |||
| : name_(name), resource_(resource_ptr), run_only_once_(false) {} | |||
| : name_(name), resource_(resource_ptr), run_only_once_(false), is_watch_renormalize_(false) {} | |||
| virtual ~Optimizer() = default; | |||
| void Init(const OptPassGroupMap &passes, bool run_only_once) { | |||
| run_only_once_ = run_only_once; | |||
| is_watch_renormalize_ = false; | |||
| for (auto &iter : passes) { | |||
| const std::string &name = iter.first; | |||
| @@ -118,9 +119,13 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||
| } | |||
| static std::shared_ptr<Optimizer> MakeOptimizer(const std::string &name, const pipeline::ResourceBasePtr resource_ptr, | |||
| const OptPassGroupMap &passes, bool run_only_once = false) { | |||
| const OptPassGroupMap &passes, bool run_only_once = false, | |||
| bool watch_renormalize = false) { | |||
| OptimizerPtr optimizer = std::make_shared<Optimizer>(name, resource_ptr); | |||
| optimizer->Init(passes, run_only_once); | |||
| if (watch_renormalize) { | |||
| optimizer->enable_watch_renormalize(); | |||
| } | |||
| return optimizer; | |||
| } | |||
| @@ -138,7 +143,16 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||
| if (opt.is_renormalize()) { | |||
| auto resource_ptr = std::dynamic_pointer_cast<pipeline::Resource>(resource_); | |||
| if (resource_ptr != nullptr) { | |||
| func_graph = pipeline::Renormalize(resource_ptr, func_graph, args_spec); | |||
| if (is_watch_renormalize_) { | |||
| if (untyped_nodes_.size() > 0) { | |||
| func_graph = pipeline::Renormalize(resource_ptr, func_graph, args_spec); | |||
| clear_untyped_nodes(); | |||
| } else { | |||
| MS_LOG(INFO) << "Optimizer::step: Skipping Renormalize because untyped_nodes_ is empty."; | |||
| } | |||
| } else { | |||
| func_graph = pipeline::Renormalize(resource_ptr, func_graph, args_spec); | |||
| } | |||
| } | |||
| } else if (opt(func_graph, shared_from_this())) { | |||
| changes = true; | |||
| @@ -180,12 +194,26 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||
| const std::string name() const { return name_; } | |||
| void add_node_to_renormalize(AnfNodePtr anode) { | |||
| if (std::find(untyped_nodes_.begin(), untyped_nodes_.end(), anode) == untyped_nodes_.end()) { | |||
| untyped_nodes_.push_back(anode); | |||
| } | |||
| } | |||
| void clear_untyped_nodes() { untyped_nodes_.clear(); } | |||
| void enable_watch_renormalize() { is_watch_renormalize_ = true; } | |||
| void disable_watch_renormalize() { is_watch_renormalize_ = false; } | |||
| bool is_watch_renormalize() { return is_watch_renormalize_; } | |||
| private: | |||
| const std::string name_; | |||
| pipeline::ResourceBasePtr resource_; | |||
| std::vector<OptPass> passes_; | |||
| std::vector<std::string> pass_names_; | |||
| bool run_only_once_; | |||
| std::vector<AnfNodePtr> untyped_nodes_; | |||
| bool is_watch_renormalize_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -185,8 +185,8 @@ void InitOpt(const ResourcePtr& res) { | |||
| if (g_pass_opts.size() == 0) { | |||
| opt::irpass::OptimizeIRPassLib irpass; | |||
| g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass)); | |||
| g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass)); | |||
| g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass)); | |||
| g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true); | |||
| g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true); | |||
| g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); | |||
| } | |||
| } | |||