You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

program_specialize.h 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef PIPELINE_STATIC_ANALYSIS_SPECIALIZE_H_
  19. #define PIPELINE_STATIC_ANALYSIS_SPECIALIZE_H_
  20. #include <memory>
  21. #include <string>
  22. #include <stdexcept>
  23. #include <unordered_set>
  24. #include <unordered_map>
  25. #include <utility>
  26. #include <vector>
  27. #include "ir/anf.h"
  28. #include "ir/func_graph_cloner.h"
  29. #include "pipeline/static_analysis/evaluator.h"
  30. namespace mindspore {
  31. namespace abstract {
  32. enum SpecializeStatusCode {
  33. kSpecializeSuccess = 0,
  34. kSpecializeFindUniqueArgvalDead = 1, // Dead Node
  35. kSpecializeFindUniqueArgvalPoly = 2, // Poly Node
  36. kSpecializeFailure = 0xFF
  37. };
  38. class FuncGraphSpecializer;
  39. // Specialize a func graph using analyzed abstract values.
  40. class ProgramSpecializer {
  41. public:
  42. explicit ProgramSpecializer(const std::shared_ptr<AnalysisEngine> &engine) : engine_(engine) {
  43. mng_ = engine_->func_graph_manager();
  44. }
  45. ~ProgramSpecializer() = default;
  46. // Run the program specializer on the topmost graph in the given context.
  47. FuncGraphPtr Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context);
  48. const std::unordered_set<AnfNodePtr> &seen() const { return seen_; }
  49. void AddSeen(const AnfNodePtr &node) { (void)seen_.insert(node); }
  50. std::shared_ptr<FuncGraphSpecializer> GetFuncGraphSpecializer(const AnalysisContextPtr &context);
  51. // Specialze one FuncGraph in a given context.
  52. FuncGraphPtr SpecializeFuncGraph(const FuncGraphPtr &fg, const AnalysisContextPtr &context);
  53. std::shared_ptr<AnalysisEngine> engine() { return engine_; }
  54. private:
  55. std::shared_ptr<AnalysisEngine> engine_;
  56. std::unordered_set<AnfNodePtr> seen_;
  57. FuncGraphManagerPtr mng_;
  58. std::unordered_map<AnalysisContextPtr, std::shared_ptr<FuncGraphSpecializer>, ContextHasher, ContextEqual>
  59. specializations_;
  60. };
  61. class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecializer> {
  62. public:
  63. FuncGraphSpecializer(ProgramSpecializer *const s, const FuncGraphPtr &fg, const AnalysisContextPtr &context);
  64. virtual ~FuncGraphSpecializer() {
  65. specializer_ = nullptr;
  66. repl_node_ = nullptr;
  67. }
  68. void Run();
  69. FuncGraphPtr specialized_func_graph() { return specialized_func_graph_; }
  70. private:
  71. ProgramSpecializer *specializer_;
  72. FuncGraphPtr func_graph_;
  73. FuncGraphPtr specialized_func_graph_;
  74. AnalysisContextPtr context_;
  75. std::shared_ptr<FuncGraphSpecializer> parent_;
  76. std::shared_ptr<AnalysisEngine> engine_;
  77. ClonerPtr cloner_;
  78. // ProcessNode-> [cloner_->CloneDisconnected] will clone AnfNode again.
  79. // So, repl_node_ should pointer to GraphCloner->repl_node_ other than a copy of that.
  80. std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_node_;
  81. std::vector<AnfNodePtr> todo_;
  82. std::unordered_set<AnfNodePtr> marked_;
  83. std::unordered_map<EvaluatorPtr, EvaluatorCacheMapPtr> evalcaches_;
  84. void FirstPass();
  85. void SecondPass();
  86. void ProcessNode(const AnfNodePtr &node);
  87. void ProcessCNode(const CNodePtr &new_node);
  88. AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node);
  89. inline void AddTodoItem(const AnfNodePtr &node) { todo_.push_back(node); }
  90. // Get node replicated by Cloner.
  91. AnfNodePtr GetReplicatedNode(const AnfNodePtr &node);
  92. // Replicated node which is not used directly by a func graph, so it's not searchable from it's return node
  93. // (disconnected).
  94. AnfNodePtr ReplicateDisconnectedNode(const AnfNodePtr &node);
  95. // Build a value node if ival is constant and not any-value
  96. AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival);
  97. // Build a replacable node for iconf->node; it may be a replicated forwared CNode in static analysis or just a
  98. // replicated node.
  99. AnfNodePtr BuildReplacedNode(const AnfNodeConfigPtr &conf);
  100. // Build a specialized node from given argvals;
  101. AnfNodePtr BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs,
  102. const AbstractBasePtrList &argvals);
  103. AnfNodePtr BuildSpecializedNodeInner(const AbstractBasePtr &abs, const AbstractFunctionPtr &func,
  104. const AbstractBasePtrList &args, SpecializeStatusCode *errcode);
  105. // Find the unique argument values which can be used to specialize a primitive or graph function.
  106. SpecializeStatusCode FindUniqueArgvals(const AbstractFunctionPtr &fn, const EvaluatorPtr &eval,
  107. const AbstractBasePtrList &argvals,
  108. std::pair<AbstractBasePtrList, AbstractBasePtr> *result);
  109. // Get cache, it may be eval's cache or cache built from broaded argument values.
  110. const EvaluatorCacheMapPtr &GetEvalCache(const EvaluatorPtr &eval);
  111. // Try to build unique argvals from the broaded arg vals if it is unique.
  112. std::pair<AbstractBasePtrList, AbstractBasePtr> BuildFromBroadedArgsVal(const EvaluatorPtr &eval);
  113. };
  114. } // namespace abstract
  115. } // namespace mindspore
  116. #endif // PIPELINE_STATIC_ANALYSIS_SPECIALIZE_H_