You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

func_graph_extends.cc 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ir/func_graph.h"
  17. #include <algorithm>
  18. #include <sstream>
  19. #include <utility>
  20. #include "ir/manager.h"
  21. #include "ir/func_graph_cloner.h"
  22. #include "operator/ops.h"
  23. #include "utils/ordered_set.h"
  24. #include "pipeline/static_analysis/abstract_value.h"
  25. #include "pipeline/static_analysis/static_analysis.h"
  26. #include "pipeline/static_analysis/abstract_function.h"
  27. #include "debug/anf_ir_dump.h"
  28. #include "debug/trace.h"
  29. #include "debug/draw.h"
  30. #include "debug/label.h"
  31. namespace mindspore {
  32. using mindspore::abstract::AbstractFunction;
  33. using mindspore::abstract::AbstractFunctionPtr;
  34. using mindspore::abstract::AnalysisContextPtr;
  35. using mindspore::abstract::PrimitiveAbstractClosure;
  36. using mindspore::abstract::VirtualAbstractClosure;
  37. AbstractFunctionPtr FuncGraph::abstract() {
  38. AbstractBasePtrList args_spec_list;
  39. for (auto &p : parameters_) {
  40. MS_EXCEPTION_IF_NULL(p);
  41. if (p->abstract() == nullptr) {
  42. MS_LOG(ERROR) << "Error!!";
  43. return nullptr;
  44. }
  45. args_spec_list.push_back(p->abstract());
  46. }
  47. if (nullptr == output()) {
  48. MS_LOG(ERROR) << "Error func graph no output";
  49. return nullptr;
  50. }
  51. return std::make_shared<VirtualAbstractClosure>(args_spec_list, output()->abstract());
  52. }
  53. abstract::AbstractBasePtr FuncGraph::MakeAbstractClosure(const abstract::AnalysisContextPtr &context) {
  54. AnalysisContextPtr temp_context = context;
  55. if (temp_context == nullptr) {
  56. temp_context = abstract::AnalysisContext::DummyContext();
  57. }
  58. return std::make_shared<abstract::FuncGraphAbstractClosure>(shared_from_base<FuncGraph>(), temp_context);
  59. }
  60. void FuncGraph::set_output(const AnfNodePtr &value, bool force_new_ret) {
  61. if (force_new_ret || return_ == nullptr) {
  62. std::vector<AnfNodePtr> params({NewValueNode(prim::kPrimReturn), value});
  63. FuncGraphPtr this_graph = shared_from_base<FuncGraph>();
  64. return_ = this_graph->NewCNode(params);
  65. } else {
  66. if (manager_.lock()) {
  67. manager_.lock()->SetEdge(return_, 1, value);
  68. } else {
  69. return_->set_input(1, value);
  70. }
  71. }
  72. return_->set_abstract(value->abstract());
  73. AnfNodePtr input0 = return_->input(0);
  74. PrimitivePtr return_prim = prim::kPrimReturn;
  75. auto f = std::make_shared<PrimitiveAbstractClosure>(return_prim, input0);
  76. input0->set_abstract(f);
  77. }
  78. void FuncGraph::DumpFuncGraph(const std::string &path) { draw::Draw(path + ".dot", shared_from_base<FuncGraph>()); }
  79. void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph,
  80. std::vector<AnfNodePtr> *specialized_parameter_list,
  81. std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes, int variable_args_count,
  82. int pos_args_input_count) {
  83. // if there is variable argument, pass the input arguments that does not match positional args to it as a tuple
  84. if (specialized_graph->has_vararg()) {
  85. TraceManager::DebugTrace(
  86. std::make_shared<TraceGenerateVarArg>(specialized_graph->GetVariableArgParameter()->debug_info()));
  87. std::vector<AnfNodePtr> var_param_tuple_nodes;
  88. var_param_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple));
  89. if (variable_args_count < 0) {
  90. MS_LOG(EXCEPTION) << "Function:" << this->ToString() << ", variable_args_count " << variable_args_count
  91. << " were given.";
  92. }
  93. // for python variable argument input , there is no upper limit
  94. for (int i = 0; i < variable_args_count; ++i) {
  95. ParameterPtr p = std::make_shared<Parameter>(specialized_graph);
  96. std::string param_name = specialized_graph->GetVariableArgName() + std::to_string(i);
  97. p->set_name(param_name);
  98. MS_EXCEPTION_IF_NULL(p->debug_info());
  99. p->debug_info()->set_name(param_name);
  100. var_param_tuple_nodes.push_back(p);
  101. MS_EXCEPTION_IF_NULL(specialized_parameter_list);
  102. specialized_parameter_list->push_back(p);
  103. }
  104. auto var_tuple_param = specialized_graph->NewCNode(var_param_tuple_nodes);
  105. (void)repl_nodes->emplace(specialized_graph->GetVariableArgParameter(), var_tuple_param);
  106. TraceManager::EndTrace();
  107. } else if (variable_args_count > 0) {
  108. MS_LOG(EXCEPTION) << "Function:" << this->ToString() << " takes " << this->GetPositionalArgsCount()
  109. << " positional arguments, but " << pos_args_input_count << " were given.";
  110. }
  111. }
  112. void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph,
  113. std::vector<AnfNodePtr> *specialized_parameter_list,
  114. const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list,
  115. std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) {
  116. std::vector<AnfNodePtr> kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
  117. std::vector<AnfNodePtr> kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
  118. for (const auto &kwarg : kwarg_list) {
  119. MS_EXCEPTION_IF_NULL(kwarg);
  120. std::string kw_param_name = kwarg->get_key();
  121. MS_EXCEPTION_IF_NULL(specialized_graph);
  122. AnfNodePtr param_node = specialized_graph->GetParameterByName(kw_param_name);
  123. // if not find correspoding parameter node
  124. if (param_node == nullptr) {
  125. if (!has_kwarg()) {
  126. MS_LOG(EXCEPTION) << "Got unexpected keyword argument: " << kw_param_name;
  127. } else {
  128. ParameterPtr p = std::make_shared<Parameter>(specialized_graph);
  129. std::string param_name = specialized_graph->GetVariableKwargName() + "[" + kw_param_name + "]";
  130. MS_EXCEPTION_IF_NULL(specialized_parameter_list);
  131. auto find_kw_arg_in_list = std::any_of(specialized_parameter_list->begin(), specialized_parameter_list->end(),
  132. [param_name](const AnfNodePtr &node) {
  133. MS_EXCEPTION_IF_NULL(node);
  134. auto param = node->cast<ParameterPtr>();
  135. return param != nullptr && param->name() == param_name;
  136. });
  137. if (find_kw_arg_in_list) {
  138. MS_LOG(EXCEPTION) << "Multiply values for keyword argument:" << kw_param_name;
  139. }
  140. p->set_name(param_name);
  141. p->debug_info()->set_name(param_name);
  142. kwarg_keys_tuple_nodes.push_back(NewValueNode(kw_param_name));
  143. auto extract_node =
  144. specialized_graph->NewCNode({NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), p});
  145. kwarg_values_tuple_nodes.push_back(extract_node);
  146. specialized_parameter_list->push_back(p);
  147. }
  148. } else {
  149. auto node_itr = std::find(specialized_parameter_list->begin(), specialized_parameter_list->end(), param_node);
  150. // multiply values found given for parameter
  151. if (node_itr != specialized_parameter_list->end()) {
  152. MS_LOG(EXCEPTION) << "Multiply values for specific argument:" << kw_param_name;
  153. } else {
  154. specialized_parameter_list->push_back(param_node);
  155. auto extract_node = specialized_graph->NewCNode(
  156. {NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), param_node});
  157. (void)repl_nodes->emplace(param_node, extract_node);
  158. }
  159. }
  160. }
  161. GenerateKwargReplNode(specialized_graph, repl_nodes, kwarg_keys_tuple_nodes, kwarg_values_tuple_nodes);
  162. }
  163. void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr &specialized_graph,
  164. std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes,
  165. const std::vector<AnfNodePtr> &kwarg_keys_tuple_nodes,
  166. const std::vector<AnfNodePtr> &kwarg_values_tuple_nodes) {
  167. if (has_kwarg()) {
  168. MS_EXCEPTION_IF_NULL(specialized_graph);
  169. TraceManager::DebugTrace(
  170. std::make_shared<TraceGenerateKwArg>(specialized_graph->GetVariableKwargParameter()->debug_info()));
  171. auto make_tuple_keys = specialized_graph->NewCNode(kwarg_keys_tuple_nodes);
  172. auto make_tuple_values = specialized_graph->NewCNode(kwarg_values_tuple_nodes);
  173. auto make_dict_node =
  174. specialized_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), make_tuple_keys, make_tuple_values});
  175. MS_EXCEPTION_IF_NULL(repl_nodes);
  176. (void)repl_nodes->emplace(specialized_graph->GetVariableKwargParameter(), make_dict_node);
  177. TraceManager::EndTrace();
  178. }
  179. }
  180. bool FuncGraph::NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list) {
  181. // if the function does not have any vararg/kwarg/kwonly/default value/kw args input
  182. // return the original graph
  183. if (!has_vararg() && kwonlyargs_count() == 0 && !has_kwarg() && GetDefaultValueCount() == 0 && kwarg_list.empty()) {
  184. return false;
  185. }
  186. // if the graph is generated for specific input, do not need to generate again
  187. if (is_generated()) {
  188. return false;
  189. }
  190. return true;
  191. }
  192. void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph,
  193. const std::vector<AnfNodePtr> &specialized_parameter_list,
  194. std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) {
  195. MS_EXCEPTION_IF_NULL(specialized_graph);
  196. for (size_t i = 0; i < specialized_graph->parameters().size() - hyper_param_count(); ++i) {
  197. auto param_node = specialized_graph->parameters()[i];
  198. MS_EXCEPTION_IF_NULL(param_node);
  199. auto param_name = param_node->cast<ParameterPtr>()->name();
  200. auto node_itr = std::find(specialized_parameter_list.begin(), specialized_parameter_list.end(), param_node);
  201. if (node_itr != specialized_parameter_list.end()) {
  202. continue;
  203. }
  204. if (param_name == specialized_graph->GetVariableArgName() ||
  205. param_name == specialized_graph->GetVariableKwargName()) {
  206. continue;
  207. }
  208. auto default_value = specialized_graph->GetDefaultValueByName(param_name);
  209. if (default_value == nullptr) {
  210. MS_LOG(EXCEPTION) << "Miss argument input for parameter:" << param_name;
  211. }
  212. MS_EXCEPTION_IF_NULL(repl_nodes);
  213. (void)repl_nodes->emplace(param_node, default_value);
  214. }
  215. }
  216. FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) {
  217. std::vector<abstract::AbstractKeywordArgPtr> kwarg_list;
  218. size_t arguments_count = args_spec_list.size();
  219. for (const auto &arg : args_spec_list) {
  220. // if it is a keyword argument
  221. MS_EXCEPTION_IF_NULL(arg);
  222. if (arg->isa<abstract::AbstractKeywordArg>()) {
  223. kwarg_list.push_back(dyn_cast<abstract::AbstractKeywordArg>(arg));
  224. }
  225. }
  226. if (!NeedGenerate(kwarg_list)) {
  227. return shared_from_base<FuncGraph>();
  228. }
  229. FuncGraphPtr specialized_graph = BasicClone(shared_from_base<FuncGraph>());
  230. size_t kwarg_count = kwarg_list.size();
  231. int pos_args_input_count = SizeToInt(arguments_count - kwarg_count - hyper_param_count());
  232. int pos_args_count = std::min(pos_args_input_count, this->GetPositionalArgsCount());
  233. int variable_args_count = pos_args_input_count - pos_args_count;
  234. std::vector<AnfNodePtr> specialized_parameter_list;
  235. std::unordered_map<AnfNodePtr, AnfNodePtr> repl_nodes;
  236. // the parameters that has arg input, copy from original parameters
  237. for (size_t i = 0; i < IntToSize(pos_args_count); ++i) {
  238. specialized_parameter_list.push_back(specialized_graph->parameters()[i]);
  239. }
  240. GenerateVarParams(specialized_graph, &specialized_parameter_list, &repl_nodes, variable_args_count,
  241. pos_args_input_count);
  242. GenerateKwParams(specialized_graph, &specialized_parameter_list, kwarg_list, &repl_nodes);
  243. GenerateDefaultValue(specialized_graph, specialized_parameter_list, &repl_nodes);
  244. // append hyper parameter to specialized_parameter_list
  245. MS_EXCEPTION_IF_NULL(specialized_graph);
  246. auto params = specialized_graph->parameters();
  247. (void)std::transform(params.end() - SizeToInt(hyper_param_count()), params.end(),
  248. std::back_inserter(specialized_parameter_list), [](const AnfNodePtr &node) { return node; });
  249. std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(specialized_graph, false);
  250. auto tr = manager->Transact();
  251. for (auto &node_pair : repl_nodes) {
  252. MS_LOG(DEBUG) << "GenerateGraph replace:" << node_pair.first->DebugString() << "-"
  253. << node_pair.second->DebugString();
  254. (void)tr.Replace(node_pair.first, node_pair.second);
  255. }
  256. tr.SetParameters(specialized_graph, specialized_parameter_list);
  257. tr.Commit();
  258. specialized_graph->set_has_kwarg(false);
  259. specialized_graph->set_has_vararg(false);
  260. specialized_graph->set_kwonlyargs_count(0);
  261. specialized_graph->ClearDefaultValues();
  262. specialized_graph->set_is_generate(true);
  263. return specialized_graph;
  264. }
  265. const char kPrimHasEffect[] = "_side_effect_flag";
  266. bool FuncGraph::HasEffect(const CNodePtr &cnode) {
  267. auto prim = GetCNodePrimitive(cnode);
  268. if (prim != nullptr && prim->isa<prim::DoSignaturePrimitive>()) {
  269. auto do_sig = prim->cast<prim::DoSignaturePrimitivePtr>();
  270. auto prim_val = do_sig->function();
  271. if (prim_val != nullptr && prim_val->isa<Primitive>()) {
  272. prim = prim_val->cast<PrimitivePtr>();
  273. } else {
  274. prim = nullptr;
  275. }
  276. }
  277. if (prim != nullptr) {
  278. auto effect_val = prim->GetAttr(kPrimHasEffect);
  279. if (effect_val && effect_val->isa<BoolImm>()) {
  280. auto effect_bool = GetValue<bool>(effect_val);
  281. return effect_bool;
  282. }
  283. }
  284. return false;
  285. }
  286. std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment) {
  287. std::shared_ptr<OrderedSet<CNodePtr>> roots = std::make_shared<OrderedSet<CNodePtr>>(segment);
  288. for (const auto &node : segment) {
  289. if (roots->size() == 1) {
  290. return roots;
  291. }
  292. auto input_size = node->size();
  293. for (size_t i = 0; i < input_size; i++) {
  294. auto in_node = node->input(i);
  295. auto in_cnode = in_node->cast<CNodePtr>();
  296. if (in_cnode != nullptr) {
  297. (void)roots->erase(in_cnode);
  298. }
  299. }
  300. }
  301. return roots;
  302. }
  303. std::shared_ptr<OrderedSet<CNodePtr>> FindLeaves(const std::vector<CNodePtr> &segment) {
  304. std::shared_ptr<OrderedSet<CNodePtr>> nodes = std::make_shared<OrderedSet<CNodePtr>>(segment);
  305. for (const auto &node : segment) {
  306. if (nodes->size() == 1) {
  307. return nodes;
  308. }
  309. if (IsPrimitiveCNode(node, prim::kPrimSwitch)) {
  310. (void)nodes->erase(node);
  311. continue;
  312. }
  313. auto input_size = node->size();
  314. for (size_t i = 0; i < input_size; i++) {
  315. auto in_node = node->input(i);
  316. if (!in_node->isa<CNode>()) {
  317. continue;
  318. }
  319. auto in_cnode = in_node->cast<CNodePtr>();
  320. if (in_cnode != nullptr) {
  321. if (std::find(segment.begin(), segment.end(), in_cnode) != segment.end()) {
  322. (void)nodes->erase(node);
  323. break;
  324. }
  325. }
  326. }
  327. }
  328. return nodes;
  329. }
  330. void FuncGraph::ReleaseFullOrderToEffectOrder() {
  331. MS_LOG(DEBUG) << "Flag has_effect " << has_flag(GRAPH_FLAG_HAS_EFFECT) << ".";
  332. if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
  333. std::list<AnfNodePtr> depends_order;
  334. std::vector<CNodePtr> segment;
  335. for (const auto &cnode : order_) {
  336. if (IsPrimitiveCNode(cnode, prim::kPrimReturn)) {
  337. continue;
  338. }
  339. if (HasEffect(cnode)) {
  340. MS_LOG(DEBUG) << "Meet a effect node " << cnode->DebugString() << ".";
  341. if (segment.size() > 0) {
  342. auto roots = FindRoots(segment);
  343. for (auto iter = roots->begin(); iter != roots->end(); (void)iter++) {
  344. depends_order.push_back(*iter);
  345. }
  346. }
  347. segment.clear();
  348. depends_order.push_back(cnode);
  349. } else {
  350. MS_LOG(DEBUG) << "Meet a general node " << cnode->DebugString() << ".";
  351. segment.push_back(cnode);
  352. }
  353. }
  354. if (segment.size() > 1) {
  355. auto roots = FindRoots(segment);
  356. for (auto iter = roots->begin(); iter != roots->end(); (void)iter++) {
  357. depends_order.push_back(*iter);
  358. }
  359. }
  360. std::vector<AnfNodePtr> depend_inputs;
  361. auto old_ret = output();
  362. for (auto iter = depends_order.rbegin(); iter != depends_order.rend(); (void)iter++) {
  363. if (*iter != old_ret) {
  364. depend_inputs.push_back(*iter);
  365. }
  366. }
  367. set_flag(GRAPH_FLAG_HAS_EFFECT, false);
  368. set_flag(GRAPH_FLAG_EFFECT_PATIAL_ORDER, true);
  369. if (!depend_inputs.empty()) {
  370. SetEffectDepends(depend_inputs);
  371. }
  372. }
  373. }
  374. void FuncGraph::SetEffectDepends(const std::vector<AnfNodePtr> &depend_inputs) {
  375. auto old_ret = output();
  376. std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimDepend), old_ret};
  377. (void)inputs.insert(inputs.end(), depend_inputs.begin(), depend_inputs.end());
  378. auto new_ret = NewCNode(inputs);
  379. auto mng = manager();
  380. if (mng) {
  381. (void)mng->Replace(old_ret, new_ret);
  382. } else {
  383. return_->set_input(1, new_ret);
  384. }
  385. }
  386. } // namespace mindspore