You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

function_block.cc 26 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/jit/parse/function_block.h"
  19. #include <string>
  20. #include <memory>
  21. #include <algorithm>
  22. #include "pybind11/pybind11.h"
  23. #include "pipeline/jit/parse/resolve.h"
  24. #include "pipeline/jit/parse/parse.h"
  25. #include "pipeline/jit/parse/data_converter.h"
  26. #include "frontend/operator/ops.h"
  27. #include "utils/info.h"
  28. #include "debug/trace.h"
  29. #include "utils/utils.h"
  30. namespace mindspore {
  31. namespace py = pybind11;
  32. namespace parse {
  33. FunctionBlock::FunctionBlock(const Parser &parser) : parser_(parser) {
  34. func_graph_ = std::make_shared<FuncGraph>();
  35. matured_ = false;
  36. }
  37. void FunctionBlock::AddPrevBlock(const FunctionBlockPtr &block) { prev_blocks_.push_back(block.get()); }
  38. static bool CanBeIsolatedNode(const std::string &var_name, const AnfNodePtr &node) {
  39. auto cnode = dyn_cast<CNode>(node);
  40. if (cnode == nullptr || cnode->inputs().empty()) {
  41. // Not a valid cnode, can not be isolate node.
  42. return false;
  43. }
  44. auto prim = GetValueNode<PrimitivePtr>(cnode->inputs().at(0));
  45. if (prim == nullptr) {
  46. // Not a primitive cnode, it may have side effects or not,
  47. // We add it as an isolate node if its name is not '_' or empty.
  48. // this means that code like:
  49. // _ = func_call()
  50. // will be ignored even if func_call() has side effects.
  51. return !var_name.empty() && var_name != "_";
  52. }
  53. // Primitive cnode with side effects can be isolate nodes.
  54. auto effect_info = GetPrimEffectInfo(prim);
  55. bool has_effects = (effect_info.memory || effect_info.io);
  56. if (has_effects) {
  57. return true;
  58. }
  59. // Primitive cnode with 'no_eliminate' flag can be isolate nodes.
  60. return GetPrimitiveFlag(prim, ATTR_NO_ELIMINATE);
  61. }
  62. // Write variable records the variable name to corresponding node
  63. void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr &node) {
  64. MS_EXCEPTION_IF_NULL(node);
  65. MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " write var `" << var_name << "` with node "
  66. << node->DebugString();
  67. // The fallback feature is enabled in default.
  68. // Not support change the flag during the process is alive.
  69. static const auto use_fallback = (parser_.support_fallback() != "0");
  70. auto [iter, is_new_name] = assigned_vars_.emplace(var_name, std::make_pair(node, false));
  71. if (!is_new_name) {
  72. // If a cnode variable with same name already existed but not used,
  73. // add it as an isolate node. for example:
  74. // a = print(x)
  75. // a = print(y)
  76. // When we write variable 'a = print(y)',
  77. // the cnode 'print(x)' should added as an isolate node.
  78. auto is_used = iter->second.second;
  79. auto hidden_node = iter->second.first;
  80. auto is_isolated = CanBeIsolatedNode(var_name, hidden_node);
  81. if (!is_used && is_isolated) {
  82. MS_EXCEPTION_IF_NULL(hidden_node);
  83. MS_LOG(INFO) << "Isolated node found(Hidden), hidden_node: " << hidden_node->DebugString(2) << " is hidden by "
  84. << node->DebugString(2) << " with the same name, var_name: " << var_name << ", block: " << this
  85. << "/" << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
  86. << ", Line: " << trace::GetDebugInfo(hidden_node->debug_info(), "", kSourceLineTipDiscard);
  87. AddIsolatedNode(hidden_node);
  88. }
  89. iter->second = std::make_pair(node, false);
  90. if (use_fallback) {
  91. UpdateLocalPyParam(var_name, node);
  92. }
  93. } else {
  94. if (use_fallback) {
  95. AddLocalPyParam(var_name, node);
  96. }
  97. }
  98. }
  99. // Read variable from predecessors
  100. AnfNodePtr FunctionBlock::ReadVariable(const std::string &var_name) {
  101. MS_LOG(DEBUG) << "Read begin, var: " << var_name << ", block: " << ToString();
  102. // Get var node if it is found
  103. auto found = assigned_vars_.find(var_name);
  104. if (found != assigned_vars_.end()) {
  105. auto &node = found->second.first;
  106. MS_EXCEPTION_IF_NULL(node);
  107. // Mark the variable as used.
  108. found->second.second = true;
  109. auto iter = resolve_to_removable_phis_.find(node);
  110. if (iter != resolve_to_removable_phis_.end()) {
  111. return iter->second;
  112. }
  113. return node;
  114. }
  115. // Get var from predecessor block, if can't get then make a resolve node to it
  116. if (matured_) {
  117. // If only one predecessor block, read the definition of var from it.
  118. if (prev_blocks_.size() == 1) {
  119. auto block = prev_blocks_[0];
  120. MS_EXCEPTION_IF_NULL(block);
  121. auto res = block->ReadVariable(var_name);
  122. // The fallback feature is enabled in default.
  123. // Not support change the flag during the process is alive.
  124. static const auto use_fallback = (parser_.support_fallback() != "0");
  125. if (use_fallback) {
  126. MS_LOG(DEBUG) << "Update global params of block: " << ToString()
  127. << ", with previous block: " << block->ToString() << ",\nCurrent: " << py::str(global_py_params())
  128. << "\nInsert: " << py::str(block->global_py_params());
  129. UpdateGlobalPyParam(block->global_py_params());
  130. }
  131. return res;
  132. } else if (prev_blocks_.empty()) {
  133. // Get namespace and make Resolve
  134. auto it = var_to_resolve_.find(var_name);
  135. if (it != var_to_resolve_.end()) {
  136. return it->second;
  137. }
  138. MS_LOG(DEBUG) << "var: " << var_name;
  139. auto tmp_node = MakeResolveSymbol(var_name);
  140. var_to_resolve_[var_name] = tmp_node;
  141. return tmp_node;
  142. }
  143. }
  144. // If have more than one predecessor blocks then build a phi node.
  145. auto debug_info = std::make_shared<NodeDebugInfo>();
  146. debug_info->set_name(var_name);
  147. TraceGuard guard(std::make_shared<TracePhi>(debug_info));
  148. ParameterPtr phi_param = std::make_shared<Parameter>(func_graph());
  149. MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " generate phi node "
  150. << phi_param->ToString() << " for " << var_name;
  151. func_graph()->add_parameter(phi_param);
  152. phi_nodes_[phi_param] = var_name;
  153. WriteVariable(var_name, phi_param);
  154. if (matured_) {
  155. SetPhiArgument(phi_param);
  156. }
  157. return phi_param;
  158. }
  159. // Resolve Ast operator node
  160. AnfNodePtr FunctionBlock::MakeResolveAstOp(const py::object &op) {
  161. auto ast = parser_.ast();
  162. MS_EXCEPTION_IF_NULL(ast);
  163. TraceGuard trace_guard(parser_.GetLocation(op));
  164. py::tuple namespace_var = ast->CallParseModFunction(PYTHON_PARSE_GET_AST_NAMESPACE_SYMBOL, op);
  165. if (namespace_var.size() != 2) {
  166. MS_LOG(EXCEPTION) << "Resolve ast op failed, get namespace tuple size=" << namespace_var.size();
  167. }
  168. NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_AST, namespace_var[0]);
  169. SymbolPtr symbol = std::make_shared<Symbol>(namespace_var[1].cast<std::string>());
  170. MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
  171. return MakeResolve(name_space, symbol);
  172. }
  173. // Resolve class member, two possible: method, member variable
  174. AnfNodePtr FunctionBlock::MakeResolveClassMember(const std::string &attr) {
  175. auto ast = parser_.ast();
  176. MS_EXCEPTION_IF_NULL(ast);
  177. py::object namespace_var = ast->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, ast->obj());
  178. NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
  179. SymbolPtr symbol = std::make_shared<Symbol>(attr);
  180. MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
  181. return MakeResolve(name_space, symbol);
  182. }
  183. AnfNodePtr FunctionBlock::GetResolveNode(const py::tuple &info) {
  184. constexpr size_t namespace_index = 0;
  185. constexpr size_t symbol_index = 1;
  186. NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_SYMBOL_STR, info[namespace_index]);
  187. SymbolPtr symbol = std::make_shared<Symbol>(info[symbol_index].cast<std::string>());
  188. return MakeResolve(name_space, symbol);
  189. }
  190. AnfNodePtr FunctionBlock::HandleNamespaceInfo(const py::tuple &info) {
  191. constexpr size_t namespace_index = 0;
  192. constexpr size_t symbol_index = 1;
  193. constexpr size_t namespace_info_size = 2;
  194. if (info.size() != namespace_info_size) {
  195. MS_EXCEPTION(NameError) << "namespace info size should be 2, but got " << info.size();
  196. }
  197. // If namespace is None, the symbol is an undefined name.
  198. if (info[namespace_index].is_none()) {
  199. MS_EXCEPTION(NameError) << info[symbol_index].cast<std::string>();
  200. }
  201. return GetResolveNode(info);
  202. }
  203. AnfNodePtr FunctionBlock::HandleBuiltinNamespaceInfo(const py::tuple &info) {
  204. constexpr size_t closure_info_size = 2;
  205. constexpr size_t unsupported_info_size = 3;
  206. constexpr size_t supported_info_size = 4;
  207. constexpr size_t namespace_index = 0;
  208. constexpr size_t symbol_index = 1;
  209. constexpr size_t value_index = 2;
  210. if (info.size() < closure_info_size || info.size() > supported_info_size) {
  211. MS_EXCEPTION(NameError) << "namespace info size should be 2, 3 or 4, but got " << info.size();
  212. }
  213. // Handle closure namespace info.
  214. if (info.size() == closure_info_size) {
  215. // If namespace is None, the symbol is an undefined name.
  216. if (info[namespace_index].is_none()) {
  217. MS_EXCEPTION(NameError) << info[symbol_index].cast<std::string>();
  218. }
  219. return GetResolveNode(info);
  220. }
  221. // Handle global namespace info.
  222. auto resolved_node = GetResolveNode(info);
  223. if (info.size() == unsupported_info_size) {
  224. resolved_node->set_interpret(true);
  225. }
  226. SymbolPtr symbol = std::make_shared<Symbol>(info[symbol_index].cast<std::string>());
  227. py::object py_obj = info[value_index];
  228. AddGlobalPyParam(symbol->name(), py_obj);
  229. MS_LOG(INFO) << "[" << func_graph()->ToString() << "] Added global python symbol: {" << symbol->name() << " : "
  230. << py::str(py_obj) << "}";
  231. return resolved_node;
  232. }
  233. // Make a resolve node for symbol string
  234. AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) {
  235. MS_LOG(DEBUG) << "value: " << value;
  236. if (value.compare(0, strlen("self"), "self") == 0) {
  237. auto start = value.find_first_of('.') + 1;
  238. if (start >= value.size()) {
  239. MS_LOG(ERROR) << "Find invalid resolve symbol str: " << value;
  240. return nullptr;
  241. }
  242. auto bits_str = value.substr(start);
  243. return MakeResolveClassMember(bits_str);
  244. }
  245. auto ast = parser_.ast();
  246. MS_EXCEPTION_IF_NULL(ast);
  247. // The fallback feature is enabled in default.
  248. // Not support change the flag during the process is alive.
  249. static const auto use_fallback = (parser_.support_fallback() != "0");
  250. if (!use_fallback) {
  251. py::tuple namespace_info = ast->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value);
  252. return HandleNamespaceInfo(namespace_info);
  253. } else {
  254. py::tuple namespace_info = ast->CallParserObjMethod(PYTHON_PARSE_GET_BUILTIN_NAMESPACE_SYMBOL, value);
  255. return HandleBuiltinNamespaceInfo(namespace_info);
  256. }
  257. }
  258. AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) {
  259. auto ast = parser_.ast();
  260. MS_EXCEPTION_IF_NULL(ast);
  261. py::tuple namespace_var = ast->CallParseModFunction(PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL, value);
  262. const size_t namespace_var_size = 2;
  263. if (namespace_var.size() < namespace_var_size) {
  264. MS_EXCEPTION(NameError) << "namespace_var is less than 2";
  265. }
  266. NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_COMMON_OPS, namespace_var[0]);
  267. SymbolPtr symbol = std::make_shared<Symbol>(namespace_var[1].cast<std::string>());
  268. MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
  269. return MakeResolve(name_space, symbol);
  270. }
  271. AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const SymbolPtr &resolve_symbol) {
  272. MS_LOG(DEBUG) << "MakeResolve for " << (name_space ? (std::string)py::str(name_space->obj()) : "null namespace")
  273. << " , " << (resolve_symbol ? (std::string)resolve_symbol->symbol() : "null resolve symbol.");
  274. ValueNodePtr module_node = NewValueNode(name_space);
  275. ValueNodePtr symbol_node = NewValueNode(resolve_symbol);
  276. auto node = func_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimResolve), module_node, symbol_node});
  277. return node;
  278. }
  279. AnfNodePtr FunctionBlock::MakeInterpret(const std::string &script_text, const AnfNodePtr &global_dict_node,
  280. const AnfNodePtr &local_dict_node, const AnfNodePtr &orig_node) {
  281. MS_LOG(DEBUG) << "MakeInterpret for " << script_text;
  282. ScriptPtr script = std::make_shared<Script>(script_text);
  283. auto script_node = NewValueNode(script);
  284. auto node = func_graph_->NewCNodeInOrder(
  285. {NewValueNode(prim::kPrimPyInterpret), script_node, global_dict_node, local_dict_node});
  286. node->set_interpreted_node(orig_node);
  287. return node;
  288. }
  289. // Add input for the block's phi parameter
  290. void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) {
  291. MS_EXCEPTION_IF_NULL(phi);
  292. TraceGuard trace_guard(std::make_shared<TraceResolve>(phi->debug_info()));
  293. std::string var = phi_nodes_[phi];
  294. MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " set phi " << phi->ToString()
  295. << " for var `" << var << "`";
  296. auto removable = CollectRemovablePhi(phi);
  297. // If the phi node is not necessary, not need to add to jumps_ of the prev blocks.
  298. if (removable) {
  299. MS_LOG(DEBUG) << "remove the phi when call graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
  300. << " var `" << var << "`";
  301. return;
  302. }
  303. for (auto &pred : prev_blocks_) {
  304. MS_EXCEPTION_IF_NULL(pred);
  305. MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " pred_blocks_ "
  306. << (pred->func_graph_ ? pred->func_graph_->ToString() : "FG(Null)");
  307. AnfNodePtr arg_node = pred->ReadVariable(var);
  308. CNodePtr jump = pred->jumps_[this];
  309. MS_EXCEPTION_IF_NULL(jump);
  310. jump->add_input(arg_node);
  311. }
  312. }
  313. AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const ParameterPtr &phi) {
  314. AnfNodePtr arg_node = nullptr;
  315. MS_LOG(DEBUG) << "Prev_blocks size: " << prev_blocks_.size();
  316. for (auto &prev : prev_blocks_) {
  317. MS_EXCEPTION_IF_NULL(prev);
  318. AnfNodePtr temp_node = prev->ReadVariable(var);
  319. MS_EXCEPTION_IF_NULL(temp_node);
  320. if (temp_node != phi) {
  321. if (arg_node == nullptr) {
  322. arg_node = temp_node;
  323. MS_LOG(DEBUG) << "graph " << (prev->func_graph_ ? prev->func_graph_->ToString() : "FG(Null)") << " phi "
  324. << (phi ? phi->ToString() : "null") << " may be replaced by node " << arg_node->DebugString();
  325. } else if (temp_node == arg_node) {
  326. MS_LOG(DEBUG) << "graph " << (prev->func_graph_ ? prev->func_graph_->ToString() : "FG(Null)") << " phi "
  327. << (phi ? phi->ToString() : "null") << " is same as node " << arg_node->DebugString();
  328. } else {
  329. MS_LOG(DEBUG) << "phi " << (phi ? phi->ToString() : "null")
  330. << " cannot be removed as it assigns to different node. node1: " << arg_node->DebugString()
  331. << ", node2: " << temp_node->DebugString();
  332. return nullptr;
  333. }
  334. }
  335. }
  336. return arg_node;
  337. }
  338. // Check if there is removable unnecessary phi node in this graph.
  339. // As per the FIRM TR 3.2, a phi node can be remove if:
  340. // <Quote>
  341. // If all arguments of a φ-function are the same value s or the φfunction itself,
  342. // then we remove the φ-function and let all users directly uses. We call such a
  343. // φ-function obviously unnecessary.
  344. // When we removed a φ-function p, then we recursively try to apply this simplification
  345. // rule with all (former) users of p, because they may have become obviously unnecessary
  346. // due to the removal of p
  347. // <Quote>
  348. // phi node in graph will be removed after the whole function is parsed in a DFS visit
  349. // of that graph.The reason is :
  350. // 1. when this function is called, not all usage of this phi node had bound to the
  351. // graph of this function block, some may stay in vars_ in other blocks.
  352. // 2. it's costly to iterate the graph to replace the phi for each phi.
  353. // Args: phi: This parameter node is functioning as a phi node.
  354. bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
  355. MS_EXCEPTION_IF_NULL(phi);
  356. std::string var_name = phi_nodes_[phi];
  357. MS_LOG(DEBUG) << "check phi " << phi->DebugString() << " for " << var_name;
  358. if (prev_blocks_.empty()) {
  359. MS_LOG(DEBUG) << "no phi " << phi->DebugString() << " for var " << var_name;
  360. return false;
  361. }
  362. AnfNodePtr arg_node = SearchReplaceNode(var_name, phi);
  363. if (arg_node != nullptr) {
  364. arg_node->set_debug_info(phi->debug_info());
  365. MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " phi " << phi->ToString()
  366. << " can be replaced with " << arg_node->DebugString();
  367. // Replace var with new one. This equal to statement in TR "v0 is immediately replaced by v1."
  368. WriteVariable(var_name, arg_node);
  369. removable_phis_[phi] = arg_node;
  370. resolve_to_removable_phis_[arg_node] = phi;
  371. // The following equal to statement "The φ-function defining v1, which now reads φ(v2, v1), is optimized
  372. // recursively". check if phi1 is assigned with this phi before, then phi1 can be replaced with arg_node.
  373. for (auto &prev : prev_blocks_) {
  374. MS_EXCEPTION_IF_NULL(prev);
  375. if (!prev->matured_) {
  376. continue;
  377. }
  378. for (auto &phi_iter : prev->removable_phis_) {
  379. MS_EXCEPTION_IF_NULL(phi_iter.second);
  380. if (phi_iter.second->isa<Parameter>()) {
  381. const auto &param = phi_iter.second->cast<ParameterPtr>();
  382. if (param == phi) {
  383. MS_LOG(DEBUG) << "graph " << (prev->func_graph_ ? prev->func_graph_->ToString() : "FG(Null)") << " var "
  384. << phi_iter.first->DebugString() << " can be replaced from " << param->DebugString()
  385. << " with " << arg_node->DebugString() << " in graph "
  386. << (arg_node->func_graph() ? arg_node->func_graph()->ToString() : "FG(Null)");
  387. prev->removable_phis_[phi_iter.first] = arg_node;
  388. }
  389. }
  390. }
  391. }
  392. return true;
  393. }
  394. return false;
  395. }
  396. // A block should be marked matured if its predecessor blocks have been processed
  397. void FunctionBlock::Mature() {
  398. const auto &graph_params = func_graph_->parameters();
  399. for (auto &param_itr : graph_params) {
  400. MS_EXCEPTION_IF_NULL(param_itr);
  401. auto param = param_itr->cast<ParameterPtr>();
  402. if (phi_nodes_.find(param) != phi_nodes_.cend()) {
  403. SetPhiArgument(param);
  404. }
  405. }
  406. matured_ = true;
  407. }
  408. // Force the condition node to bool using bool operation
  409. CNodePtr FunctionBlock::ForceToBoolNode(const AnfNodePtr &cond) {
  410. MS_EXCEPTION_IF_NULL(cond);
  411. CNodePtr op_apply_node = func_graph_->NewCNodeInOrder({MakeResolveOperation(NAMED_PRIMITIVE_BOOL), cond});
  412. return op_apply_node;
  413. }
  414. CNodePtr FunctionBlock::ForceToWhileCond(const AnfNodePtr &cond) {
  415. MS_EXCEPTION_IF_NULL(cond);
  416. TraceGuard trace_guard(std::make_shared<TraceForceWhileCond>(cond->debug_info()));
  417. CNodePtr op_apply_node = func_graph_->NewCNodeInOrder({MakeResolveOperation("while_cond"), cond});
  418. return op_apply_node;
  419. }
  420. // Perform a jump from this block to target block
  421. void FunctionBlock::Jump(const FunctionBlockPtr &target_block, const std::vector<AnfNodePtr> &args) {
  422. MS_LOG(DEBUG) << "Jump from block: " << ToString() << " to block: " << target_block->ToString();
  423. MS_EXCEPTION_IF_NULL(target_block);
  424. if (is_dead_block_) {
  425. MS_LOG(DEBUG) << "Dead code block should not jump to other block! block: " << ToString();
  426. return;
  427. }
  428. if (func_graph_->get_return() != nullptr) {
  429. MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: "
  430. << trace::GetDebugInfo(func_graph_->get_return()->debug_info());
  431. }
  432. std::vector<AnfNodePtr> input_nodes;
  433. input_nodes.emplace_back(NewValueNode(target_block->func_graph()));
  434. (void)std::copy(args.begin(), args.end(), std::back_inserter(input_nodes));
  435. CNodePtr jump = func_graph_->NewCNodeInOrder(std::move(input_nodes));
  436. jumps_[target_block.get()] = jump;
  437. target_block->AddPrevBlock(shared_from_this());
  438. func_graph_->set_output(jump);
  439. }
  440. // Perform a conditional jump using switch operation.
  441. // The first CNode select graph with condition, and than execute this graph
  442. void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &true_block,
  443. const FunctionBlockPtr &false_block, bool) {
  444. MS_EXCEPTION_IF_NULL(true_block);
  445. MS_EXCEPTION_IF_NULL(false_block);
  446. if (func_graph_->get_return() != nullptr) {
  447. MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: "
  448. << trace::GetDebugInfo(func_graph_->get_return()->debug_info());
  449. }
  450. CNodePtr switch_app =
  451. func_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), condNode, NewValueNode(true_block->func_graph()),
  452. NewValueNode(false_block->func_graph())});
  453. CNodePtr switch_app_new = func_graph_->NewCNodeInOrder({switch_app});
  454. func_graph_->set_output(switch_app_new);
  455. }
  456. // Create cnode for the assign statement like 'self.target = source'.
  457. // convert it to 'P.Assign(self.target, source)' and then add the cnode as isolate node.
  458. void FunctionBlock::SetStateAssign(const AnfNodePtr &target, const AnfNodePtr &source) {
  459. const std::string primitive_name("assign");
  460. const std::string module_name("mindspore.ops.functional");
  461. ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true));
  462. auto assign_node = func_graph_->NewCNodeInOrder({assign_op, target, source});
  463. MS_LOG(DEBUG) << "Isolated node found(Assign), assign_node: " << assign_node->DebugString(2) << ", block: " << this
  464. << "/" << func_graph_->ToString()
  465. << ", Line: " << trace::GetDebugInfo(assign_node->debug_info(), "", kSourceLineTipDiscard);
  466. AddIsolatedNode(assign_node);
  467. }
  468. void FunctionBlock::FindIsolatedNodes() {
  469. //
  470. // Search isolate nodes from variables, for example,
  471. // variable 'a' is an isolate node in below code:
  472. //
  473. // def construct(self, x, y):
  474. // a = print(x) # isolate node
  475. // return x + y
  476. //
  477. std::set<AnfNodePtr> used;
  478. // Find used variables.
  479. for (const auto &var : assigned_vars_) {
  480. auto &node = var.second.first;
  481. if (node == nullptr) {
  482. continue;
  483. }
  484. bool is_used = var.second.second;
  485. if (is_used) {
  486. used.emplace(node);
  487. }
  488. }
  489. // Add isolated nodes which is unused var but not found in used set.
  490. for (const auto &var : assigned_vars_) {
  491. auto &node = var.second.first;
  492. bool is_used = var.second.second;
  493. if (node == nullptr || is_used) {
  494. continue;
  495. }
  496. auto &var_name = var.first;
  497. if (used.find(node) == used.end() && CanBeIsolatedNode(var_name, node)) {
  498. MS_LOG(INFO) << "Isolated node found(NoUse), node: " << node->DebugString(2) << ", var_name: " << var_name
  499. << ", block: " << this << "/" << (func_graph() ? func_graph()->ToString() : "FG(Null)")
  500. << ", Line: " << trace::GetDebugInfo(node->debug_info(), "", kSourceLineTipDiscard);
  501. AddIsolatedNode(node);
  502. }
  503. }
  504. }
  505. void FunctionBlock::AddIsolatedNode(const AnfNodePtr &target) { isolated_nodes_.add(target); }
  506. void FunctionBlock::AttachIsolatedNodesBeforeReturn() {
  507. if (isolated_nodes_.empty()) {
  508. return;
  509. }
  510. std::vector<AnfNodePtr> states;
  511. states.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  512. constexpr int recursive_level = 2;
  513. for (auto &node : isolated_nodes_) {
  514. MS_EXCEPTION_IF_NULL(node);
  515. MS_LOG(DEBUG) << "Adding dependency, node: " << node->DebugString(recursive_level) << " in "
  516. << func_graph_->ToString();
  517. if (node->func_graph() == func_graph_) {
  518. states.emplace_back(node);
  519. } else {
  520. MS_LOG(INFO) << "Ignored FV dependency, node: " << node->DebugString(recursive_level) << " in "
  521. << func_graph_->ToString();
  522. }
  523. }
  524. isolated_nodes_.clear();
  525. AnfNodePtr state = nullptr;
  526. constexpr size_t no_state_size = 1;
  527. constexpr size_t only_one_state_size = 2;
  528. if (states.size() == no_state_size) {
  529. // Only MakeTuple, no state left.
  530. return;
  531. } else if (states.size() == only_one_state_size) {
  532. // If there are only MakeTuple and another node in states(the states size is 2),
  533. // do not need to MakeTuple, just use the node.
  534. state = states[1];
  535. } else {
  536. state = func_graph_->NewCNode(std::move(states));
  537. }
  538. AnfNodePtr old_output = nullptr;
  539. auto return_node = func_graph_->get_return();
  540. if (return_node) {
  541. const size_t return_input_size = 2;
  542. if (return_node->inputs().size() < return_input_size) {
  543. MS_LOG(EXCEPTION) << "Length of inputs of output node is less than 2";
  544. }
  545. old_output = return_node->input(1);
  546. } else {
  547. old_output = NewValueNode(kNone);
  548. }
  549. AnfNodePtr stop_grad_node = func_graph_->NewCNode({NewValueNode(prim::kPrimStopGradient), state});
  550. CNodePtr depend_node = func_graph_->NewCNode({NewValueNode(prim::kPrimDepend), old_output, stop_grad_node});
  551. // We add this attribute for @constexpr use scene, since we must infer them before other nodes.
  552. // That means isolated nodes will be evaluated first. It's not complete, but works in most scenes.
  553. depend_node->AddAttr(kAttrTopoSortRhsFirst, MakeValue(true));
  554. MS_EXCEPTION_IF_NULL(state);
  555. MS_LOG(INFO) << "Attached for side-effect nodes, depend_node: " << depend_node->DebugString()
  556. << ", state: " << state->DebugString(2);
  557. func_graph_->set_output(depend_node, true);
  558. }
  559. void FunctionBlock::SetAsDeadBlock() { is_dead_block_ = true; }
  560. } // namespace parse
  561. } // namespace mindspore