You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

function_block.cc 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/jit/parse/function_block.h"
  19. #include <string>
  20. #include <memory>
  21. #include "pipeline/jit/parse/resolve.h"
  22. #include "pipeline/jit/parse/parse.h"
  23. #include "frontend/operator/ops.h"
  24. #include "utils/info.h"
  25. #include "debug/trace.h"
  26. #include "pybind11/pybind11.h"
  27. namespace mindspore {
  28. namespace py = pybind11;
  29. namespace parse {
  30. FunctionBlock::FunctionBlock(const Parser &parser) : parser_(parser) {
  31. func_graph_ = std::make_shared<FuncGraph>();
  32. matured_ = false;
  33. }
  34. void FunctionBlock::AddPrevBlock(const FunctionBlockPtr &block) { prev_blocks_.push_back(block.get()); }
  35. // write variable records the variable name to corresponding node
  36. void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr &node) {
  37. MS_LOG(DEBUG) << func_graph_->ToString() << " write var " << var_name << " with node " << node->DebugString();
  38. vars_[var_name] = node;
  39. }
  40. // read variable from predecessors
  41. AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
  42. // get var node if it is found
  43. if (vars_.count(var)) {
  44. AnfNodePtr node = vars_[var];
  45. MS_EXCEPTION_IF_NULL(node);
  46. if (node->isa<ValueNode>()) {
  47. return NewValueNode(GetValueNode(node));
  48. } else {
  49. return node;
  50. }
  51. }
  52. // get var from predecessor block ,if can't get the make a resolve node to it
  53. if (matured_) {
  54. // If only one predecessor block, read the definition of var from it.
  55. if (prev_blocks_.size() == 1) {
  56. auto block = prev_blocks_[0];
  57. MS_EXCEPTION_IF_NULL(block);
  58. return block->ReadVariable(var);
  59. } else if (prev_blocks_.empty()) {
  60. // get namespace and make Reslove
  61. return MakeResolveSymbol(var);
  62. }
  63. }
  64. // If have more than one predecessor blocks then build a phi node.
  65. auto debug_info = std::make_shared<NodeDebugInfo>();
  66. debug_info->set_name(var);
  67. TraceManager::DebugTrace(std::make_shared<TracePhi>(debug_info));
  68. ParameterPtr phi_param = std::make_shared<Parameter>(func_graph());
  69. TraceManager::EndTrace();
  70. MS_LOG(DEBUG) << func_graph_->ToString() << " generate phi node " << phi_param->ToString() << " for " << var;
  71. func_graph()->add_parameter(phi_param);
  72. phi_nodes_[phi_param] = var;
  73. WriteVariable(var, phi_param);
  74. if (matured_) {
  75. SetPhiArgument(phi_param);
  76. }
  77. return phi_param;
  78. }
  79. // Resolve Ast operator node
  80. AnfNodePtr FunctionBlock::MakeResolveAstOp(const py::object &op) {
  81. auto ast = parser_.ast();
  82. MS_EXCEPTION_IF_NULL(ast);
  83. TraceGuard trace_guard(parser_.GetLocation(op));
  84. py::tuple namespace_var = ast->CallParserObjMethod(PYTHON_PARSE_GET_AST_NAMESPACE_SYMBOL, op);
  85. if (namespace_var.size() != 2) {
  86. MS_LOG(EXCEPTION) << "Resolve ast op failed, get namespace tuple size=" << namespace_var.size();
  87. }
  88. NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_AST, namespace_var[0]);
  89. SymbolPtr symbol = std::make_shared<Symbol>(namespace_var[1].cast<std::string>());
  90. return MakeResolve(name_space, symbol);
  91. }
  92. // Resolve class member, two possible: method, member variable
  93. AnfNodePtr FunctionBlock::MakeResolveClassMember(const std::string &attr) {
  94. py::object namespace_var =
  95. parser_.ast()->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, parser_.ast()->obj());
  96. NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
  97. SymbolPtr symbol = std::make_shared<Symbol>(attr);
  98. return MakeResolve(name_space, symbol);
  99. }
  100. // Make a resolve node for symbol string
  101. AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) {
  102. if (value.compare(0, strlen("self"), "self") == 0) {
  103. auto start = value.find_first_of('.') + 1;
  104. if (start >= value.size()) {
  105. MS_LOG(ERROR) << "Find invalid resolve symbol str: " << value;
  106. return nullptr;
  107. }
  108. auto bits_str = value.substr(start);
  109. return MakeResolveClassMember(bits_str);
  110. }
  111. py::tuple namespace_var = parser_.ast()->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value);
  112. NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_SYMBOL_STR, namespace_var[0]);
  113. SymbolPtr symbol = std::make_shared<Symbol>(namespace_var[1].cast<std::string>());
  114. return MakeResolve(name_space, symbol);
  115. }
  116. AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) {
  117. py::tuple namespace_var = parser_.ast()->CallParserObjMethod(PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL, value);
  118. NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_COMMON_OPS, namespace_var[0]);
  119. SymbolPtr symbol = std::make_shared<Symbol>(namespace_var[1].cast<std::string>());
  120. return MakeResolve(name_space, symbol);
  121. }
  122. AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const SymbolPtr &resolve_symbol) {
  123. MS_LOG(DEBUG) << "MakeResolve for " << ((std::string)py::str(name_space->obj())) << " , "
  124. << ((std::string)resolve_symbol->symbol());
  125. ValueNodePtr module_node = NewValueNode(name_space);
  126. ValueNodePtr symbol_node = NewValueNode(resolve_symbol);
  127. auto node = func_graph()->NewCNode({NewValueNode(prim::kPrimResolve), module_node, symbol_node});
  128. return node;
  129. }
  130. // add input for the block's phi parameter
  131. void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) {
  132. std::string var = phi_nodes_[phi];
  133. MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " set phi " << phi->ToString() << " for var " << var;
  134. auto removable = CollectRemovablePhi(phi);
  135. // If the phi node is not necessary, not need to add to jumps_ of the prev blocks.
  136. if (removable) {
  137. MS_LOG(DEBUG) << "remove the phi when call graph " << func_graph_->ToString() << " var " << var;
  138. return;
  139. }
  140. for (auto &pred : prev_blocks_) {
  141. MS_EXCEPTION_IF_NULL(pred);
  142. MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " pred_blocks_ " << pred->func_graph_->ToString();
  143. AnfNodePtr arg_node = pred->ReadVariable(var);
  144. CNodePtr jump = pred->jumps_[this];
  145. jump->add_input(arg_node);
  146. }
  147. }
  148. AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const ParameterPtr &phi) {
  149. AnfNodePtr arg_node = nullptr;
  150. for (auto &prev : prev_blocks_) {
  151. MS_EXCEPTION_IF_NULL(prev);
  152. AnfNodePtr temp_node = prev->ReadVariable(var);
  153. MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " phi " << phi->ToString() << " for var " << var
  154. << " is " << temp_node->DebugString();
  155. if (temp_node != phi) {
  156. if (arg_node == nullptr) {
  157. arg_node = temp_node;
  158. MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " phi " << phi->ToString()
  159. << " may be replaced by node " << arg_node->DebugString();
  160. } else if (temp_node == arg_node) {
  161. MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " phi " << phi->ToString() << " is same as node "
  162. << arg_node->DebugString();
  163. } else {
  164. MS_LOG(DEBUG) << "phi " << phi->ToString()
  165. << " cannot be removed as it assigns to different node. node1: " << arg_node->DebugString()
  166. << ", node2: " << temp_node->DebugString();
  167. return nullptr;
  168. }
  169. }
  170. }
  171. return arg_node;
  172. }
  173. // Check if there is removable unnecessary phi node in this graph.
  174. // as per the FIRM TR 3.2, a phi node can be remove if:
  175. // <Quote>
  176. // If all arguments of a φ-function are the same value s or the φfunction itself,
  177. // then we remove the φ-function and let all users directly uses. We call such a
  178. // φ-function obviously unnecessary.
  179. // When we removed a φ-function p, then we recursively try to apply this simplification
  180. // rule with all (former) users of p, because they may have become obviously unnecessary
  181. // due to the removal of p
  182. // <Quote>
  183. // phi node in graph will be removed after the whole function is parsed in a DFS visit
  184. // of that graph.The reason is :
  185. // 1. when this function is called, not all usage of this phi node had bound to the
  186. // graph of this function block, some may stay in vars_ in other blocks.
  187. // 2. it's costly to iterate the graph to replace the phi for each phi.
  188. // Args :
  189. // phi : This parameter node is functioning as a phi node.
  190. bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
  191. MS_EXCEPTION_IF_NULL(phi);
  192. std::string var = phi_nodes_[phi];
  193. MS_LOG(DEBUG) << "check phi " << phi->DebugString() << " for " << var;
  194. if (prev_blocks_.size() == 0) {
  195. MS_LOG(DEBUG) << "no phi " << phi->DebugString() << " for var " << var;
  196. return false;
  197. }
  198. AnfNodePtr arg_node = SearchReplaceNode(var, phi);
  199. if (arg_node != nullptr) {
  200. MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " phi " << phi->ToString() << " can be replaced with "
  201. << arg_node->DebugString();
  202. // replace var with new one. This equal to statement in TR "v0 is immediately replaced by v1."
  203. WriteVariable(var, arg_node);
  204. removable_phis_[phi] = arg_node;
  205. // The following equal to statement "The φ-function defining v1, which now reads φ(v2, v1), is optimized
  206. // recursively". check if phi1 is assigned with this phi before, then phi1 can be replaced with arg_node.
  207. for (auto &prev : prev_blocks_) {
  208. MS_EXCEPTION_IF_NULL(prev);
  209. if (!prev->matured_) {
  210. continue;
  211. }
  212. for (auto &phi_iter : prev->removable_phis_) {
  213. MS_EXCEPTION_IF_NULL(phi_iter.second);
  214. if (phi_iter.second->isa<Parameter>()) {
  215. const auto &param = phi_iter.second->cast<ParameterPtr>();
  216. if (param == phi) {
  217. MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " var " << phi_iter.first->DebugString()
  218. << " can be replaced from " << param->DebugString() << " with " << arg_node->DebugString()
  219. << " in graph " << arg_node->func_graph()->ToString();
  220. prev->removable_phis_[phi_iter.first] = arg_node;
  221. }
  222. }
  223. }
  224. }
  225. return true;
  226. }
  227. return false;
  228. }
  229. // A block should be marked matured if its predecessor blocks have been processed
  230. void FunctionBlock::Mature() {
  231. const auto &graphParamVec = func_graph_->parameters();
  232. for (auto &paramItr : graphParamVec) {
  233. MS_EXCEPTION_IF_NULL(paramItr);
  234. ParameterPtr param = paramItr->cast<ParameterPtr>();
  235. if (phi_nodes_.find(param) != phi_nodes_.cend()) {
  236. SetPhiArgument(param);
  237. }
  238. }
  239. matured_ = true;
  240. }
  241. // Force the conditIon node to bool using bool operation
  242. CNodePtr FunctionBlock::ForceToBoolNode(const AnfNodePtr &cond) {
  243. TraceManager::DebugTrace(std::make_shared<TraceForceBool>(cond->debug_info()));
  244. CNodePtr op_apply_node = func_graph()->NewCNode({MakeResolveOperation(NAMED_PRIMITIVE_BOOL), cond});
  245. TraceManager::EndTrace();
  246. return op_apply_node;
  247. }
  248. CNodePtr FunctionBlock::ForceToWhileCond(const AnfNodePtr &cond) {
  249. TraceManager::DebugTrace(std::make_shared<TraceForceWhileCond>(cond->debug_info()));
  250. CNodePtr op_apply_node = func_graph()->NewCNode({MakeResolveOperation("while_cond"), cond});
  251. TraceManager::EndTrace();
  252. return op_apply_node;
  253. }
  254. // Perform a jump from this block to target block
  255. void FunctionBlock::Jump(const FunctionBlockPtr &target_block, AnfNodePtr node) {
  256. if (func_graph()->get_return() != nullptr) {
  257. MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: "
  258. << trace::GetDebugInfo(func_graph()->get_return()->debug_info());
  259. }
  260. std::vector<AnfNodePtr> input_nodes;
  261. input_nodes.emplace_back(NewValueNode(target_block->func_graph()));
  262. if (node != nullptr) {
  263. input_nodes.emplace_back(node);
  264. }
  265. CNodePtr jump = func_graph()->NewCNode(input_nodes);
  266. jumps_[target_block.get()] = jump;
  267. target_block->AddPrevBlock(shared_from_this());
  268. func_graph()->set_output(jump);
  269. InsertDependItemsBeforeReturn();
  270. }
  271. // Perform a conditional jump using switch operation.
  272. // The first CNode select graph with condition, and than execute this graph
  273. void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &true_block,
  274. const FunctionBlockPtr &false_block, bool unroll_loop) {
  275. if (func_graph()->get_return() != nullptr) {
  276. MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: "
  277. << trace::GetDebugInfo(func_graph()->get_return()->debug_info());
  278. }
  279. CNodePtr switch_app =
  280. func_graph()->NewCNode({NewValueNode(prim::kPrimSwitch), condNode, NewValueNode(true_block->func_graph()),
  281. NewValueNode(false_block->func_graph())});
  282. CNodePtr switch_app_new = func_graph()->NewCNode({switch_app});
  283. func_graph()->set_output(switch_app_new);
  284. InsertDependItemsBeforeReturn();
  285. }
  286. void FunctionBlock::SetStateAssgin(const AnfNodePtr &target, const std::string &readid) {
  287. const std::string primitive_name("assign");
  288. const std::string module_name("mindspore.ops.functional");
  289. ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true));
  290. auto source = ReadVariable(readid);
  291. auto assign = func_graph()->NewCNode({assign_op, target, source});
  292. WriteVariable(readid, assign);
  293. MS_LOG(INFO) << "SetState read " << target->DebugString() << ", " << readid;
  294. AddAutoDepend(assign);
  295. }
  296. void FunctionBlock::AddAutoDepend(const AnfNodePtr &target) { auto_depends_.push_back(target); }
  297. void FunctionBlock::InsertDependItemsBeforeReturn() {
  298. if (!prev_blocks_.empty()) {
  299. for (auto &prev_block : prev_blocks_) {
  300. MS_LOG(DEBUG) << "Has prev_block " << prev_block->func_graph()->debug_info().get();
  301. }
  302. }
  303. ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple);
  304. ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend);
  305. ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient);
  306. if (auto_depends_.size() == 0) {
  307. return;
  308. }
  309. AnfNodePtr state = nullptr;
  310. std::vector<AnfNodePtr> vec_states;
  311. vec_states.emplace_back(make_tuple_op);
  312. for (auto &item : auto_depends_) {
  313. MS_LOG(DEBUG) << "auto_depends " << item->ToString();
  314. vec_states.emplace_back(item);
  315. }
  316. // if there are only make_tuple_op and another node in vec_states(the vec_states size is 2)
  317. // do not need to make_tuple, just use the node.
  318. if (vec_states.size() == 2) {
  319. state = vec_states[1];
  320. } else {
  321. state = func_graph()->NewCNode(vec_states);
  322. }
  323. AnfNodePtr old_ret = nullptr;
  324. auto return_node = func_graph()->get_return();
  325. if (return_node) {
  326. if (return_node->inputs().size() < 1) {
  327. MS_LOG(EXCEPTION) << "Length of inputs of output node is less than 2";
  328. }
  329. old_ret = return_node->input(1);
  330. } else {
  331. old_ret = NewValueNode(kNone);
  332. }
  333. AnfNodePtr stopped = func_graph()->NewCNode({stop_gradient_op, state});
  334. AnfNodePtr ret = func_graph()->NewCNode({depend_op, old_ret, stopped});
  335. func_graph()->set_output(ret, true);
  336. }
  337. } // namespace parse
  338. } // namespace mindspore