You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

program_specialize.cc 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/static_analysis/program_specialize.h"
  19. #include <algorithm>
  20. #include <exception>
  21. #include "./common.h"
  22. #include "operator/ops.h"
  23. #include "operator/composite/do_signature.h"
  24. #include "utils/graph_utils.h"
  25. #include "utils/profile.h"
  26. #include "debug/trace.h"
  27. namespace mindspore {
  28. namespace abstract {
  29. namespace {
  30. inline AbstractBasePtr GetEvaluatedValueWrap(const AnfNodeConfigPtr &conf) {
  31. if (conf->node()->intermediate_abstract()) {
  32. return conf->node()->intermediate_abstract();
  33. }
  34. return conf->GetEvaluatedValue();
  35. }
  36. AnfNodePtr BuildValueNode(const ValuePtr &v, const AbstractBasePtr &abs_base) {
  37. AnfNodePtr value_node = NewValueNode(v);
  38. value_node->set_abstract(abs_base);
  39. MS_LOG(DEBUG) << "Create ValueNode: " << value_node->ToString() << ", with abstract: " << abs_base->ToString();
  40. return value_node;
  41. }
  42. bool IsVisible(FuncGraphPtr fg, const FuncGraphPtr &parent) {
  43. while (fg != nullptr && fg != parent) {
  44. fg = fg->parent();
  45. }
  46. return fg == parent;
  47. }
  48. } // namespace
  49. FuncGraphPtr ProgramSpecializer::Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context) {
  50. MS_EXCEPTION_IF_NULL(fg);
  51. MS_EXCEPTION_IF_NULL(context);
  52. MS_LOG(DEBUG) << "Specialize topmost function graph: " << context->func_graph()->ToString();
  53. return SpecializeFuncGraph(fg, context);
  54. }
  55. FuncGraphPtr ProgramSpecializer::SpecializeFuncGraph(const FuncGraphPtr &fg, const AnalysisContextPtr &context) {
  56. MS_EXCEPTION_IF_NULL(fg);
  57. MS_EXCEPTION_IF_NULL(context);
  58. auto iter = specializations_.find(context->SpecializeKey());
  59. if (iter != specializations_.end()) {
  60. return iter->second->specialized_func_graph();
  61. }
  62. std::shared_ptr<FuncGraphSpecializer> fg_spec = std::make_shared<FuncGraphSpecializer>(this, fg, context);
  63. FuncGraphPtr fg2 = fg_spec->specialized_func_graph();
  64. specializations_[context->SpecializeKey()] = fg_spec;
  65. fg_spec->Run();
  66. return fg2;
  67. }
  68. std::shared_ptr<FuncGraphSpecializer> ProgramSpecializer::GetFuncGraphSpecializer(const AnalysisContextPtr &context) {
  69. MS_EXCEPTION_IF_NULL(context);
  70. auto iter = specializations_.find(context->SpecializeKey());
  71. if (iter != specializations_.end()) {
  72. return iter->second;
  73. }
  74. return nullptr;
  75. }
  76. std::string GetNextCounter() {
  77. static int g_CloneCounter = 1;
  78. std::string str_count = std::to_string(g_CloneCounter);
  79. g_CloneCounter++;
  80. return str_count;
  81. }
  82. FuncGraphSpecializer::FuncGraphSpecializer(ProgramSpecializer *const s, const FuncGraphPtr &fg,
  83. const AnalysisContextPtr &context)
  84. : specializer_(s), func_graph_(fg), context_(context) {
  85. parent_ = s->GetFuncGraphSpecializer(context->parent());
  86. engine_ = s->engine();
  87. cloner_ = SpecializerClone(fg, std::make_shared<TraceSpecialize>(GetNextCounter()));
  88. repl_node_ = cloner_->cloned_node();
  89. specialized_func_graph_ = cloner_->cloned_func_graph()[fg];
  90. todo_.push_back(fg->get_return());
  91. auto ps = fg->parameters();
  92. (void)todo_.insert(todo_.end(), ps.begin(), ps.end());
  93. }
  94. AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &node) {
  95. MS_EXCEPTION_IF_NULL(node);
  96. FuncGraphPtr fg = node->func_graph();
  97. if (node->isa<ValueNode>()) {
  98. return node;
  99. }
  100. std::shared_ptr<FuncGraphSpecializer> specializer = shared_from_this();
  101. while (fg != nullptr && fg != specializer->func_graph_) {
  102. specializer = specializer->parent_;
  103. }
  104. // If had replicated, just return that.
  105. auto iter = specializer->repl_node_->find(node);
  106. if (iter != specializer->repl_node_->end()) {
  107. return iter->second;
  108. }
  109. auto new_node = specializer->cloner_->CloneDisconnected(node);
  110. if (node->isa<CNode>()) {
  111. if (!new_node->isa<CNode>()) {
  112. MS_LOG(EXCEPTION) << "new_node must be a CNode, but is " << new_node->DebugString() << ".";
  113. }
  114. auto c_node = node->cast<CNodePtr>();
  115. MS_EXCEPTION_IF_NULL(c_node);
  116. auto inputs = c_node->inputs();
  117. std::vector<AnfNodePtr> new_inputs;
  118. (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(new_inputs),
  119. [this](const AnfNodePtr &inp) -> AnfNodePtr {
  120. if (inp->isa<ValueNode>()) {
  121. return inp;
  122. }
  123. return ReplicateDisconnectedNode(inp);
  124. });
  125. auto c_new_node = new_node->cast<CNodePtr>();
  126. MS_EXCEPTION_IF_NULL(c_new_node);
  127. c_new_node->set_inputs(new_inputs);
  128. }
  129. iter = specializer->repl_node_->find(node);
  130. if (iter != specializer->repl_node_->end()) {
  131. if (iter->second == node) {
  132. MS_LOG(EXCEPTION) << "Replicated is same as original node, node: " << node->ToString();
  133. }
  134. } else {
  135. MS_LOG(EXCEPTION) << "Replicate node failed, node: " << node->ToString();
  136. }
  137. return new_node;
  138. }
  139. AnfNodePtr FuncGraphSpecializer::GetReplicatedNode(const AnfNodePtr &node) {
  140. MS_EXCEPTION_IF_NULL(node);
  141. FuncGraphPtr fg = node->func_graph();
  142. std::shared_ptr<FuncGraphSpecializer> specializer = shared_from_this();
  143. while (fg != nullptr && fg != specializer->func_graph_) {
  144. specializer = specializer->parent_;
  145. }
  146. MS_EXCEPTION_IF_NULL(specializer->repl_node_);
  147. auto iter = specializer->repl_node_->find(node);
  148. if (iter != specializer->repl_node_->end()) {
  149. return iter->second;
  150. }
  151. return node;
  152. }
  153. void FuncGraphSpecializer::Run() {
  154. MS_LOG(DEBUG) << "Before run, origin func graph name: " << func_graph_->ToString()
  155. << ", cloned func graph name: " << specialized_func_graph_->ToString()
  156. << ", func graph: " << func_graph_->get_return()->DebugString();
  157. FirstPass();
  158. SecondPass();
  159. MS_LOG(DEBUG) << "After run, origin func graph name: " << func_graph_->ToString()
  160. << ", cloned func graph name: " << specialized_func_graph_->ToString()
  161. << ", new func graph: " << specialized_func_graph_->get_return()->DebugString();
  162. }
  163. void FuncGraphSpecializer::FirstPass() {
  164. while (todo_.size()) {
  165. AnfNodePtr node = todo_.back();
  166. todo_.pop_back();
  167. if (node->func_graph() == nullptr) {
  168. // do nothing for ValueNode
  169. continue;
  170. }
  171. if (node->func_graph() != func_graph_) {
  172. if (parent_ == nullptr) {
  173. MS_LOG(EXCEPTION) << "Parent must not null NodeInfo: " << trace::GetDebugInfo(node->debug_info());
  174. }
  175. parent_->AddTodoItem(node);
  176. parent_->FirstPass();
  177. AnfNodePtr new_node = parent_->GetReplicatedNode(node);
  178. if (node->isa<CNode>()) {
  179. parent_->ProcessCNode(new_node->cast<CNodePtr>());
  180. }
  181. continue;
  182. }
  183. if (marked_.count(node) > 0) {
  184. continue;
  185. }
  186. (void)marked_.insert(node);
  187. ProcessNode(node);
  188. }
  189. }
  190. // Specialize CNode in func graphs
  191. void FuncGraphSpecializer::SecondPass() {
  192. for (auto &node : DeepLinkedGraphSearch(specialized_func_graph_->get_return())) {
  193. if (node->isa<CNode>()) {
  194. ProcessCNode(node->cast<CNodePtr>());
  195. }
  196. }
  197. }
  198. void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
  199. MS_EXCEPTION_IF_NULL(node);
  200. ScopeGuard scope_guard(node->scope());
  201. AnfNodeConfigPtr conf = MakeConfig(node);
  202. AnfNodePtr new_node = GetReplicatedNode(node);
  203. MS_EXCEPTION_IF_NULL(new_node);
  204. if (new_node->func_graph() != specialized_func_graph_) {
  205. MS_LOG(EXCEPTION) << "Error in specializer [A] node: " << node->DebugString()
  206. << ", new_node: " << new_node->DebugString()
  207. << ", new_node->func_graph(): " << new_node->func_graph()->ToString()
  208. << ", specialized_func_graph_: " << specialized_func_graph_->ToString();
  209. return;
  210. }
  211. new_node->set_abstract(GetEvaluatedValueWrap(conf));
  212. MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString();
  213. if (node->isa<CNode>()) {
  214. auto c_old = node->cast<CNodePtr>();
  215. auto c_new = new_node->cast<CNodePtr>();
  216. auto new_inputs = c_new->inputs();
  217. auto old_inputs = c_old->inputs();
  218. for (size_t i = 0; i < old_inputs.size(); ++i) {
  219. auto node_input = old_inputs[i];
  220. AnfNodeConfigPtr iconf = MakeConfig(node_input);
  221. AbstractBasePtr ival = GetEvaluatedValueWrap(iconf);
  222. // First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if
  223. // can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node.
  224. AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival);
  225. if (replace_node == nullptr) {
  226. replace_node = BuildReplacedNode(iconf);
  227. MS_EXCEPTION_IF_NULL(replace_node);
  228. replace_node->set_abstract(ival);
  229. MS_LOG(DEBUG) << "Set replaced: " << replace_node->ToString() << ", to abstract: " << ival->ToString();
  230. } else {
  231. MS_LOG(DEBUG) << "Build possible value node for node: " << node_input->DebugString()
  232. << ", ival: " << ival->ToString() << ", replace_node: " << replace_node->ToString();
  233. }
  234. if (new_inputs[i] != replace_node) {
  235. new_inputs[i] = replace_node;
  236. MS_LOG(DEBUG) << "Set new_input[" << i << "] = " << replace_node->DebugString();
  237. }
  238. }
  239. c_new->set_inputs(new_inputs);
  240. }
  241. }
  242. AnfNodePtr FuncGraphSpecializer::BuildReplacedNode(const AnfNodeConfigPtr &conf) {
  243. MS_EXCEPTION_IF_NULL(conf);
  244. auto conf_iter = engine_->anfnode_config_map().find(conf);
  245. AnfNodeConfigPtr new_conf = conf;
  246. while (conf_iter != engine_->anfnode_config_map().end()) {
  247. MS_LOG(DEBUG) << "Origin conf: graph(" << new_conf->node()->func_graph()->ToString() << ", node("
  248. << new_conf->node()->DebugString() << ")";
  249. new_conf = conf_iter->second;
  250. MS_EXCEPTION_IF_NULL(new_conf);
  251. MS_LOG(DEBUG) << "Replaced conf: graph(" << conf->node()->func_graph()->ToString() << ", node("
  252. << conf->node()->DebugString() << ")";
  253. (void)ReplicateDisconnectedNode(new_conf->node());
  254. conf_iter = engine_->anfnode_config_map().find(new_conf);
  255. }
  256. todo_.push_back(new_conf->node());
  257. auto repl = GetReplicatedNode(new_conf->node());
  258. if (repl->func_graph()) {
  259. MS_LOG(DEBUG) << "Set repl: graph(" << repl->func_graph()->ToString() << "), node:" << repl->DebugString()
  260. << ") to replace origin:" << new_conf->node()->DebugString();
  261. } else {
  262. MS_LOG(DEBUG) << "Set repl: graph(nullptr), node(" << repl->DebugString()
  263. << ") to replace origin: " << new_conf->node()->DebugString();
  264. }
  265. return repl;
  266. }
  267. namespace {
  268. const StringImmPtr kDeadNode = std::make_shared<StringImm>("Dead Node");
  269. const StringImmPtr kPolyNode = std::make_shared<StringImm>("Poly Node");
  270. inline bool CanSpecializeNode(const AnfNodePtr &node) {
  271. if (IsValueNode<FuncGraph>(node) || IsValueNode<MetaFuncGraph>(node) || IsValueNode<Primitive>(node)) {
  272. return true;
  273. }
  274. return false;
  275. }
  276. } // namespace
  277. AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs,
  278. const AbstractBasePtrList &argvals) {
  279. MS_EXCEPTION_IF_NULL(abs);
  280. AbstractFunctionPtr real_a = dyn_cast<AbstractFunction>(abs);
  281. MS_EXCEPTION_IF_NULL(real_a);
  282. AbstractFunctionPtr func = real_a->GetUnique();
  283. SpecializeStatusCode errcode;
  284. ScopeGuard scope_guard(node->scope());
  285. AnfNodePtr repl = BuildSpecializedNodeInner(abs, func, argvals, &errcode);
  286. if (repl == nullptr) {
  287. if (errcode == kSpecializeFindUniqueArgvalDead) {
  288. const auto error_dead_node = std::make_shared<AbstractError>(kDeadNode, node);
  289. repl = BuildValueNode(kDeadNode, error_dead_node);
  290. MS_LOG(DEBUG) << "DEAD for node: " << node->DebugString() << ", abstract: " << abs->ToString();
  291. } else if (errcode == kSpecializeFindUniqueArgvalPoly) {
  292. const auto error_poly_node = std::make_shared<AbstractError>(kPolyNode, node);
  293. repl = BuildValueNode(kPolyNode, error_poly_node);
  294. MS_LOG(DEBUG) << "POLY for node: " << node->DebugString() << ", abstract: " << abs->ToString();
  295. } else {
  296. MS_LOG(EXCEPTION) << "Failed to build specialized node, node: " << node->DebugString()
  297. << ", abstract: " << abs->ToString();
  298. }
  299. }
  300. return repl;
  301. }
  302. AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr &abs, const AbstractFunctionPtr &func,
  303. const AbstractBasePtrList &args,
  304. SpecializeStatusCode *errcode) {
  305. MS_EXCEPTION_IF_NULL(abs);
  306. MS_EXCEPTION_IF_NULL(func);
  307. MS_EXCEPTION_IF_NULL(errcode);
  308. *errcode = kSpecializeSuccess;
  309. auto real_func = dyn_cast<TypedPrimitiveAbstractClosure>(func);
  310. if (real_func != nullptr) {
  311. return BuildValueNode(real_func->prim(), abs);
  312. }
  313. EvaluatorPtr eval;
  314. eval = engine_->GetEvaluatorFor(func);
  315. MS_EXCEPTION_IF_NULL(eval);
  316. AbstractBasePtrList argvals = eval->NormalizeArgs(args);
  317. std::pair<AbstractBasePtrList, AbstractBasePtr> result;
  318. SpecializeStatusCode status = FindUniqueArgvals(func, eval, argvals, &result);
  319. if (status != kSpecializeSuccess) {
  320. *errcode = status;
  321. return nullptr;
  322. }
  323. argvals = result.first;
  324. AbstractBasePtr unique_output = result.second;
  325. auto prim_func = dyn_cast<PrimitiveAbstractClosure>(func);
  326. if (prim_func != nullptr) {
  327. auto type_func = std::make_shared<TypedPrimitiveAbstractClosure>(prim_func->prim(), argvals, unique_output);
  328. return BuildValueNode(prim_func->prim(), type_func);
  329. }
  330. if (!eval->isa<BaseFuncGraphEvaluator>()) {
  331. MS_LOG(EXCEPTION) << "Eval is not BaseGraphEvaluator, but " << eval->ToString();
  332. }
  333. auto real_eval = dyn_cast<BaseFuncGraphEvaluator>(eval);
  334. if (func->context() != nullptr) {
  335. if (!IsVisible(func_graph_, func->context()->func_graph())) {
  336. MS_LOG(EXCEPTION) << "Func is not visible NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info());
  337. }
  338. } else {
  339. MS_LOG(EXCEPTION) << "Func context is nullptr NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info());
  340. }
  341. AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals);
  342. MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << argvals.size()
  343. << ", graph: " << context->func_graph()->get_return()->DebugString();
  344. FuncGraphPtr v = specializer_->SpecializeFuncGraph(context->func_graph(), context);
  345. return BuildValueNode(v, abs);
  346. }
  347. const EvaluatorCacheMapPtr &FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) {
  348. auto cache_iter = evalcaches_.find(eval);
  349. if (cache_iter == evalcaches_.end()) {
  350. evalcaches_[eval] = eval->cache();
  351. return eval->cache();
  352. }
  353. return cache_iter->second;
  354. }
  355. std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromBroadedArgsVal(
  356. const EvaluatorPtr &eval) {
  357. MS_EXCEPTION_IF_NULL(eval);
  358. std::unordered_set<AbstractBasePtrList, AbstractBasePtrListHasher, AbstractBasePtrListEqual> choices;
  359. AbstractBasePtr ret = nullptr;
  360. AbstractBasePtrList broaded_argvals;
  361. for (auto &argvals_map : *evalcaches_[eval]) {
  362. auto argvals = argvals_map.first;
  363. broaded_argvals.clear();
  364. (void)std::transform(argvals.begin(), argvals.end(), std::back_inserter(broaded_argvals),
  365. [](const AbstractBasePtr &arg) -> AbstractBasePtr { return arg->Broaden(); });
  366. (void)choices.insert(broaded_argvals);
  367. MS_LOG(DEBUG) << "Broaded_argvals: " << broaded_argvals.size() << ", " << ::mindspore::ToString(broaded_argvals);
  368. }
  369. if (1 == choices.size()) {
  370. ConfigPtrList args_conf_list;
  371. (void)std::transform(broaded_argvals.begin(), broaded_argvals.end(), std::back_inserter(args_conf_list),
  372. [](AbstractBasePtr v) -> ConfigPtr { return std::make_shared<VirtualConfig>(v); });
  373. // if broaden return null
  374. ret = eval->Run(engine_, args_conf_list, nullptr);
  375. EvaluatorCacheMapPtr real = std::make_shared<EvaluatorCacheMap>();
  376. (*real)[broaded_argvals] = ret;
  377. evalcaches_[eval] = real;
  378. return std::make_pair(broaded_argvals, ret);
  379. } else {
  380. MS_LOG(DEBUG) << "Choices.size: " << choices.size();
  381. return std::make_pair(AbstractBasePtrList(), nullptr);
  382. }
  383. }
  384. void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
  385. MS_EXCEPTION_IF_NULL(new_node);
  386. if (specializer_->seen().count(new_node) > 0) {
  387. return;
  388. }
  389. specializer_->AddSeen(new_node);
  390. auto new_inputs = new_node->inputs();
  391. if (new_inputs.empty()) {
  392. MS_LOG(EXCEPTION) << "Inputs of CNode is empty";
  393. }
  394. AnfNodePtr func = new_inputs[0];
  395. MS_EXCEPTION_IF_NULL(func);
  396. // First element is func so arg start from 1
  397. std::vector<AnfNodePtr> args(new_inputs.begin() + 1, new_inputs.end());
  398. // CNode(CNode(Partial, f, arg1), arg2, ...) --> CNode(f, arg1, arg2, ...)
  399. while (IsPrimitiveCNode(func, prim::kPrimPartial)) {
  400. std::vector<AnfNodePtr> inputs = func->cast<CNodePtr>()->inputs();
  401. // First element is partial, second is func so arg is start from 2
  402. (void)args.insert(args.begin(), inputs.begin() + 2, inputs.end());
  403. func = inputs[1];
  404. new_inputs = args;
  405. (void)new_inputs.insert(new_inputs.begin(), func);
  406. }
  407. AbstractBasePtrList argvals;
  408. MS_EXCEPTION_IF_NULL(new_inputs[0]);
  409. AbstractBasePtr fnval = new_inputs[0]->abstract();
  410. MS_LOG(DEBUG) << "The new_inputs[0] node: pointer: " << new_inputs[0]->ToString() << ", "
  411. << new_inputs[0]->DebugString() << ", abstract: " << new_inputs[0]->abstract()->ToString();
  412. // First element is func so function arguments start from 1
  413. for (size_t i = 1; i < new_inputs.size(); ++i) {
  414. argvals.push_back(new_inputs[i]->abstract());
  415. MS_LOG(DEBUG) << "The new_inputs[" << i << "] node: pointer: " << new_inputs[i]->ToString() << ", "
  416. << new_inputs[i]->DebugString() << ", abstract: " << new_inputs[i]->abstract()->ToString();
  417. }
  418. if (CanSpecializeNode(func)) {
  419. new_inputs[0] = BuildSpecializedNode(func, fnval, argvals);
  420. }
  421. for (size_t i = 0; i < argvals.size();) {
  422. size_t next = i + 1;
  423. if (CanSpecializeNode(args[i])) {
  424. new_inputs[next] = BuildSpecializedNode(args[i], argvals[i], std::vector<AbstractBasePtr>{});
  425. }
  426. // support for partial(Multitype) which Multitype should not be inferred to POLY.
  427. // after one or more times clone, Multitype metafuncgraph evaluator will specialized to one type only,
  428. // so even with partial parameter, it will specialize to that graph.
  429. // Maybe a better idea should inline graph with partial node first, then it will have full
  430. // parameter list to infer and specialize.
  431. MS_EXCEPTION_IF_NULL(new_inputs[next]);
  432. if (new_inputs[next]->isa<ValueNode>() && (GetValueNode(new_inputs[next]) == kPolyNode) &&
  433. IsPrimitive(func, prim::kPrimPartial)) {
  434. new_inputs[next] = args[i];
  435. }
  436. i = next;
  437. }
  438. new_node->set_inputs(new_inputs);
  439. }
  440. namespace {
  441. void DumpEvaluatorCache(const EvaluatorCacheMap &evaluator_cache_map, const AbstractBasePtrList &argvals) {
  442. MS_LOG(DEBUG) << "Find unique argvals failed: " << argvals.size() << ", " << argvals << ". Check cache all items.";
  443. int i = 0;
  444. for (const auto &item : evaluator_cache_map) {
  445. MS_LOG(DEBUG) << "evaluator_cache_map[" << i++ << "]: " << item.first;
  446. }
  447. }
  448. bool IsPolyFunc(const AbstractFunctionPtr &func, const AbstractBasePtrList &argvals) {
  449. if (func->isa<PrimitiveAbstractClosure>() && argvals.empty()) {
  450. MS_LOG(DEBUG) << "High order primitive return POLY.";
  451. return true;
  452. }
  453. if (func->isa<MetaFuncGraphAbstractClosure>() && argvals.empty()) {
  454. auto meta_func_graph_wrapper = dyn_cast<MetaFuncGraphAbstractClosure>(func);
  455. auto meta_func_graph = meta_func_graph_wrapper->meta_func_graph();
  456. if (meta_func_graph != nullptr && meta_func_graph->isa<prim::DoSignatureMetaFuncGraph>()) {
  457. auto do_signature = dyn_cast<prim::DoSignatureMetaFuncGraph>(meta_func_graph);
  458. if (do_signature != nullptr && do_signature->function()->isa<Primitive>()) {
  459. MS_LOG(DEBUG) << "High order primitive " << do_signature->function()->ToString() << " return POLY.";
  460. return true;
  461. }
  462. }
  463. }
  464. return false;
  465. }
  466. } // end anonymous namespace
  467. SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunctionPtr &func, const EvaluatorPtr &eval,
  468. const AbstractBasePtrList &argvals,
  469. std::pair<AbstractBasePtrList, AbstractBasePtr> *result) {
  470. MS_EXCEPTION_IF_NULL(func);
  471. MS_EXCEPTION_IF_NULL(eval);
  472. MS_EXCEPTION_IF_NULL(result);
  473. EvaluatorCacheMap evaluator_cache_map = *eval->cache();
  474. if (evaluator_cache_map.find(argvals) != evaluator_cache_map.end()) {
  475. *result = std::make_pair(argvals, evaluator_cache_map[argvals]);
  476. return kSpecializeSuccess;
  477. }
  478. DumpEvaluatorCache(evaluator_cache_map, argvals);
  479. const EvaluatorCacheMapPtr &choices = GetEvalCache(eval);
  480. MS_EXCEPTION_IF_NULL(choices);
  481. if (choices->count(argvals)) {
  482. *result = std::make_pair(argvals, (*choices)[argvals]);
  483. return kSpecializeSuccess;
  484. } else if (choices->size() == 1) {
  485. MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it.";
  486. *result = std::make_pair(choices->begin()->first, choices->begin()->second);
  487. return kSpecializeSuccess;
  488. } else if (choices->empty()) {
  489. MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase.";
  490. return kSpecializeFindUniqueArgvalDead;
  491. } else {
  492. if (IsPolyFunc(func, argvals)) {
  493. return kSpecializeFindUniqueArgvalPoly;
  494. }
  495. MS_LOG(DEBUG) << "Try to find generalized argvals.";
  496. *result = BuildFromBroadedArgsVal(eval);
  497. if (!result->first.empty()) {
  498. return kSpecializeSuccess;
  499. }
  500. MS_LOG(DEBUG) << "Find POLY code, it may be unused code or unresolved polymorphism.";
  501. return kSpecializeFindUniqueArgvalPoly;
  502. }
  503. }
  504. AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival) {
  505. MS_EXCEPTION_IF_NULL(origin_node);
  506. MS_EXCEPTION_IF_NULL(ival);
  507. AbstractFunctionPtr abs = dyn_cast<AbstractFunction>(ival);
  508. if (abs != nullptr) {
  509. // Cannot build a determinstic ValueNode if there are multiple possible AbstractFunction.
  510. if (abs->isa<AbstractFuncUnion>()) {
  511. return nullptr;
  512. }
  513. ValuePtr value = nullptr;
  514. if (abs->isa<PrimitiveAbstractClosure>()) {
  515. auto real_fn = dyn_cast<PrimitiveAbstractClosure>(abs);
  516. value = real_fn->prim();
  517. } else if (abs->isa<MetaFuncGraphAbstractClosure>()) {
  518. auto real_fn = dyn_cast<MetaFuncGraphAbstractClosure>(abs);
  519. value = real_fn->meta_func_graph();
  520. } else if (abs->isa<FuncGraphAbstractClosure>()) {
  521. auto real_fn = dyn_cast<FuncGraphAbstractClosure>(abs);
  522. value = real_fn->func_graph();
  523. } else {
  524. return nullptr;
  525. }
  526. if (!value->isa<FuncGraph>() || value->cast<FuncGraphPtr>()->parent() == nullptr ||
  527. (IsValueNode<FuncGraph>(origin_node) && IsVisible(func_graph_, value->cast<FuncGraphPtr>()->parent()))) {
  528. return BuildValueNode(value, ival);
  529. } else {
  530. return nullptr;
  531. }
  532. } else {
  533. ValuePtr val = ival->BuildValue();
  534. if (val->isa<AnyValue>()) {
  535. return nullptr;
  536. } else {
  537. return BuildValueNode(val, ival);
  538. }
  539. }
  540. }
  541. AnfNodeConfigPtr FuncGraphSpecializer::MakeConfig(const AnfNodePtr &node) {
  542. return engine_->MakeConfig(node, context_);
  543. }
  544. } // namespace abstract
  545. } // namespace mindspore