You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

function_block.h 8.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_FUNCTION_BLOCK_H_
  19. #define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_FUNCTION_BLOCK_H_
  20. #include <vector>
  21. #include <string>
  22. #include <map>
  23. #include <set>
  24. #include <memory>
  25. #include <utility>
  26. #include <tuple>
  27. #include "utils/hash_map.h"
  28. #include "ir/meta_func_graph.h"
  29. #include "pipeline/jit/parse/parse_base.h"
  30. #include "utils/log_adapter.h"
  31. #include "utils/ordered_set.h"
  32. namespace mindspore {
  33. namespace parse {
  34. class Parser;
  35. class NameSpace;
  36. class Symbol;
  37. class Script;
  38. class FunctionBlock;
  39. using FunctionBlockPtr = std::shared_ptr<FunctionBlock>;
  40. // A function block is a straight-line code sequence with no branches, every block has one one exit point
  41. // which is return. When parsing function, loop or branch , we use function block to track the structure of
  42. // the original source code.
  43. class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
  44. public:
  45. explicit FunctionBlock(const Parser &parser);
  46. virtual ~FunctionBlock() = default;
  47. FuncGraphPtr func_graph() { return func_graph_; }
  48. std::string ToString() const { return func_graph_->ToString(); }
  49. void WriteVariable(const std::string &var_name, const AnfNodePtr &node);
  50. AnfNodePtr ReadVariable(const std::string &var_name);
  51. void AddPrevBlock(const FunctionBlockPtr &block);
  52. void SetPhiArgument(const ParameterPtr &phi);
  53. bool CollectRemovablePhi(const ParameterPtr &phi);
  54. // A block is matured if all its predecessors is generated
  55. void Mature();
  56. CNodePtr ForceToBoolNode(const AnfNodePtr &cond);
  57. CNodePtr ForceToWhileCond(const AnfNodePtr &cond);
  58. void Jump(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &args);
  59. AnfNodePtr SearchReplaceNode(const std::string &var, const ParameterPtr &phi);
  60. void ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &trueBlock, const FunctionBlockPtr &falseBlock,
  61. bool unroll_loop = true);
  62. // Create cnode for the assign statement like self.target = source.
  63. void SetStateAssign(const AnfNodePtr &target, const AnfNodePtr &source);
  64. void AddGlobalVar(const std::string &var_name) { (void)global_vars_.insert(var_name); }
  65. bool IsGlobalVar(const std::string &var_name) { return global_vars_.find(var_name) != global_vars_.end(); }
  66. AnfNodePtr MakeResolveAstOp(const py::object &op);
  67. AnfNodePtr MakeResolveClassMember(const std::string &attr);
  68. AnfNodePtr MakeResolveSymbol(const std::string &value);
  69. AnfNodePtr MakeResolveOperation(const std::string &value);
  70. AnfNodePtr MakeResolve(const std::shared_ptr<NameSpace> &name_space, const std::shared_ptr<Symbol> &resolve_symbol);
  71. AnfNodePtr GetResolveNode(const py::tuple &namespace_info);
  72. AnfNodePtr HandleNamespaceInfo(const py::tuple &namespace_info);
  73. AnfNodePtr HandleBuiltinNamespaceInfo(const py::tuple &namespace_info);
  74. AnfNodePtr MakeInterpret(const std::string &script_text, const AnfNodePtr &global_dict_node,
  75. const AnfNodePtr &local_dict_node, const AnfNodePtr &orig_node);
  76. const mindspore::HashMap<ParameterPtr, AnfNodePtr> &removable_phis() const { return removable_phis_; }
  77. void FindIsolatedNodes();
  78. void AddIsolatedNode(const AnfNodePtr &target);
  79. void AttachIsolatedNodesBeforeReturn();
  80. const std::vector<FunctionBlock *> &prev_blocks() const { return prev_blocks_; }
  81. bool is_dead_block() const { return is_dead_block_; }
  82. void SetAsDeadBlock();
  83. py::dict &global_py_params() { return global_py_params_; }
  84. void set_global_py_params(const py::dict &symbols) { global_py_params_ = symbols; }
  85. void AddGlobalPyParam(const std::string &name, const py::object &obj) { global_py_params_[py::str(name)] = obj; }
  86. void UpdateGlobalPyParam(const py::dict &symbols) {
  87. for (auto &param : symbols) {
  88. if (!global_py_params_.contains(param.first)) {
  89. global_py_params_[param.first] = param.second;
  90. }
  91. }
  92. }
  93. std::tuple<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> local_py_params() {
  94. return {local_py_params_keys_, local_py_params_values_};
  95. }
  96. void AddLocalPyParam(const std::string &name, const AnfNodePtr &node) {
  97. MS_LOG(DEBUG) << "Add '" << name << "', " << node->DebugString();
  98. local_py_params_keys_.emplace_back(NewValueNode(name));
  99. local_py_params_values_.emplace_back(node);
  100. }
  101. // Call this methon only if you need update a variable. Usually variable override.
  102. void UpdateLocalPyParam(const std::string &name, const AnfNodePtr &node) {
  103. auto iter = std::find_if(local_py_params_keys_.cbegin(), local_py_params_keys_.cend(),
  104. [&name](const AnfNodePtr node) -> bool {
  105. const auto value_node = dyn_cast<ValueNode>(node);
  106. MS_EXCEPTION_IF_NULL(value_node);
  107. const StringImmPtr &str_imm = dyn_cast<StringImm>(value_node->value());
  108. MS_EXCEPTION_IF_NULL(str_imm);
  109. return name == str_imm->value();
  110. });
  111. if (iter == local_py_params_keys_.cend()) {
  112. MS_LOG(EXCEPTION) << "Only for updating. Should not call this method if 'name' not exist.";
  113. }
  114. // Find the same position in 'values', and update the node.
  115. auto distance = std::distance(local_py_params_keys_.cbegin(), iter);
  116. auto values_pos_iter = local_py_params_values_.begin() + distance;
  117. MS_LOG(DEBUG) << "Update '" << name << "', " << (*values_pos_iter)->DebugString() << " -> " << node->DebugString();
  118. *values_pos_iter = node;
  119. }
  120. private:
  121. // Block graph
  122. FuncGraphPtr func_graph_;
  123. // Block parser
  124. const Parser &parser_;
  125. // A block is matured if all its prev_blocks is processed
  126. bool matured_;
  127. // Store the nest-level block.
  128. // Refer to comments in Parser::func_block_list_;
  129. std::vector<FunctionBlock *> prev_blocks_;
  130. // Store args and variable's node, use a bool flag to indicate if the variable is used.
  131. std::map<std::string, std::pair<AnfNodePtr, bool>> assigned_vars_;
  132. // Map the parameter node to variable, it can be resolved if the block's predecessors are processed
  133. std::map<ParameterPtr, std::string> phi_nodes_;
  134. // Jumps map the successor block and the function call that perform jump
  135. // Refer to comments in Parser::func_block_list_ that how to break the cyclic reference
  136. std::map<FunctionBlock *, CNodePtr> jumps_;
  137. // Keep all removable phis which will be removed in one pass.
  138. mindspore::HashMap<ParameterPtr, AnfNodePtr> removable_phis_;
  139. // Keep the map for the resolve node to the removable phi node.
  140. // For the case that ReadVariable returns a phi node although this phi node
  141. // generated in the prev block is identified as removable. The other blocks
  142. // should find this phi node.
  143. mindspore::HashMap<AnfNodePtr, ParameterPtr> resolve_to_removable_phis_;
  144. // Hold declared global variables in function
  145. std::set<std::string> global_vars_;
  146. // Keep new made resolve symbol for the variable not found in vars_.
  147. mindspore::HashMap<std::string, AnfNodePtr> var_to_resolve_;
  148. // Collect all python symbols in the block.
  149. // We treat both global symbols and local symbols declared previously as global symbols.
  150. py::dict global_py_params_;
  151. std::vector<AnfNodePtr> local_py_params_keys_;
  152. std::vector<AnfNodePtr> local_py_params_values_;
  153. // Isolated nodes.
  154. OrderedSet<AnfNodePtr> isolated_nodes_;
  155. // If a block can never be executed, it's prev blocks will be empty, so this block is a dead block.
  156. // while x > 5:
  157. // x = x - 2
  158. // if x > 7 :
  159. // break
  160. // else :
  161. // break
  162. // x = x - 1 #This after block is a dead block
  163. bool is_dead_block_{false};
  164. };
  165. } // namespace parse
  166. } // namespace mindspore
  167. #endif // MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_FUNCTION_BLOCK_H_