diff --git a/mindspore/ccsrc/optimizer/irpass.cc b/mindspore/ccsrc/optimizer/irpass.cc index 96d88f6e61..be9c8f787a 100644 --- a/mindspore/ccsrc/optimizer/irpass.cc +++ b/mindspore/ccsrc/optimizer/irpass.cc @@ -52,7 +52,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", {prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType, prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); - zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor); + zero_like_fill_zero_ = + MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor, opt::FORCE_RENORM); // ops eliminate item_tuple_eliminate_ = @@ -81,7 +82,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() { make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef); get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate", {prim::kPrimGetRefKey, prim::kPrimGetRefValue}); - replace_refkey_by_param_ = MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode); + + replace_refkey_by_param_ = + MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode, opt::FORCE_RENORM); replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam); // Gradient transforms diff --git a/mindspore/ccsrc/optimizer/opt.cc b/mindspore/ccsrc/optimizer/opt.cc index a0faa2bf46..24339ddb84 100644 --- a/mindspore/ccsrc/optimizer/opt.cc +++ b/mindspore/ccsrc/optimizer/opt.cc @@ -31,14 +31,14 @@ namespace mindspore { /* namespace to support opt */ namespace opt { -SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, - const PrimitivePtr& prim) { +SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, const PrimitivePtr& prim, + const RenormAction& renorm_action) { auto fn = [prim](const AnfNodePtr& node) -> bool { return IsPrimitiveCNode(node, prim); }; - return std::make_shared(transform, name, fn); + return std::make_shared(transform, name, fn, renorm_action); } SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, - const std::vector& prims) { + const std::vector& prims, const RenormAction& renorm_action) { auto fn = [prims](const AnfNodePtr& node) -> bool { if (!node->isa()) { return false; @@ -52,12 +52,12 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std:: return false; }; - return std::make_shared(transform, name, fn); + return std::make_shared(transform, name, fn, renorm_action); } SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, - const PredicateFuncType& predicate) { - return std::make_shared(transform, name, predicate); + const PredicateFuncType& predicate, const RenormAction& renorm_action) { + return std::make_shared(transform, name, predicate, renorm_action); } AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNodePtr& node) const { @@ -74,6 +74,16 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNode } } #endif + if (optimizer != nullptr && optimizer->is_watch_renormalize() && result != nullptr) { + if (renorm_action_ == FORCE_RENORM) { + optimizer->add_node_to_renormalize(result); + } else { + // renorm_action_ is CHECK_RENORM + if (result->abstract() == nullptr) { + optimizer->add_node_to_renormalize(result); + } + } + } return result; } diff --git a/mindspore/ccsrc/optimizer/opt.h b/mindspore/ccsrc/optimizer/opt.h index bd548645f4..24191998e8 100644 --- a/mindspore/ccsrc/optimizer/opt.h +++ b/mindspore/ccsrc/optimizer/opt.h @@ -36,24 +36,34 @@ using OptimizerWeakPtr = std::weak_ptr; using PredicateFuncType = std::function; using TransformFuncType = std::function; +// Define the interaction mode between an Optimize pass and Renormalize pass +// FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed +// CHECK_RENORM: check if the new node is un-typed to decide if the next Renormalize will be executted +enum RenormAction : int { FORCE_RENORM = 0, CHECK_RENORM }; + class Substitution { public: TransformFuncType transform_{nullptr}; std::string name_; PredicateFuncType predicate_{nullptr}; - explicit Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate) - : transform_(transform), name_(name), predicate_(predicate) {} + // an enum to mark this Substitution relation to renormalize pass + RenormAction renorm_action_; + explicit Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate, + const RenormAction &renorm_action) + : transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {} ~Substitution() = default; AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const; }; using SubstitutionPtr = std::shared_ptr; -SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim); +SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim, + const RenormAction &action_renorm = CHECK_RENORM); SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, - const std::vector &prims); + const std::vector &prims, + const RenormAction &action_renorm = CHECK_RENORM); SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, - const PredicateFuncType &predicate); + const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM); class SubstitutionList { public: diff --git a/mindspore/ccsrc/optimizer/optimizer.h b/mindspore/ccsrc/optimizer/optimizer.h index d821e826cf..f67466efba 100644 --- a/mindspore/ccsrc/optimizer/optimizer.h +++ b/mindspore/ccsrc/optimizer/optimizer.h @@ -87,11 +87,12 @@ using OptPassGroupMap = std::vector>; class Optimizer : public std::enable_shared_from_this { public: Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr) - : name_(name), resource_(resource_ptr), run_only_once_(false) {} + : name_(name), resource_(resource_ptr), run_only_once_(false), is_watch_renormalize_(false) {} virtual ~Optimizer() = default; void Init(const OptPassGroupMap &passes, bool run_only_once) { run_only_once_ = run_only_once; + is_watch_renormalize_ = false; for (auto &iter : passes) { const std::string &name = iter.first; @@ -118,9 +119,13 @@ class Optimizer : public std::enable_shared_from_this { } static std::shared_ptr MakeOptimizer(const std::string &name, const pipeline::ResourceBasePtr resource_ptr, - const OptPassGroupMap &passes, bool run_only_once = false) { + const OptPassGroupMap &passes, bool run_only_once = false, + bool watch_renormalize = false) { OptimizerPtr optimizer = std::make_shared(name, resource_ptr); optimizer->Init(passes, run_only_once); + if (watch_renormalize) { + optimizer->enable_watch_renormalize(); + } return optimizer; } @@ -138,7 +143,16 @@ class Optimizer : public std::enable_shared_from_this { if (opt.is_renormalize()) { auto resource_ptr = std::dynamic_pointer_cast(resource_); if (resource_ptr != nullptr) { - func_graph = pipeline::Renormalize(resource_ptr, func_graph, args_spec); + if (is_watch_renormalize_) { + if (untyped_nodes_.size() > 0) { + func_graph = pipeline::Renormalize(resource_ptr, func_graph, args_spec); + clear_untyped_nodes(); + } else { + MS_LOG(INFO) << "Optimizer::step: Skipping Renormalize because untyped_nodes_ is empty."; + } + } else { + func_graph = pipeline::Renormalize(resource_ptr, func_graph, args_spec); + } } } else if (opt(func_graph, shared_from_this())) { changes = true; @@ -180,12 +194,26 @@ class Optimizer : public std::enable_shared_from_this { const std::string name() const { return name_; } + void add_node_to_renormalize(AnfNodePtr anode) { + if (std::find(untyped_nodes_.begin(), untyped_nodes_.end(), anode) == untyped_nodes_.end()) { + untyped_nodes_.push_back(anode); + } + } + + void clear_untyped_nodes() { untyped_nodes_.clear(); } + + void enable_watch_renormalize() { is_watch_renormalize_ = true; } + void disable_watch_renormalize() { is_watch_renormalize_ = false; } + bool is_watch_renormalize() { return is_watch_renormalize_; } + private: const std::string name_; pipeline::ResourceBasePtr resource_; std::vector passes_; std::vector pass_names_; bool run_only_once_; + std::vector untyped_nodes_; + bool is_watch_renormalize_; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pass.cc b/mindspore/ccsrc/pipeline/pass.cc index 9248590f27..b3eda4c37b 100644 --- a/mindspore/ccsrc/pipeline/pass.cc +++ b/mindspore/ccsrc/pipeline/pass.cc @@ -185,8 +185,8 @@ void InitOpt(const ResourcePtr& res) { if (g_pass_opts.size() == 0) { opt::irpass::OptimizeIRPassLib irpass; g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass)); - g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass)); - g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass)); + g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true); + g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true); g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); } }