You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

program_specialize.cc 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/static_analysis/program_specialize.h"
  19. #include <algorithm>
  20. #include <exception>
  21. #include "./common.h"
  22. #include "operator/ops.h"
  23. #include "operator/composite/do_signature.h"
  24. #include "pipeline/static_analysis/abstract_function.h"
  25. #include "utils/graph_utils.h"
  26. #include "utils/log_adapter.h"
  27. #include "utils/profile.h"
  28. #include "debug/trace.h"
  29. namespace mindspore {
  30. namespace abstract {
  31. namespace {
  32. inline AbstractBasePtr GetEvaluatedValueWrap(const AnfNodeConfigPtr &conf) {
  33. if (conf->node()->intermediate_abstract()) {
  34. return conf->node()->intermediate_abstract();
  35. }
  36. return conf->GetEvaluatedValue()->abstract();
  37. }
  38. AnfNodePtr BuildValueNode(const ValuePtr &v, const AbstractBasePtr &abs_base) {
  39. AnfNodePtr value_node = NewValueNode(v);
  40. value_node->set_abstract(abs_base);
  41. MS_LOG(DEBUG) << "Create ValueNode: " << value_node->ToString() << ", with abstract: " << abs_base->ToString();
  42. return value_node;
  43. }
  44. bool IsVisible(FuncGraphPtr fg, const FuncGraphPtr &parent) {
  45. while (fg != nullptr && fg != parent) {
  46. fg = fg->parent();
  47. }
  48. return fg == parent;
  49. }
  50. } // namespace
  51. FuncGraphPtr ProgramSpecializer::Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context) {
  52. MS_EXCEPTION_IF_NULL(fg);
  53. MS_EXCEPTION_IF_NULL(context);
  54. MS_LOG(DEBUG) << "Specialize topmost function graph: " << context->func_graph()->ToString();
  55. return SpecializeFuncGraph(fg, context);
  56. }
  57. FuncGraphPtr ProgramSpecializer::SpecializeFuncGraph(const FuncGraphPtr &fg, const AnalysisContextPtr &context) {
  58. MS_EXCEPTION_IF_NULL(fg);
  59. MS_EXCEPTION_IF_NULL(context);
  60. auto iter = specializations_.find(context->SpecializeKey());
  61. if (iter != specializations_.end()) {
  62. return iter->second->specialized_func_graph();
  63. }
  64. std::shared_ptr<FuncGraphSpecializer> fg_spec = std::make_shared<FuncGraphSpecializer>(this, fg, context);
  65. FuncGraphPtr fg2 = fg_spec->specialized_func_graph();
  66. specializations_[context->SpecializeKey()] = fg_spec;
  67. fg_spec->Run();
  68. return fg2;
  69. }
  70. std::shared_ptr<FuncGraphSpecializer> ProgramSpecializer::GetFuncGraphSpecializer(const AnalysisContextPtr &context) {
  71. MS_EXCEPTION_IF_NULL(context);
  72. auto iter = specializations_.find(context->SpecializeKey());
  73. if (iter != specializations_.end()) {
  74. return iter->second;
  75. }
  76. return nullptr;
  77. }
  78. std::string GetNextCounter() {
  79. static int g_CloneCounter = 1;
  80. std::string str_count = std::to_string(g_CloneCounter);
  81. g_CloneCounter++;
  82. return str_count;
  83. }
  84. FuncGraphSpecializer::FuncGraphSpecializer(ProgramSpecializer *const s, const FuncGraphPtr &fg,
  85. const AnalysisContextPtr &context)
  86. : specializer_(s), func_graph_(fg), context_(context) {
  87. parent_ = s->GetFuncGraphSpecializer(context->parent());
  88. engine_ = s->engine();
  89. cloner_ = SpecializerClone(fg, std::make_shared<TraceSpecialize>(GetNextCounter()));
  90. repl_node_ = cloner_->cloned_node();
  91. specialized_func_graph_ = cloner_->cloned_func_graph()[fg];
  92. todo_.push_back(fg->get_return());
  93. auto ps = fg->parameters();
  94. (void)todo_.insert(todo_.end(), ps.begin(), ps.end());
  95. }
  96. AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &node) {
  97. MS_EXCEPTION_IF_NULL(node);
  98. FuncGraphPtr fg = node->func_graph();
  99. if (node->isa<ValueNode>()) {
  100. return node;
  101. }
  102. std::shared_ptr<FuncGraphSpecializer> specializer = shared_from_this();
  103. while (fg != nullptr && fg != specializer->func_graph_) {
  104. specializer = specializer->parent_;
  105. }
  106. // If had replicated, just return that.
  107. auto iter = specializer->repl_node_->find(node);
  108. if (iter != specializer->repl_node_->end()) {
  109. return iter->second;
  110. }
  111. auto new_node = specializer->cloner_->CloneDisconnected(node);
  112. if (node->isa<CNode>()) {
  113. if (!new_node->isa<CNode>()) {
  114. MS_LOG(EXCEPTION) << "new_node must be a CNode, but is " << new_node->DebugString() << ".";
  115. }
  116. auto c_node = node->cast<CNodePtr>();
  117. MS_EXCEPTION_IF_NULL(c_node);
  118. auto inputs = c_node->inputs();
  119. std::vector<AnfNodePtr> new_inputs;
  120. (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(new_inputs),
  121. [this](const AnfNodePtr &inp) -> AnfNodePtr {
  122. if (inp->isa<ValueNode>()) {
  123. return inp;
  124. }
  125. return ReplicateDisconnectedNode(inp);
  126. });
  127. auto c_new_node = new_node->cast<CNodePtr>();
  128. MS_EXCEPTION_IF_NULL(c_new_node);
  129. c_new_node->set_inputs(new_inputs);
  130. }
  131. iter = specializer->repl_node_->find(node);
  132. if (iter != specializer->repl_node_->end()) {
  133. if (iter->second == node) {
  134. MS_LOG(EXCEPTION) << "Replicated is same as original node, node: " << node->ToString();
  135. }
  136. } else {
  137. MS_LOG(EXCEPTION) << "Replicate node failed, node: " << node->ToString();
  138. }
  139. return new_node;
  140. }
  141. AnfNodePtr FuncGraphSpecializer::GetReplicatedNode(const AnfNodePtr &node) {
  142. MS_EXCEPTION_IF_NULL(node);
  143. FuncGraphPtr fg = node->func_graph();
  144. std::shared_ptr<FuncGraphSpecializer> specializer = shared_from_this();
  145. while (fg != nullptr && fg != specializer->func_graph_) {
  146. specializer = specializer->parent_;
  147. }
  148. MS_EXCEPTION_IF_NULL(specializer->repl_node_);
  149. auto iter = specializer->repl_node_->find(node);
  150. if (iter != specializer->repl_node_->end()) {
  151. return iter->second;
  152. }
  153. return node;
  154. }
  155. void FuncGraphSpecializer::Run() {
  156. MS_LOG(DEBUG) << "Before run, origin func graph name: " << func_graph_->ToString()
  157. << ", cloned func graph name: " << specialized_func_graph_->ToString()
  158. << ", func graph: " << func_graph_->get_return()->DebugString();
  159. FirstPass();
  160. SecondPass();
  161. MS_LOG(DEBUG) << "After run, origin func graph name: " << func_graph_->ToString()
  162. << ", cloned func graph name: " << specialized_func_graph_->ToString()
  163. << ", new func graph: " << specialized_func_graph_->get_return()->DebugString();
  164. }
  165. void FuncGraphSpecializer::FirstPass() {
  166. while (todo_.size()) {
  167. AnfNodePtr node = todo_.back();
  168. todo_.pop_back();
  169. if (node->func_graph() == nullptr) {
  170. // do nothing for ValueNode
  171. continue;
  172. }
  173. if (node->func_graph() != func_graph_) {
  174. if (parent_ == nullptr) {
  175. MS_LOG(EXCEPTION) << "Parent must not null NodeInfo: " << trace::GetDebugInfo(node->debug_info());
  176. }
  177. parent_->AddTodoItem(node);
  178. parent_->FirstPass();
  179. AnfNodePtr new_node = parent_->GetReplicatedNode(node);
  180. if (node->isa<CNode>()) {
  181. parent_->ProcessCNode(new_node->cast<CNodePtr>());
  182. }
  183. continue;
  184. }
  185. if (marked_.count(node) > 0) {
  186. continue;
  187. }
  188. (void)marked_.insert(node);
  189. ProcessNode(node);
  190. }
  191. }
  192. // Specialize CNode in func graphs
  193. void FuncGraphSpecializer::SecondPass() {
  194. for (auto &node : BroadFirstSearchGraphCNodes(specialized_func_graph_->get_return())) {
  195. if (node->isa<CNode>()) {
  196. ProcessCNode(node->cast<CNodePtr>());
  197. }
  198. }
  199. }
  200. void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
  201. MS_EXCEPTION_IF_NULL(node);
  202. ScopeGuard scope_guard(node->scope());
  203. AnfNodeConfigPtr conf = MakeConfig(node);
  204. AnfNodePtr new_node = GetReplicatedNode(node);
  205. MS_EXCEPTION_IF_NULL(new_node);
  206. if (new_node->func_graph() != specialized_func_graph_) {
  207. MS_LOG(EXCEPTION) << "Error in specializer [A] node: " << node->DebugString()
  208. << ", new_node: " << new_node->DebugString()
  209. << ", new_node->func_graph(): " << new_node->func_graph()->ToString()
  210. << ", specialized_func_graph_: " << specialized_func_graph_->ToString();
  211. return;
  212. }
  213. new_node->set_abstract(GetEvaluatedValueWrap(conf));
  214. if (new_node->isa<CNode>() && new_node->abstract()->isa<PartialAbstractClosure>()) {
  215. auto partial_abstract = dyn_cast<PartialAbstractClosure>(new_node->abstract());
  216. if (partial_abstract->node() == node) {
  217. partial_abstract->set_node(new_node);
  218. }
  219. }
  220. MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString();
  221. if (node->isa<CNode>()) {
  222. auto attrs = conf->GetEvaluatedValue()->attribute();
  223. auto c_old = node->cast<CNodePtr>();
  224. auto c_new = new_node->cast<CNodePtr>();
  225. auto new_inputs = c_new->inputs();
  226. auto old_inputs = c_old->inputs();
  227. for (size_t i = 0; i < old_inputs.size(); ++i) {
  228. auto node_input = old_inputs[i];
  229. AnfNodeConfigPtr iconf = MakeConfig(node_input);
  230. AbstractBasePtr ival = GetEvaluatedValueWrap(iconf);
  231. // First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if
  232. // can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node.
  233. AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival, attrs);
  234. if (replace_node == nullptr) {
  235. replace_node = BuildReplacedNode(iconf);
  236. MS_EXCEPTION_IF_NULL(replace_node);
  237. replace_node->set_abstract(ival);
  238. MS_LOG(DEBUG) << "Set replaced: " << replace_node->ToString() << ", to abstract: " << ival->ToString();
  239. } else {
  240. MS_LOG(DEBUG) << "Build possible value node for node: " << node_input->DebugString()
  241. << ", ival: " << ival->ToString() << ", replace_node: " << replace_node->ToString();
  242. }
  243. if (new_inputs[i] != replace_node) {
  244. new_inputs[i] = replace_node;
  245. MS_LOG(DEBUG) << "Set new_input[" << i << "] = " << replace_node->DebugString();
  246. }
  247. }
  248. c_new->set_inputs(new_inputs);
  249. }
  250. }
  251. AnfNodePtr FuncGraphSpecializer::BuildReplacedNode(const AnfNodeConfigPtr &conf) {
  252. MS_EXCEPTION_IF_NULL(conf);
  253. auto conf_iter = engine_->anfnode_config_map().find(conf);
  254. AnfNodeConfigPtr new_conf = conf;
  255. while (conf_iter != engine_->anfnode_config_map().end()) {
  256. MS_LOG(DEBUG) << "Origin conf: graph(" << new_conf->node()->func_graph()->ToString() << ", node("
  257. << new_conf->node()->DebugString() << ")";
  258. new_conf = conf_iter->second;
  259. MS_EXCEPTION_IF_NULL(new_conf);
  260. MS_LOG(DEBUG) << "Replaced conf: graph(" << conf->node()->func_graph()->ToString() << ", node("
  261. << conf->node()->DebugString() << ")";
  262. (void)ReplicateDisconnectedNode(new_conf->node());
  263. conf_iter = engine_->anfnode_config_map().find(new_conf);
  264. }
  265. todo_.push_back(new_conf->node());
  266. auto repl = GetReplicatedNode(new_conf->node());
  267. if (repl->func_graph()) {
  268. MS_LOG(DEBUG) << "Set repl: graph(" << repl->func_graph()->ToString() << "), node:" << repl->DebugString()
  269. << ") to replace origin:" << new_conf->node()->DebugString();
  270. } else {
  271. MS_LOG(DEBUG) << "Set repl: graph(nullptr), node(" << repl->DebugString()
  272. << ") to replace origin: " << new_conf->node()->DebugString();
  273. }
  274. return repl;
  275. }
  276. namespace {
  277. const StringImmPtr kDeadNode = std::make_shared<StringImm>("Dead Node");
  278. const StringImmPtr kPolyNode = std::make_shared<StringImm>("Poly Node");
  279. inline bool CanSpecializeNode(const AnfNodePtr &node) {
  280. if (IsValueNode<FuncGraph>(node) || IsValueNode<MetaFuncGraph>(node) || IsValueNode<Primitive>(node)) {
  281. return true;
  282. }
  283. return false;
  284. }
  285. } // namespace
  286. AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs,
  287. const AbstractBasePtrList &argvals) {
  288. MS_EXCEPTION_IF_NULL(abs);
  289. AbstractFunctionPtr real_a = dyn_cast<AbstractFunction>(abs);
  290. MS_EXCEPTION_IF_NULL(real_a);
  291. AbstractFunctionPtr func = real_a->GetUnique();
  292. SpecializeStatusCode errcode;
  293. ScopeGuard scope_guard(node->scope());
  294. AnfNodePtr repl = BuildSpecializedNodeInner(node, abs, func, argvals, &errcode);
  295. if (repl == nullptr) {
  296. if (errcode == kSpecializeFindUniqueArgvalDead) {
  297. const auto error_dead_node = std::make_shared<AbstractError>(kDeadNode, node);
  298. repl = BuildValueNode(kDeadNode, error_dead_node);
  299. MS_LOG(DEBUG) << "DEAD for node: " << node->DebugString() << ", abstract: " << abs->ToString();
  300. } else if (errcode == kSpecializeFindUniqueArgvalPoly) {
  301. const auto error_poly_node = std::make_shared<AbstractError>(kPolyNode, node);
  302. repl = BuildValueNode(kPolyNode, error_poly_node);
  303. MS_LOG(DEBUG) << "POLY for node: " << node->DebugString() << ", abstract: " << abs->ToString();
  304. } else {
  305. MS_LOG(EXCEPTION) << "Failed to build specialized node, node: " << node->DebugString()
  306. << ", abstract: " << abs->ToString();
  307. }
  308. }
  309. return repl;
  310. }
  311. AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AnfNodePtr &node, const AbstractBasePtr &abs,
  312. const AbstractFunctionPtr &func,
  313. const AbstractBasePtrList &args,
  314. SpecializeStatusCode *errcode) {
  315. MS_EXCEPTION_IF_NULL(abs);
  316. MS_EXCEPTION_IF_NULL(func);
  317. MS_EXCEPTION_IF_NULL(errcode);
  318. *errcode = kSpecializeSuccess;
  319. auto real_func = dyn_cast<TypedPrimitiveAbstractClosure>(func);
  320. if (real_func != nullptr) {
  321. return BuildValueNode(real_func->prim(), abs);
  322. }
  323. EvaluatorPtr eval;
  324. eval = engine_->GetEvaluatorFor(func);
  325. MS_EXCEPTION_IF_NULL(eval);
  326. AbstractBasePtrList argvals = eval->NormalizeArgs(args);
  327. std::pair<AbstractBasePtrList, AbstractBasePtr> result;
  328. SpecializeStatusCode status = FindUniqueArgvals(func, eval, argvals, &result);
  329. if (status != kSpecializeSuccess) {
  330. *errcode = status;
  331. return nullptr;
  332. }
  333. argvals = result.first;
  334. AbstractBasePtr unique_output = result.second;
  335. auto prim_func = dyn_cast<PrimitiveAbstractClosure>(func);
  336. if (prim_func != nullptr) {
  337. auto type_func = std::make_shared<TypedPrimitiveAbstractClosure>(prim_func->prim(), argvals, unique_output);
  338. return BuildValueNode(prim_func->prim(), type_func);
  339. }
  340. if (!eval->isa<BaseFuncGraphEvaluator>()) {
  341. MS_LOG(EXCEPTION) << "Eval is not BaseGraphEvaluator, but " << eval->ToString();
  342. }
  343. auto real_eval = dyn_cast<BaseFuncGraphEvaluator>(eval);
  344. if (func->context() == nullptr) {
  345. MS_LOG(EXCEPTION) << "Func context is nullptr NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info());
  346. }
  347. AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals);
  348. MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << argvals.size()
  349. << ", graph: " << context->func_graph()->get_return()->DebugString();
  350. if (context->func_graph()->stub()) {
  351. MS_LOG(DEBUG) << "Specialize stub function graph, return the original node: " << context->func_graph()->ToString()
  352. << ", args: " << argvals.size() << ", graph: " << context->func_graph()->get_return()->DebugString()
  353. << ", " << node->ToString();
  354. return node;
  355. }
  356. FuncGraphPtr v = specializer_->SpecializeFuncGraph(context->func_graph(), context);
  357. v->set_flag(kFuncGraphFlagUndetermined, false);
  358. return BuildValueNode(v, abs);
  359. }
  360. AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &new_node) {
  361. auto new_inputs = new_node->inputs();
  362. AnfNodePtr func = new_inputs[0];
  363. AbstractBasePtr fnval = new_inputs[0]->abstract();
  364. AbstractBasePtrList args;
  365. auto backed_fnval = fnval;
  366. if (fnval->isa<PartialAbstractClosure>()) {
  367. auto partial_closure = dyn_cast<PartialAbstractClosure>(fnval);
  368. backed_fnval = partial_closure->fn();
  369. args = partial_closure->args();
  370. }
  371. std::transform(new_inputs.cbegin() + 1, new_inputs.cend(), std::back_inserter(args),
  372. [](const AnfNodePtr &inp) { return inp->abstract(); });
  373. ScopeGuard scope_guard(new_node->scope());
  374. auto specialized_node = BuildSpecializedNode(func, backed_fnval, args);
  375. auto wrapped_node = specialized_node;
  376. if (fnval->isa<PartialAbstractClosure>()) {
  377. auto partial_closure = dyn_cast<PartialAbstractClosure>(fnval);
  378. AnfNodePtrList partial_node_list = {BuildValueNode(prim::kPrimPartial, FromValueInside(prim::kPrimPartial)),
  379. specialized_node};
  380. auto anf_node = partial_closure->node();
  381. if (!anf_node->isa<CNode>()) {
  382. MS_LOG(EXCEPTION) << "Must be cnode, but " << anf_node->DebugString();
  383. }
  384. auto cnode = anf_node->cast<CNodePtr>();
  385. if (cnode->size() != partial_closure->args().size() + 2) {
  386. MS_LOG(EXCEPTION) << "Size of cnode: " << cnode->DebugString()
  387. << " is not equal to 2 added to size of args: " << mindspore::ToString(partial_closure->args());
  388. }
  389. auto attrs = std::make_shared<AttrValueMap>();
  390. for (size_t i = 0; i < partial_closure->args().size(); i++) {
  391. auto old_node = cnode->input(i + 2);
  392. auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i], attrs);
  393. if (possibile_value_node != nullptr) {
  394. partial_node_list.push_back(possibile_value_node);
  395. } else {
  396. if (!(old_node->isa<CNode>() || old_node->isa<Parameter>())) {
  397. MS_LOG(EXCEPTION) << "Old node should be CNode or Parameter, but " << old_node->ToString();
  398. }
  399. partial_node_list.push_back(old_node);
  400. }
  401. }
  402. wrapped_node = new_node->func_graph()->NewCNode(partial_node_list);
  403. wrapped_node->set_abstract(partial_closure);
  404. }
  405. return wrapped_node;
  406. }
  407. const EvaluatorCacheMapPtr &FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) {
  408. auto cache_iter = evalcaches_.find(eval);
  409. if (cache_iter == evalcaches_.end()) {
  410. evalcaches_[eval] = eval->cache();
  411. return eval->cache();
  412. }
  413. return cache_iter->second;
  414. }
  415. std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromBroadedArgsVal(
  416. const EvaluatorPtr &eval) {
  417. MS_EXCEPTION_IF_NULL(eval);
  418. std::unordered_set<AbstractBasePtrList, AbstractBasePtrListHasher, AbstractBasePtrListEqual> choices;
  419. EvalResultPtr ret = nullptr;
  420. AbstractBasePtrList broaded_argvals;
  421. for (auto &argvals_map : *evalcaches_[eval]) {
  422. auto argvals = argvals_map.first;
  423. broaded_argvals.clear();
  424. (void)std::transform(argvals.begin(), argvals.end(), std::back_inserter(broaded_argvals),
  425. [](const AbstractBasePtr &arg) -> AbstractBasePtr { return arg->Broaden(); });
  426. (void)choices.insert(broaded_argvals);
  427. MS_LOG(DEBUG) << "Broaded_argvals: " << broaded_argvals.size() << ", " << ::mindspore::ToString(broaded_argvals);
  428. }
  429. if (1 == choices.size()) {
  430. ConfigPtrList args_conf_list;
  431. (void)std::transform(broaded_argvals.begin(), broaded_argvals.end(), std::back_inserter(args_conf_list),
  432. [](AbstractBasePtr v) -> ConfigPtr { return std::make_shared<VirtualConfig>(v); });
  433. // if broaden return null
  434. ret = eval->Run(engine_, args_conf_list, nullptr);
  435. EvaluatorCacheMapPtr real = std::make_shared<EvaluatorCacheMap>();
  436. (*real)[broaded_argvals] = ret;
  437. evalcaches_[eval] = real;
  438. return std::make_pair(broaded_argvals, ret->abstract());
  439. } else {
  440. MS_LOG(DEBUG) << "Choices.size: " << choices.size();
  441. return std::make_pair(AbstractBasePtrList(), nullptr);
  442. }
  443. }
  444. void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
  445. MS_EXCEPTION_IF_NULL(new_node);
  446. if (specializer_->seen().count(new_node) > 0) {
  447. return;
  448. }
  449. specializer_->AddSeen(new_node);
  450. auto new_inputs = new_node->inputs();
  451. if (new_inputs.empty()) {
  452. MS_LOG(EXCEPTION) << "Inputs of CNode is empty";
  453. }
  454. AnfNodePtr func = new_inputs[0];
  455. MS_EXCEPTION_IF_NULL(func);
  456. // First element is func so arg start from 1
  457. std::vector<AnfNodePtr> args(new_inputs.begin() + 1, new_inputs.end());
  458. // CNode(CNode(Partial, f, arg1), arg2, ...) --> CNode(f, arg1, arg2, ...)
  459. while (IsPrimitiveCNode(func, prim::kPrimPartial)) {
  460. std::vector<AnfNodePtr> inputs = func->cast<CNodePtr>()->inputs();
  461. // First element is partial, second is func so arg is start from 2
  462. (void)args.insert(args.begin(), inputs.begin() + 2, inputs.end());
  463. func = inputs[1];
  464. }
  465. new_inputs = args;
  466. (void)new_inputs.insert(new_inputs.begin(), func);
  467. AbstractBasePtrList argvals;
  468. MS_EXCEPTION_IF_NULL(new_inputs[0]);
  469. AbstractBasePtr fnval = new_inputs[0]->abstract();
  470. MS_LOG(DEBUG) << "The new_inputs[0] node: pointer: " << new_inputs[0]->ToString() << ", "
  471. << new_inputs[0]->DebugString() << ", abstract: " << new_inputs[0]->abstract()->ToString();
  472. // First element is func so function arguments start from 1
  473. for (size_t i = 1; i < new_inputs.size(); ++i) {
  474. argvals.push_back(new_inputs[i]->abstract());
  475. MS_LOG(DEBUG) << "The new_inputs[" << i << "] node: pointer: " << new_inputs[i]->ToString() << ", "
  476. << new_inputs[i]->DebugString() << ", abstract: " << new_inputs[i]->abstract()->ToString();
  477. }
  478. if (!func->isa<ValueNode>()) {
  479. MS_LOG(DEBUG) << func->abstract()->type_name() << " | " << func->abstract()->ToString();
  480. if (func->abstract()->isa<AbstractFunction>() && !func->abstract()->isa<AbstractFuncUnion>()) {
  481. auto func_abs = func->abstract()->cast<AbstractFunctionPtr>();
  482. EvaluatorPtr eval = engine_->GetEvaluatorFor(func_abs);
  483. std::pair<AbstractBasePtrList, AbstractBasePtr> result;
  484. AbstractBasePtrList empty_args;
  485. auto status = FindUniqueArgvals(func_abs, eval, empty_args, &result);
  486. MS_LOG(DEBUG) << "FindUniqueArgvals return status: " << status;
  487. // if a node is a poly node, or an input parameter is a PartialAbstractClosure, expand it early
  488. if (status == kSpecializeFindUniqueArgvalPoly ||
  489. (func->isa<Parameter>() && (func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER) ||
  490. func->abstract()->isa<PartialAbstractClosure>()))) {
  491. auto wrapped_node = BuildSpecializedParameterNode(new_node);
  492. new_inputs[0] = wrapped_node;
  493. }
  494. }
  495. }
  496. if (CanSpecializeNode(func)) {
  497. // for primitive node , we build the primitive node with infered attributes in the first pass
  498. // so we do not build replaced node again here in second pass
  499. if (IsValueNode<Primitive>(func)) {
  500. new_inputs[0] = func;
  501. } else {
  502. new_inputs[0] = BuildSpecializedNode(func, fnval, argvals);
  503. }
  504. }
  505. for (size_t i = 0; i < argvals.size();) {
  506. size_t next = i + 1;
  507. if (CanSpecializeNode(args[i])) {
  508. new_inputs[next] = BuildSpecializedNode(args[i], argvals[i], std::vector<AbstractBasePtr>{});
  509. }
  510. i = next;
  511. }
  512. new_node->set_inputs(new_inputs);
  513. }
  514. namespace {
  515. void DumpEvaluatorCache(const EvaluatorCacheMap &evaluator_cache_map, const AbstractBasePtrList &argvals) {
  516. MS_LOG(DEBUG) << "Find unique argvals failed: " << argvals.size() << ", " << argvals << ". Check cache all items.";
  517. int i = 0;
  518. for (const auto &item : evaluator_cache_map) {
  519. MS_LOG(DEBUG) << "evaluator_cache_map[" << i++ << "]: " << item.first;
  520. }
  521. }
  522. bool IsPolyFunc(const AbstractFunctionPtr &func, const AbstractBasePtrList &argvals) {
  523. if (func->isa<PrimitiveAbstractClosure>() && argvals.empty()) {
  524. MS_LOG(DEBUG) << "High order primitive return POLY.";
  525. return true;
  526. }
  527. if (func->isa<MetaFuncGraphAbstractClosure>() && argvals.empty()) {
  528. auto meta_func_graph_wrapper = dyn_cast<MetaFuncGraphAbstractClosure>(func);
  529. auto meta_func_graph = meta_func_graph_wrapper->meta_func_graph();
  530. if (meta_func_graph != nullptr && meta_func_graph->isa<prim::DoSignatureMetaFuncGraph>()) {
  531. auto do_signature = dyn_cast<prim::DoSignatureMetaFuncGraph>(meta_func_graph);
  532. if (do_signature != nullptr && do_signature->function()->isa<Primitive>()) {
  533. MS_LOG(DEBUG) << "High order primitive " << do_signature->function()->ToString() << " return POLY.";
  534. return true;
  535. }
  536. }
  537. }
  538. return false;
  539. }
  540. } // end anonymous namespace
  541. SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunctionPtr &func, const EvaluatorPtr &eval,
  542. const AbstractBasePtrList &argvals,
  543. std::pair<AbstractBasePtrList, AbstractBasePtr> *result) {
  544. MS_EXCEPTION_IF_NULL(func);
  545. MS_EXCEPTION_IF_NULL(eval);
  546. MS_EXCEPTION_IF_NULL(result);
  547. EvaluatorCacheMap evaluator_cache_map = *eval->cache();
  548. if (evaluator_cache_map.find(argvals) != evaluator_cache_map.end()) {
  549. *result = std::make_pair(argvals, evaluator_cache_map[argvals]->abstract());
  550. return kSpecializeSuccess;
  551. }
  552. DumpEvaluatorCache(evaluator_cache_map, argvals);
  553. const EvaluatorCacheMapPtr &choices = GetEvalCache(eval);
  554. MS_EXCEPTION_IF_NULL(choices);
  555. if (choices->count(argvals)) {
  556. *result = std::make_pair(argvals, (*choices)[argvals]->abstract());
  557. return kSpecializeSuccess;
  558. } else if (choices->size() == 1) {
  559. MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it.";
  560. *result = std::make_pair(choices->begin()->first, choices->begin()->second->abstract());
  561. return kSpecializeSuccess;
  562. } else if (choices->empty()) {
  563. MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase " << func->ToString() << " | "
  564. << func->type_name();
  565. return kSpecializeFindUniqueArgvalDead;
  566. } else {
  567. if (IsPolyFunc(func, argvals)) {
  568. return kSpecializeFindUniqueArgvalPoly;
  569. }
  570. MS_LOG(DEBUG) << "Try to find generalized argvals.";
  571. *result = BuildFromBroadedArgsVal(eval);
  572. if (!result->first.empty()) {
  573. return kSpecializeSuccess;
  574. }
  575. MS_LOG(DEBUG) << "Find POLY code, it may be unused code or unresolved polymorphism.";
  576. return kSpecializeFindUniqueArgvalPoly;
  577. }
  578. }
  579. static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, const AttrValueMapPtr &attrs) {
  580. auto &prim_attrs = prim->attrs();
  581. bool is_attr_same = true;
  582. for (auto &item : *attrs) {
  583. auto itr = prim_attrs.find(item.first);
  584. if (itr != prim_attrs.end()) {
  585. if (!(*(itr->second) == *(item.second))) {
  586. is_attr_same = false;
  587. break;
  588. }
  589. } else {
  590. is_attr_same = false;
  591. break;
  592. }
  593. }
  594. if (!is_attr_same) {
  595. if (prim->isa<PrimitivePy>()) {
  596. PrimitivePyPtr prim_py = prim->cast<PrimitivePyPtr>();
  597. auto clone_fn = prim_py->GetPyObj().attr("_clone");
  598. py::object new_obj = clone_fn();
  599. auto cloned_prim = new_obj.cast<PrimitivePyPtr>();
  600. for (auto &item : *attrs) {
  601. cloned_prim->AddAttr(item.first, item.second);
  602. }
  603. return cloned_prim;
  604. }
  605. auto cloned_prim = std::make_shared<Primitive>(*prim);
  606. for (auto &item : *attrs) {
  607. cloned_prim->AddAttr(item.first, item.second);
  608. }
  609. return cloned_prim;
  610. }
  611. return prim;
  612. }
  613. AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival,
  614. const AttrValueMapPtr &attrs) {
  615. MS_EXCEPTION_IF_NULL(origin_node);
  616. MS_EXCEPTION_IF_NULL(ival);
  617. AbstractFunctionPtr abs = dyn_cast<AbstractFunction>(ival);
  618. if (abs != nullptr) {
  619. // Cannot build a determinstic ValueNode if there are multiple possible AbstractFunction.
  620. if (abs->isa<AbstractFuncUnion>()) {
  621. return nullptr;
  622. }
  623. ValuePtr value = nullptr;
  624. if (abs->isa<PrimitiveAbstractClosure>()) {
  625. auto real_fn = dyn_cast<PrimitiveAbstractClosure>(abs);
  626. // for primitive, check if the attribute is the same with cnode infererd attribute ,if not, clone a new one
  627. if (attrs != nullptr) {
  628. value = BuildPrimtiveValueWithAttributes(real_fn->prim(), attrs);
  629. } else {
  630. value = real_fn->prim();
  631. }
  632. } else if (abs->isa<MetaFuncGraphAbstractClosure>()) {
  633. auto real_fn = dyn_cast<MetaFuncGraphAbstractClosure>(abs);
  634. value = real_fn->meta_func_graph();
  635. } else if (abs->isa<FuncGraphAbstractClosure>()) {
  636. auto real_fn = dyn_cast<FuncGraphAbstractClosure>(abs);
  637. value = real_fn->func_graph();
  638. } else {
  639. return nullptr;
  640. }
  641. if (!value->isa<FuncGraph>() || value->cast<FuncGraphPtr>()->parent() == nullptr ||
  642. (IsValueNode<FuncGraph>(origin_node) && IsVisible(func_graph_, value->cast<FuncGraphPtr>()->parent()))) {
  643. return BuildValueNode(value, ival);
  644. } else {
  645. return nullptr;
  646. }
  647. } else {
  648. ValuePtr val = ival->BuildValue();
  649. if (val->isa<AnyValue>()) {
  650. return nullptr;
  651. }
  652. // keep primitive 'depend' not to be optimized
  653. if (IsPrimitiveCNode(origin_node, prim::kPrimDepend)) {
  654. return nullptr;
  655. }
  656. return BuildValueNode(val, ival);
  657. }
  658. }
  659. AnfNodeConfigPtr FuncGraphSpecializer::MakeConfig(const AnfNodePtr &node) {
  660. return engine_->MakeConfig(node, context_);
  661. }
  662. } // namespace abstract
  663. } // namespace mindspore