You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

program_specialize.cc 27 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/static_analysis/program_specialize.h"
  19. #include <algorithm>
  20. #include <exception>
  21. #include "./common.h"
  22. #include "operator/ops.h"
  23. #include "operator/composite/do_signature.h"
  24. #include "pipeline/static_analysis/abstract_function.h"
  25. #include "utils/graph_utils.h"
  26. #include "utils/log_adapter.h"
  27. #include "utils/profile.h"
  28. #include "debug/trace.h"
  29. namespace mindspore {
  30. namespace abstract {
  31. namespace {
  32. inline AbstractBasePtr GetEvaluatedValueWrap(const AnfNodeConfigPtr &conf) {
  33. if (conf->node()->intermediate_abstract()) {
  34. return conf->node()->intermediate_abstract();
  35. }
  36. return conf->GetEvaluatedValue();
  37. }
  38. AnfNodePtr BuildValueNode(const ValuePtr &v, const AbstractBasePtr &abs_base) {
  39. AnfNodePtr value_node = NewValueNode(v);
  40. value_node->set_abstract(abs_base);
  41. MS_LOG(DEBUG) << "Create ValueNode: " << value_node->ToString() << ", with abstract: " << abs_base->ToString();
  42. return value_node;
  43. }
  44. bool IsVisible(FuncGraphPtr fg, const FuncGraphPtr &parent) {
  45. while (fg != nullptr && fg != parent) {
  46. fg = fg->parent();
  47. }
  48. return fg == parent;
  49. }
  50. } // namespace
  51. FuncGraphPtr ProgramSpecializer::Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context) {
  52. MS_EXCEPTION_IF_NULL(fg);
  53. MS_EXCEPTION_IF_NULL(context);
  54. MS_LOG(DEBUG) << "Specialize topmost function graph: " << context->func_graph()->ToString();
  55. return SpecializeFuncGraph(fg, context);
  56. }
  57. FuncGraphPtr ProgramSpecializer::SpecializeFuncGraph(const FuncGraphPtr &fg, const AnalysisContextPtr &context) {
  58. MS_EXCEPTION_IF_NULL(fg);
  59. MS_EXCEPTION_IF_NULL(context);
  60. auto iter = specializations_.find(context->SpecializeKey());
  61. if (iter != specializations_.end()) {
  62. return iter->second->specialized_func_graph();
  63. }
  64. std::shared_ptr<FuncGraphSpecializer> fg_spec = std::make_shared<FuncGraphSpecializer>(this, fg, context);
  65. FuncGraphPtr fg2 = fg_spec->specialized_func_graph();
  66. specializations_[context->SpecializeKey()] = fg_spec;
  67. fg_spec->Run();
  68. return fg2;
  69. }
  70. std::shared_ptr<FuncGraphSpecializer> ProgramSpecializer::GetFuncGraphSpecializer(const AnalysisContextPtr &context) {
  71. MS_EXCEPTION_IF_NULL(context);
  72. auto iter = specializations_.find(context->SpecializeKey());
  73. if (iter != specializations_.end()) {
  74. return iter->second;
  75. }
  76. return nullptr;
  77. }
  78. std::string GetNextCounter() {
  79. static int g_CloneCounter = 1;
  80. std::string str_count = std::to_string(g_CloneCounter);
  81. g_CloneCounter++;
  82. return str_count;
  83. }
  84. FuncGraphSpecializer::FuncGraphSpecializer(ProgramSpecializer *const s, const FuncGraphPtr &fg,
  85. const AnalysisContextPtr &context)
  86. : specializer_(s), func_graph_(fg), context_(context) {
  87. parent_ = s->GetFuncGraphSpecializer(context->parent());
  88. engine_ = s->engine();
  89. cloner_ = SpecializerClone(fg, std::make_shared<TraceSpecialize>(GetNextCounter()));
  90. repl_node_ = cloner_->cloned_node();
  91. specialized_func_graph_ = cloner_->cloned_func_graph()[fg];
  92. todo_.push_back(fg->get_return());
  93. auto ps = fg->parameters();
  94. (void)todo_.insert(todo_.end(), ps.begin(), ps.end());
  95. }
  96. AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &node) {
  97. MS_EXCEPTION_IF_NULL(node);
  98. FuncGraphPtr fg = node->func_graph();
  99. if (node->isa<ValueNode>()) {
  100. return node;
  101. }
  102. std::shared_ptr<FuncGraphSpecializer> specializer = shared_from_this();
  103. while (fg != nullptr && fg != specializer->func_graph_) {
  104. specializer = specializer->parent_;
  105. }
  106. // If had replicated, just return that.
  107. auto iter = specializer->repl_node_->find(node);
  108. if (iter != specializer->repl_node_->end()) {
  109. return iter->second;
  110. }
  111. auto new_node = specializer->cloner_->CloneDisconnected(node);
  112. if (node->isa<CNode>()) {
  113. if (!new_node->isa<CNode>()) {
  114. MS_LOG(EXCEPTION) << "new_node must be a CNode, but is " << new_node->DebugString() << ".";
  115. }
  116. auto c_node = node->cast<CNodePtr>();
  117. MS_EXCEPTION_IF_NULL(c_node);
  118. auto inputs = c_node->inputs();
  119. std::vector<AnfNodePtr> new_inputs;
  120. (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(new_inputs),
  121. [this](const AnfNodePtr &inp) -> AnfNodePtr {
  122. if (inp->isa<ValueNode>()) {
  123. return inp;
  124. }
  125. return ReplicateDisconnectedNode(inp);
  126. });
  127. auto c_new_node = new_node->cast<CNodePtr>();
  128. MS_EXCEPTION_IF_NULL(c_new_node);
  129. c_new_node->set_inputs(new_inputs);
  130. }
  131. iter = specializer->repl_node_->find(node);
  132. if (iter != specializer->repl_node_->end()) {
  133. if (iter->second == node) {
  134. MS_LOG(EXCEPTION) << "Replicated is same as original node, node: " << node->ToString();
  135. }
  136. } else {
  137. MS_LOG(EXCEPTION) << "Replicate node failed, node: " << node->ToString();
  138. }
  139. return new_node;
  140. }
  141. AnfNodePtr FuncGraphSpecializer::GetReplicatedNode(const AnfNodePtr &node) {
  142. MS_EXCEPTION_IF_NULL(node);
  143. FuncGraphPtr fg = node->func_graph();
  144. std::shared_ptr<FuncGraphSpecializer> specializer = shared_from_this();
  145. while (fg != nullptr && fg != specializer->func_graph_) {
  146. specializer = specializer->parent_;
  147. }
  148. MS_EXCEPTION_IF_NULL(specializer->repl_node_);
  149. auto iter = specializer->repl_node_->find(node);
  150. if (iter != specializer->repl_node_->end()) {
  151. return iter->second;
  152. }
  153. return node;
  154. }
  155. void FuncGraphSpecializer::Run() {
  156. MS_LOG(DEBUG) << "Before run, origin func graph name: " << func_graph_->ToString()
  157. << ", cloned func graph name: " << specialized_func_graph_->ToString()
  158. << ", func graph: " << func_graph_->get_return()->DebugString();
  159. FirstPass();
  160. SecondPass();
  161. MS_LOG(DEBUG) << "After run, origin func graph name: " << func_graph_->ToString()
  162. << ", cloned func graph name: " << specialized_func_graph_->ToString()
  163. << ", new func graph: " << specialized_func_graph_->get_return()->DebugString();
  164. }
  165. void FuncGraphSpecializer::FirstPass() {
  166. while (todo_.size()) {
  167. AnfNodePtr node = todo_.back();
  168. todo_.pop_back();
  169. if (node->func_graph() == nullptr) {
  170. // do nothing for ValueNode
  171. continue;
  172. }
  173. if (node->func_graph() != func_graph_) {
  174. if (parent_ == nullptr) {
  175. MS_LOG(EXCEPTION) << "Parent must not null NodeInfo: " << trace::GetDebugInfo(node->debug_info());
  176. }
  177. parent_->AddTodoItem(node);
  178. parent_->FirstPass();
  179. AnfNodePtr new_node = parent_->GetReplicatedNode(node);
  180. if (node->isa<CNode>()) {
  181. parent_->ProcessCNode(new_node->cast<CNodePtr>());
  182. }
  183. continue;
  184. }
  185. if (marked_.count(node) > 0) {
  186. continue;
  187. }
  188. (void)marked_.insert(node);
  189. ProcessNode(node);
  190. }
  191. }
  192. // Specialize CNode in func graphs
  193. void FuncGraphSpecializer::SecondPass() {
  194. for (auto &node : DeepLinkedGraphSearch(specialized_func_graph_->get_return())) {
  195. if (node->isa<CNode>()) {
  196. ProcessCNode(node->cast<CNodePtr>());
  197. }
  198. }
  199. }
  200. void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
  201. MS_EXCEPTION_IF_NULL(node);
  202. ScopeGuard scope_guard(node->scope());
  203. AnfNodeConfigPtr conf = MakeConfig(node);
  204. AnfNodePtr new_node = GetReplicatedNode(node);
  205. MS_EXCEPTION_IF_NULL(new_node);
  206. if (new_node->func_graph() != specialized_func_graph_) {
  207. MS_LOG(EXCEPTION) << "Error in specializer [A] node: " << node->DebugString()
  208. << ", new_node: " << new_node->DebugString()
  209. << ", new_node->func_graph(): " << new_node->func_graph()->ToString()
  210. << ", specialized_func_graph_: " << specialized_func_graph_->ToString();
  211. return;
  212. }
  213. new_node->set_abstract(GetEvaluatedValueWrap(conf));
  214. if (new_node->isa<CNode>() && new_node->abstract()->isa<PartialAbstractClosure>()) {
  215. auto partial_abstract = dyn_cast<PartialAbstractClosure>(new_node->abstract());
  216. if (partial_abstract->node() == node) {
  217. partial_abstract->set_node(new_node);
  218. }
  219. }
  220. MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString();
  221. if (node->isa<CNode>()) {
  222. auto c_old = node->cast<CNodePtr>();
  223. auto c_new = new_node->cast<CNodePtr>();
  224. auto new_inputs = c_new->inputs();
  225. auto old_inputs = c_old->inputs();
  226. for (size_t i = 0; i < old_inputs.size(); ++i) {
  227. auto node_input = old_inputs[i];
  228. AnfNodeConfigPtr iconf = MakeConfig(node_input);
  229. AbstractBasePtr ival = GetEvaluatedValueWrap(iconf);
  230. // First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if
  231. // can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node.
  232. AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival);
  233. if (replace_node == nullptr) {
  234. replace_node = BuildReplacedNode(iconf);
  235. MS_EXCEPTION_IF_NULL(replace_node);
  236. replace_node->set_abstract(ival);
  237. MS_LOG(DEBUG) << "Set replaced: " << replace_node->ToString() << ", to abstract: " << ival->ToString();
  238. } else {
  239. MS_LOG(DEBUG) << "Build possible value node for node: " << node_input->DebugString()
  240. << ", ival: " << ival->ToString() << ", replace_node: " << replace_node->ToString();
  241. }
  242. if (new_inputs[i] != replace_node) {
  243. new_inputs[i] = replace_node;
  244. MS_LOG(DEBUG) << "Set new_input[" << i << "] = " << replace_node->DebugString();
  245. }
  246. }
  247. c_new->set_inputs(new_inputs);
  248. }
  249. }
  250. AnfNodePtr FuncGraphSpecializer::BuildReplacedNode(const AnfNodeConfigPtr &conf) {
  251. MS_EXCEPTION_IF_NULL(conf);
  252. auto conf_iter = engine_->anfnode_config_map().find(conf);
  253. AnfNodeConfigPtr new_conf = conf;
  254. while (conf_iter != engine_->anfnode_config_map().end()) {
  255. MS_LOG(DEBUG) << "Origin conf: graph(" << new_conf->node()->func_graph()->ToString() << ", node("
  256. << new_conf->node()->DebugString() << ")";
  257. new_conf = conf_iter->second;
  258. MS_EXCEPTION_IF_NULL(new_conf);
  259. MS_LOG(DEBUG) << "Replaced conf: graph(" << conf->node()->func_graph()->ToString() << ", node("
  260. << conf->node()->DebugString() << ")";
  261. (void)ReplicateDisconnectedNode(new_conf->node());
  262. conf_iter = engine_->anfnode_config_map().find(new_conf);
  263. }
  264. todo_.push_back(new_conf->node());
  265. auto repl = GetReplicatedNode(new_conf->node());
  266. if (repl->func_graph()) {
  267. MS_LOG(DEBUG) << "Set repl: graph(" << repl->func_graph()->ToString() << "), node:" << repl->DebugString()
  268. << ") to replace origin:" << new_conf->node()->DebugString();
  269. } else {
  270. MS_LOG(DEBUG) << "Set repl: graph(nullptr), node(" << repl->DebugString()
  271. << ") to replace origin: " << new_conf->node()->DebugString();
  272. }
  273. return repl;
  274. }
  275. namespace {
  276. const StringImmPtr kDeadNode = std::make_shared<StringImm>("Dead Node");
  277. const StringImmPtr kPolyNode = std::make_shared<StringImm>("Poly Node");
  278. inline bool CanSpecializeNode(const AnfNodePtr &node) {
  279. if (IsValueNode<FuncGraph>(node) || IsValueNode<MetaFuncGraph>(node) || IsValueNode<Primitive>(node)) {
  280. return true;
  281. }
  282. return false;
  283. }
  284. } // namespace
  285. AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs,
  286. const AbstractBasePtrList &argvals) {
  287. MS_EXCEPTION_IF_NULL(abs);
  288. AbstractFunctionPtr real_a = dyn_cast<AbstractFunction>(abs);
  289. MS_EXCEPTION_IF_NULL(real_a);
  290. AbstractFunctionPtr func = real_a->GetUnique();
  291. SpecializeStatusCode errcode;
  292. ScopeGuard scope_guard(node->scope());
  293. AnfNodePtr repl = BuildSpecializedNodeInner(abs, func, argvals, &errcode);
  294. if (repl == nullptr) {
  295. if (errcode == kSpecializeFindUniqueArgvalDead) {
  296. const auto error_dead_node = std::make_shared<AbstractError>(kDeadNode, node);
  297. repl = BuildValueNode(kDeadNode, error_dead_node);
  298. MS_LOG(DEBUG) << "DEAD for node: " << node->DebugString() << ", abstract: " << abs->ToString();
  299. } else if (errcode == kSpecializeFindUniqueArgvalPoly) {
  300. const auto error_poly_node = std::make_shared<AbstractError>(kPolyNode, node);
  301. repl = BuildValueNode(kPolyNode, error_poly_node);
  302. MS_LOG(DEBUG) << "POLY for node: " << node->DebugString() << ", abstract: " << abs->ToString();
  303. } else {
  304. MS_LOG(EXCEPTION) << "Failed to build specialized node, node: " << node->DebugString()
  305. << ", abstract: " << abs->ToString();
  306. }
  307. }
  308. return repl;
  309. }
  310. AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr &abs, const AbstractFunctionPtr &func,
  311. const AbstractBasePtrList &args,
  312. SpecializeStatusCode *errcode) {
  313. MS_EXCEPTION_IF_NULL(abs);
  314. MS_EXCEPTION_IF_NULL(func);
  315. MS_EXCEPTION_IF_NULL(errcode);
  316. *errcode = kSpecializeSuccess;
  317. auto real_func = dyn_cast<TypedPrimitiveAbstractClosure>(func);
  318. if (real_func != nullptr) {
  319. return BuildValueNode(real_func->prim(), abs);
  320. }
  321. EvaluatorPtr eval;
  322. eval = engine_->GetEvaluatorFor(func);
  323. MS_EXCEPTION_IF_NULL(eval);
  324. AbstractBasePtrList argvals = eval->NormalizeArgs(args);
  325. std::pair<AbstractBasePtrList, AbstractBasePtr> result;
  326. SpecializeStatusCode status = FindUniqueArgvals(func, eval, argvals, &result);
  327. if (status != kSpecializeSuccess) {
  328. *errcode = status;
  329. return nullptr;
  330. }
  331. argvals = result.first;
  332. AbstractBasePtr unique_output = result.second;
  333. auto prim_func = dyn_cast<PrimitiveAbstractClosure>(func);
  334. if (prim_func != nullptr) {
  335. auto type_func = std::make_shared<TypedPrimitiveAbstractClosure>(prim_func->prim(), argvals, unique_output);
  336. return BuildValueNode(prim_func->prim(), type_func);
  337. }
  338. if (!eval->isa<BaseFuncGraphEvaluator>()) {
  339. MS_LOG(EXCEPTION) << "Eval is not BaseGraphEvaluator, but " << eval->ToString();
  340. }
  341. auto real_eval = dyn_cast<BaseFuncGraphEvaluator>(eval);
  342. if (func->context() != nullptr) {
  343. if (!IsVisible(func_graph_, func->context()->func_graph())) {
  344. MS_LOG(EXCEPTION) << "Func is not visible NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info());
  345. }
  346. } else {
  347. MS_LOG(EXCEPTION) << "Func context is nullptr NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info());
  348. }
  349. AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals);
  350. MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << argvals.size()
  351. << ", graph: " << context->func_graph()->get_return()->DebugString();
  352. FuncGraphPtr v = specializer_->SpecializeFuncGraph(context->func_graph(), context);
  353. return BuildValueNode(v, abs);
  354. }
  355. AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &new_node) {
  356. auto new_inputs = new_node->inputs();
  357. AnfNodePtr func = new_inputs[0];
  358. AbstractBasePtr fnval = new_inputs[0]->abstract();
  359. AbstractBasePtrList args;
  360. auto backed_fnval = fnval;
  361. if (fnval->isa<PartialAbstractClosure>()) {
  362. auto partial_closure = dyn_cast<PartialAbstractClosure>(fnval);
  363. backed_fnval = partial_closure->fn();
  364. args = partial_closure->args();
  365. }
  366. std::transform(new_inputs.cbegin() + 1, new_inputs.cend(), std::back_inserter(args),
  367. [](const AnfNodePtr &inp) { return inp->abstract(); });
  368. ScopeGuard scope_guard(new_node->scope());
  369. auto specialized_node = BuildSpecializedNode(func, backed_fnval, args);
  370. auto wrapped_node = specialized_node;
  371. if (fnval->isa<PartialAbstractClosure>()) {
  372. auto partial_closure = dyn_cast<PartialAbstractClosure>(fnval);
  373. AnfNodePtrList partial_node_list = {BuildValueNode(prim::kPrimPartial, FromValueInside(prim::kPrimPartial)),
  374. specialized_node};
  375. auto anf_node = partial_closure->node();
  376. if (!anf_node->isa<CNode>()) {
  377. MS_LOG(EXCEPTION) << "Must be cnode, but " << anf_node->DebugString();
  378. }
  379. auto cnode = anf_node->cast<CNodePtr>();
  380. if (cnode->size() != partial_closure->args().size() + 2) {
  381. MS_LOG(EXCEPTION) << "Size of cnode: " << cnode->DebugString()
  382. << " is not equal to 2 added to size of args: " << mindspore::ToString(partial_closure->args());
  383. }
  384. for (size_t i = 0; i < partial_closure->args().size(); i++) {
  385. auto old_node = cnode->input(i + 2);
  386. auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i]);
  387. if (possibile_value_node != nullptr) {
  388. partial_node_list.push_back(possibile_value_node);
  389. } else {
  390. if (!(old_node->isa<CNode>() || old_node->isa<Parameter>())) {
  391. MS_LOG(EXCEPTION) << "Old node should be CNode or Parameter, but " << old_node->ToString();
  392. }
  393. partial_node_list.push_back(old_node);
  394. }
  395. }
  396. wrapped_node = new_node->func_graph()->NewCNode(partial_node_list);
  397. wrapped_node->set_abstract(partial_closure);
  398. }
  399. return wrapped_node;
  400. }
  401. const EvaluatorCacheMapPtr &FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) {
  402. auto cache_iter = evalcaches_.find(eval);
  403. if (cache_iter == evalcaches_.end()) {
  404. evalcaches_[eval] = eval->cache();
  405. return eval->cache();
  406. }
  407. return cache_iter->second;
  408. }
  409. std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromBroadedArgsVal(
  410. const EvaluatorPtr &eval) {
  411. MS_EXCEPTION_IF_NULL(eval);
  412. std::unordered_set<AbstractBasePtrList, AbstractBasePtrListHasher, AbstractBasePtrListEqual> choices;
  413. AbstractBasePtr ret = nullptr;
  414. AbstractBasePtrList broaded_argvals;
  415. for (auto &argvals_map : *evalcaches_[eval]) {
  416. auto argvals = argvals_map.first;
  417. broaded_argvals.clear();
  418. (void)std::transform(argvals.begin(), argvals.end(), std::back_inserter(broaded_argvals),
  419. [](const AbstractBasePtr &arg) -> AbstractBasePtr { return arg->Broaden(); });
  420. (void)choices.insert(broaded_argvals);
  421. MS_LOG(DEBUG) << "Broaded_argvals: " << broaded_argvals.size() << ", " << ::mindspore::ToString(broaded_argvals);
  422. }
  423. if (1 == choices.size()) {
  424. ConfigPtrList args_conf_list;
  425. (void)std::transform(broaded_argvals.begin(), broaded_argvals.end(), std::back_inserter(args_conf_list),
  426. [](AbstractBasePtr v) -> ConfigPtr { return std::make_shared<VirtualConfig>(v); });
  427. // if broaden return null
  428. ret = eval->Run(engine_, args_conf_list, nullptr);
  429. EvaluatorCacheMapPtr real = std::make_shared<EvaluatorCacheMap>();
  430. (*real)[broaded_argvals] = ret;
  431. evalcaches_[eval] = real;
  432. return std::make_pair(broaded_argvals, ret);
  433. } else {
  434. MS_LOG(DEBUG) << "Choices.size: " << choices.size();
  435. return std::make_pair(AbstractBasePtrList(), nullptr);
  436. }
  437. }
  438. void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
  439. MS_EXCEPTION_IF_NULL(new_node);
  440. if (specializer_->seen().count(new_node) > 0) {
  441. return;
  442. }
  443. specializer_->AddSeen(new_node);
  444. auto new_inputs = new_node->inputs();
  445. if (new_inputs.empty()) {
  446. MS_LOG(EXCEPTION) << "Inputs of CNode is empty";
  447. }
  448. AnfNodePtr func = new_inputs[0];
  449. MS_EXCEPTION_IF_NULL(func);
  450. // First element is func so arg start from 1
  451. std::vector<AnfNodePtr> args(new_inputs.begin() + 1, new_inputs.end());
  452. // CNode(CNode(Partial, f, arg1), arg2, ...) --> CNode(f, arg1, arg2, ...)
  453. while (IsPrimitiveCNode(func, prim::kPrimPartial)) {
  454. std::vector<AnfNodePtr> inputs = func->cast<CNodePtr>()->inputs();
  455. // First element is partial, second is func so arg is start from 2
  456. (void)args.insert(args.begin(), inputs.begin() + 2, inputs.end());
  457. func = inputs[1];
  458. new_inputs = args;
  459. (void)new_inputs.insert(new_inputs.begin(), func);
  460. }
  461. AbstractBasePtrList argvals;
  462. MS_EXCEPTION_IF_NULL(new_inputs[0]);
  463. AbstractBasePtr fnval = new_inputs[0]->abstract();
  464. MS_LOG(DEBUG) << "The new_inputs[0] node: pointer: " << new_inputs[0]->ToString() << ", "
  465. << new_inputs[0]->DebugString() << ", abstract: " << new_inputs[0]->abstract()->ToString();
  466. // First element is func so function arguments start from 1
  467. for (size_t i = 1; i < new_inputs.size(); ++i) {
  468. argvals.push_back(new_inputs[i]->abstract());
  469. MS_LOG(DEBUG) << "The new_inputs[" << i << "] node: pointer: " << new_inputs[i]->ToString() << ", "
  470. << new_inputs[i]->DebugString() << ", abstract: " << new_inputs[i]->abstract()->ToString();
  471. }
  472. if (func->isa<Parameter>() && func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER)) {
  473. auto wrapped_node = BuildSpecializedParameterNode(new_node);
  474. new_inputs[0] = wrapped_node;
  475. }
  476. if (CanSpecializeNode(func)) {
  477. new_inputs[0] = BuildSpecializedNode(func, fnval, argvals);
  478. }
  479. for (size_t i = 0; i < argvals.size();) {
  480. size_t next = i + 1;
  481. if (CanSpecializeNode(args[i])) {
  482. new_inputs[next] = BuildSpecializedNode(args[i], argvals[i], std::vector<AbstractBasePtr>{});
  483. }
  484. i = next;
  485. }
  486. new_node->set_inputs(new_inputs);
  487. }
  488. namespace {
  489. void DumpEvaluatorCache(const EvaluatorCacheMap &evaluator_cache_map, const AbstractBasePtrList &argvals) {
  490. MS_LOG(DEBUG) << "Find unique argvals failed: " << argvals.size() << ", " << argvals << ". Check cache all items.";
  491. int i = 0;
  492. for (const auto &item : evaluator_cache_map) {
  493. MS_LOG(DEBUG) << "evaluator_cache_map[" << i++ << "]: " << item.first;
  494. }
  495. }
  496. bool IsPolyFunc(const AbstractFunctionPtr &func, const AbstractBasePtrList &argvals) {
  497. if (func->isa<PrimitiveAbstractClosure>() && argvals.empty()) {
  498. MS_LOG(DEBUG) << "High order primitive return POLY.";
  499. return true;
  500. }
  501. if (func->isa<MetaFuncGraphAbstractClosure>() && argvals.empty()) {
  502. auto meta_func_graph_wrapper = dyn_cast<MetaFuncGraphAbstractClosure>(func);
  503. auto meta_func_graph = meta_func_graph_wrapper->meta_func_graph();
  504. if (meta_func_graph != nullptr && meta_func_graph->isa<prim::DoSignatureMetaFuncGraph>()) {
  505. auto do_signature = dyn_cast<prim::DoSignatureMetaFuncGraph>(meta_func_graph);
  506. if (do_signature != nullptr && do_signature->function()->isa<Primitive>()) {
  507. MS_LOG(DEBUG) << "High order primitive " << do_signature->function()->ToString() << " return POLY.";
  508. return true;
  509. }
  510. }
  511. }
  512. return false;
  513. }
  514. } // end anonymous namespace
  515. SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunctionPtr &func, const EvaluatorPtr &eval,
  516. const AbstractBasePtrList &argvals,
  517. std::pair<AbstractBasePtrList, AbstractBasePtr> *result) {
  518. MS_EXCEPTION_IF_NULL(func);
  519. MS_EXCEPTION_IF_NULL(eval);
  520. MS_EXCEPTION_IF_NULL(result);
  521. EvaluatorCacheMap evaluator_cache_map = *eval->cache();
  522. if (evaluator_cache_map.find(argvals) != evaluator_cache_map.end()) {
  523. *result = std::make_pair(argvals, evaluator_cache_map[argvals]);
  524. return kSpecializeSuccess;
  525. }
  526. DumpEvaluatorCache(evaluator_cache_map, argvals);
  527. const EvaluatorCacheMapPtr &choices = GetEvalCache(eval);
  528. MS_EXCEPTION_IF_NULL(choices);
  529. if (choices->count(argvals)) {
  530. *result = std::make_pair(argvals, (*choices)[argvals]);
  531. return kSpecializeSuccess;
  532. } else if (choices->size() == 1) {
  533. MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it.";
  534. *result = std::make_pair(choices->begin()->first, choices->begin()->second);
  535. return kSpecializeSuccess;
  536. } else if (choices->empty()) {
  537. MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase.";
  538. return kSpecializeFindUniqueArgvalDead;
  539. } else {
  540. if (IsPolyFunc(func, argvals)) {
  541. return kSpecializeFindUniqueArgvalPoly;
  542. }
  543. MS_LOG(DEBUG) << "Try to find generalized argvals.";
  544. *result = BuildFromBroadedArgsVal(eval);
  545. if (!result->first.empty()) {
  546. return kSpecializeSuccess;
  547. }
  548. MS_LOG(DEBUG) << "Find POLY code, it may be unused code or unresolved polymorphism.";
  549. return kSpecializeFindUniqueArgvalPoly;
  550. }
  551. }
  552. AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival) {
  553. MS_EXCEPTION_IF_NULL(origin_node);
  554. MS_EXCEPTION_IF_NULL(ival);
  555. AbstractFunctionPtr abs = dyn_cast<AbstractFunction>(ival);
  556. if (abs != nullptr) {
  557. // Cannot build a determinstic ValueNode if there are multiple possible AbstractFunction.
  558. if (abs->isa<AbstractFuncUnion>()) {
  559. return nullptr;
  560. }
  561. ValuePtr value = nullptr;
  562. if (abs->isa<PrimitiveAbstractClosure>()) {
  563. auto real_fn = dyn_cast<PrimitiveAbstractClosure>(abs);
  564. value = real_fn->prim();
  565. } else if (abs->isa<MetaFuncGraphAbstractClosure>()) {
  566. auto real_fn = dyn_cast<MetaFuncGraphAbstractClosure>(abs);
  567. value = real_fn->meta_func_graph();
  568. } else if (abs->isa<FuncGraphAbstractClosure>()) {
  569. auto real_fn = dyn_cast<FuncGraphAbstractClosure>(abs);
  570. value = real_fn->func_graph();
  571. } else {
  572. return nullptr;
  573. }
  574. if (!value->isa<FuncGraph>() || value->cast<FuncGraphPtr>()->parent() == nullptr ||
  575. (IsValueNode<FuncGraph>(origin_node) && IsVisible(func_graph_, value->cast<FuncGraphPtr>()->parent()))) {
  576. return BuildValueNode(value, ival);
  577. } else {
  578. return nullptr;
  579. }
  580. } else {
  581. ValuePtr val = ival->BuildValue();
  582. if (val->isa<AnyValue>()) {
  583. return nullptr;
  584. }
  585. // keep primitive 'depend' not to be optimized
  586. if (IsPrimitiveCNode(origin_node, prim::kPrimDepend)) {
  587. return nullptr;
  588. }
  589. return BuildValueNode(val, ival);
  590. }
  591. }
  592. AnfNodeConfigPtr FuncGraphSpecializer::MakeConfig(const AnfNodePtr &node) {
  593. return engine_->MakeConfig(node, context_);
  594. }
  595. } // namespace abstract
  596. } // namespace mindspore