You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

static_analysis.cc 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/static_analysis/static_analysis.h"
  19. #include <algorithm>
  20. #include "pipeline/static_analysis/utils.h"
  21. #include "pipeline/static_analysis/prim.h"
  22. #include "operator/ops.h"
  23. #include "utils/symbolic.h"
  24. #include "ir/meta_tensor.h"
  25. #include "ir/func_graph_cloner.h"
  26. #include "./common.h"
  27. #include "pipeline/parse/data_converter.h"
  28. #include "debug/draw.h"
  29. #include "pipeline/static_analysis/evaluator.h"
  30. #include "debug/trace.h"
  31. namespace mindspore {
  32. namespace abstract {
  33. bool IsIntermediateAbstract(const AbstractBasePtr &arg_spec) {
  34. if (dyn_cast<AbstractScalar>(arg_spec)) {
  35. auto v = arg_spec->GetValueTrack();
  36. if (v->isa<SymbolicKeyInstance>()) {
  37. return true;
  38. } else {
  39. return false;
  40. }
  41. } else {
  42. return false;
  43. }
  44. }
  45. AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBasePtr &arg2) {
  46. if (dyn_cast<AbstractScalar>(arg1) && dyn_cast<AbstractScalar>(arg2)) {
  47. return arg1->Join(arg2);
  48. }
  49. return nullptr;
  50. }
  51. void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) {
  52. MS_LOG(DEBUG) << "AnalysisCache set for NodeConfig: " << conf->node()->DebugString()
  53. << ", Context: " << conf->context()->ToString() << ", Value: " << arg->ToString()
  54. << ", Pointer: " << arg.get();
  55. cache_[conf] = arg;
  56. // Set intermediate abstract value.
  57. if (IsIntermediateAbstract(arg)) {
  58. if (conf->node()->intermediate_abstract() == nullptr) {
  59. conf->node()->set_intermediate_abstract(arg);
  60. MS_LOG(DEBUG) << "Set intermediate abstract: " << arg->ToString();
  61. } else {
  62. auto old_spec = conf->node()->intermediate_abstract();
  63. auto joined_spec = IntermediateJoin(arg, old_spec);
  64. conf->node()->set_intermediate_abstract(joined_spec);
  65. MS_LOG(DEBUG) << "Set joined intermediate abstract:\nold_spec:\t\t" << old_spec->ToString() << "\nnew_spec:\t\t"
  66. << arg->ToString() << "\njoined_spec:\t"
  67. << (joined_spec != nullptr ? joined_spec->ToString() : "nullptr");
  68. }
  69. }
  70. }
  71. AbstractBasePtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) {
  72. auto value = cache_.find(conf);
  73. if (value == cache_.end()) {
  74. return nullptr;
  75. }
  76. return value->second;
  77. }
  78. std::size_t AnfNodeConfigHasher::operator()(const AnfNodeConfigPtr conf) const {
  79. MS_EXCEPTION_IF_NULL(conf);
  80. MS_EXCEPTION_IF_NULL(conf->node());
  81. std::size_t hash_value = hash_combine(conf->node()->hash(), conf->context()->hash());
  82. if (conf->context() != nullptr && conf->context()->func_graph() != nullptr) {
  83. MS_LOG(DEBUG) << "NodeConfigHasher Node: " << conf->node()->DebugString()
  84. << ", Graph: " << conf->context()->func_graph()->ToString() << " ### , hash value: " << hash_value;
  85. } else {
  86. MS_LOG(DEBUG) << "NodeConfigHasher Node: " << conf->node()->DebugString() << " ### , hash value: " << hash_value;
  87. }
  88. return hash_value;
  89. }
  90. bool AnfNodeConfigEqual::operator()(const AnfNodeConfigPtr lhs, const AnfNodeConfigPtr rhs) const {
  91. if (lhs == nullptr || rhs == nullptr) {
  92. return false;
  93. }
  94. if (lhs == rhs) {
  95. return true;
  96. }
  97. return (*lhs == *rhs);
  98. }
  99. AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list) {
  100. ConfigPtrList args_conf_list;
  101. (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list),
  102. [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
  103. MS_EXCEPTION_IF_NULL(func_graph_manager_);
  104. func_graph_manager_->AddFuncGraph(func_graph);
  105. AnalysisContextPtr empty_context = AnalysisContext::DummyContext();
  106. // Running the analyzer.
  107. AnalysisContextPtr root_context = Run(func_graph, empty_context, args_conf_list);
  108. MS_EXCEPTION_IF_NULL(root_context);
  109. MS_EXCEPTION_IF_NULL(root_context->func_graph());
  110. AnfNodeConfigPtr output_conf = MakeConfig(root_context->func_graph()->get_return(), root_context);
  111. MS_EXCEPTION_IF_NULL(func_graph);
  112. MS_LOG(INFO) << func_graph->ToString() << ": Run finished.";
  113. AnalysisResult result;
  114. MS_EXCEPTION_IF_NULL(output_conf);
  115. result.inferred = output_conf->GetEvaluatedValue();
  116. result.context = root_context;
  117. return result;
  118. }
  119. AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
  120. const ConfigPtrList &args_conf_list) {
  121. std::shared_ptr<FuncGraphEvaluator> eval = std::make_shared<FuncGraphEvaluator>(func_graph, context);
  122. (void)eval->Run(shared_from_this(), args_conf_list, nullptr);
  123. return eval->graph_context();
  124. }
  125. AbstractBasePtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) {
  126. MS_EXCEPTION_IF_NULL(conf);
  127. auto value = cache_.GetValue(conf);
  128. if (value != nullptr) {
  129. MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString() << ", Value: " << value.get() << ", "
  130. << value->ToString();
  131. return value;
  132. }
  133. MS_LOG(DEBUG) << "Evaluate cache miss for NodeConfig: " << conf->ToString();
  134. value = Eval(conf);
  135. if (value == nullptr) {
  136. MS_LOG(EXCEPTION) << "Evaluate for NodeConfig " << conf->ToString() << " get nullptr";
  137. }
  138. cache_.set_value(conf, value);
  139. return value;
  140. }
  141. AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
  142. MS_EXCEPTION_IF_NULL(conf);
  143. AnfNodePtr node = conf->node();
  144. AbstractBasePtr ret_abstract = nullptr;
  145. #ifdef DEBUG
  146. compute_conf_stack_.push_back(node);
  147. std::ostringstream buffer;
  148. buffer << "Compute Config Begin:";
  149. for (auto iter : compute_conf_stack_) {
  150. buffer << " -> " << iter->DebugString();
  151. }
  152. MS_LOG(DEBUG) << buffer.str();
  153. #endif
  154. MS_LOG(DEBUG) << "Begin Eval NodeConfig " << conf->ToString();
  155. MS_EXCEPTION_IF_NULL(node);
  156. if (node->abstract() != nullptr) {
  157. MS_LOG(DEBUG) << "Return old abstract: " << node->DebugString();
  158. ret_abstract = node->abstract();
  159. } else if (node->isa<ValueNode>()) {
  160. auto value_node = node->cast<ValueNodePtr>();
  161. ret_abstract = EvalValueNode(value_node, conf);
  162. } else if (node->isa<CNode>()) {
  163. auto cnode = node->cast<CNodePtr>();
  164. trace::TraceInferCNodeEnter(conf);
  165. ret_abstract = InferCNode(cnode, conf);
  166. trace::TraceInferCNodeLeave();
  167. } else {
  168. MS_LOG(EXCEPTION) << "Illegal AnfNode for inferring, " << node->DebugString()
  169. << ". NodeInfo: " << trace::GetDebugInfo(node->debug_info());
  170. }
  171. #ifdef DEBUG
  172. compute_conf_stack_.pop_back();
  173. if (ret_abstract == nullptr) {
  174. MS_LOG(EXCEPTION) << "Compute Config failed, node: " << node->DebugString()
  175. << " NodeInfo: " << trace::GetDebugInfo(node->debug_info());
  176. }
  177. #endif
  178. MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << ret_abstract->ToString();
  179. return ret_abstract;
  180. }
  181. AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) {
  182. MS_EXCEPTION_IF_NULL(conf);
  183. MS_EXCEPTION_IF_NULL(value_node);
  184. return ToAbstract(value_node->value(), conf->context(), conf);
  185. }
  186. AbstractBasePtr AnalysisEngine::InferCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
  187. MS_EXCEPTION_IF_NULL(conf);
  188. MS_EXCEPTION_IF_NULL(cnode);
  189. auto &inputs = cnode->inputs();
  190. if (inputs.empty()) {
  191. MS_LOG(EXCEPTION) << "CNode->inputs() is empty, CNode: " << cnode->DebugString();
  192. }
  193. AnfNodePtr func_node = inputs[0];
  194. MS_EXCEPTION_IF_NULL(func_node);
  195. MS_LOG(DEBUG) << "Current CNode function: " << func_node->DebugString();
  196. AnalysisContextPtr context = conf->context();
  197. AnfNodeConfigPtr func_conf = MakeConfig(func_node, context);
  198. MS_EXCEPTION_IF_NULL(func_conf);
  199. // Keep it in a local variable, otherwise smart pointer will free it.
  200. AbstractBasePtr maybe_func = func_conf->GetEvaluatedValue();
  201. if (maybe_func == nullptr) {
  202. MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return null, func_conf: " << func_conf->ToString()
  203. << " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info());
  204. }
  205. AbstractFunctionPtr func = dyn_cast<AbstractFunction>(maybe_func);
  206. if (func == nullptr) {
  207. MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return not AbstractFunction: " << maybe_func->ToString()
  208. << ", func_conf: " << func_conf->ToString()
  209. << " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info());
  210. }
  211. ConfigPtrList args_conf_list;
  212. // ignore the first node which is function name
  213. for (std::size_t i = 1; i < inputs.size(); i++) {
  214. const AnfNodePtr &node = inputs[i];
  215. args_conf_list.push_back(MakeConfig(node, context));
  216. MS_LOG(DEBUG) << "Current CNode args_conf_list[" << i << "] node: " << node->DebugString();
  217. }
  218. std::vector<EvaluatorPtr> infs;
  219. auto build_evaluator = [this, &infs, &cnode](const AbstractFuncAtomPtr &poss) {
  220. auto evaluator = this->GetEvaluatorFor(poss);
  221. evaluator->set_bound_node(cnode);
  222. infs.push_back(evaluator);
  223. };
  224. func->Visit(build_evaluator);
  225. return ExecuteEvaluators(infs, conf, args_conf_list);
  226. }
  227. AbstractBasePtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) {
  228. ConfigPtrList args_conf_list;
  229. (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list),
  230. [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
  231. std::vector<EvaluatorPtr> infs;
  232. MS_EXCEPTION_IF_NULL(func);
  233. auto build_evaluator = [this, &infs](const AbstractFuncAtomPtr &poss) {
  234. auto evaluator = this->GetEvaluatorFor(poss);
  235. infs.push_back(evaluator);
  236. };
  237. func->Visit(build_evaluator);
  238. return ExecuteEvaluators(infs, nullptr, args_conf_list);
  239. }
  240. void AnalysisEngine::ClearEvaluatorCache() {
  241. for (std::pair<AbstractFunctionPtr, EvaluatorPtr> element : constructors_) {
  242. EvaluatorPtr evaluator = element.second;
  243. MS_EXCEPTION_IF_NULL(evaluator);
  244. MS_EXCEPTION_IF_NULL(evaluator->cache());
  245. evaluator->cache()->clear();
  246. }
  247. for (auto &element : prim_constructors_) {
  248. EvaluatorPtr evaluator = element.second;
  249. MS_EXCEPTION_IF_NULL(evaluator);
  250. MS_EXCEPTION_IF_NULL(evaluator->cache());
  251. evaluator->cache()->clear();
  252. }
  253. for (auto &element : prim_py_evaluators_) {
  254. EvaluatorPtr evaluator = element.second;
  255. MS_EXCEPTION_IF_NULL(evaluator);
  256. MS_EXCEPTION_IF_NULL(evaluator->cache());
  257. evaluator->cache()->clear();
  258. }
  259. }
  260. void AnalysisEngine::Clear() {
  261. cache_.Clear();
  262. anfnode_config_map_.clear();
  263. eval_trace_.clear();
  264. constructors_.clear();
  265. }
  266. namespace {
  267. EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) {
  268. // Custom Primitive with python infer_shape, infer_type
  269. EvaluatorPtr evaluator = nullptr;
  270. MS_EXCEPTION_IF_NULL(prim);
  271. if (prim->isa<prim::DoSignaturePrimitive>()) {
  272. evaluator = std::make_shared<DoSignatureEvaluator>(prim);
  273. return evaluator;
  274. }
  275. if (prim->isa<prim::UnpackGraphPrimitive>()) {
  276. evaluator = std::make_shared<UnpackGraphEvaluator>(prim);
  277. return evaluator;
  278. }
  279. if (prim->HasPyEvaluator()) {
  280. auto prim_py = dyn_cast<PrimitivePy>(prim);
  281. if (prim_py != nullptr) {
  282. if (engine == nullptr) {
  283. return std::make_shared<PythonPrimEvaluator>(prim_py);
  284. }
  285. const auto &iter = engine->prim_py_evaluators_.find(prim_py);
  286. if (iter != engine->prim_py_evaluators_.end()) {
  287. return iter->second;
  288. }
  289. evaluator = std::make_shared<PythonPrimEvaluator>(prim_py);
  290. engine->prim_py_evaluators_[prim_py] = evaluator;
  291. return evaluator;
  292. }
  293. MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive.";
  294. }
  295. if (prim->isa<PrimitivePy>() || prim->HasAttr()) {
  296. if (engine == nullptr) {
  297. (void)GetPrimEvaluatorConstructors();
  298. }
  299. // If a primitive may have attr, try to create a new evaluator.
  300. StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim);
  301. if (eval_impl != nullptr) {
  302. return std::make_shared<StandardPrimEvaluator>(prim, eval_impl);
  303. }
  304. }
  305. if (engine == nullptr) {
  306. // If engine is nullptr, get constructor from default.
  307. const PrimEvaluatorMap &prim_evaluator_map = GetPrimEvaluatorConstructors();
  308. auto iter = prim_evaluator_map.find(prim);
  309. if (iter != prim_evaluator_map.end()) {
  310. evaluator = iter->second;
  311. }
  312. } else {
  313. // If engine is given, get constructor from engine resource.
  314. const PrimEvaluatorMap &prim_evaluator_map = engine->PrimConstructors();
  315. auto iter = prim_evaluator_map.find(prim);
  316. if (iter != prim_evaluator_map.end()) {
  317. evaluator = iter->second;
  318. }
  319. }
  320. if (evaluator == nullptr) {
  321. MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << prim->name() << ").";
  322. }
  323. return evaluator;
  324. }
  325. } // namespace
  326. EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> &func) {
  327. auto inf_pair = constructors_.find(func);
  328. if (inf_pair != constructors_.end()) {
  329. return inf_pair->second;
  330. }
  331. MS_EXCEPTION_IF_NULL(func);
  332. auto primitive = func->prim();
  333. auto evaluator = GetPrimEvaluator(primitive, shared_from_this());
  334. constructors_[func] = evaluator;
  335. return evaluator;
  336. }
  337. EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<FuncGraphAbstractClosure> &func) {
  338. auto inf_pair = constructors_.find(func);
  339. if (inf_pair != constructors_.end()) {
  340. return inf_pair->second;
  341. }
  342. MS_EXCEPTION_IF_NULL(func);
  343. std::shared_ptr<FuncGraphEvaluator> func_graph_evaluator =
  344. std::make_shared<FuncGraphEvaluator>(func->func_graph(), func->context());
  345. constructors_[func] = func_graph_evaluator;
  346. return func_graph_evaluator;
  347. }
  348. EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<MetaFuncGraphAbstractClosure> &func) {
  349. auto inf_pair = constructors_.find(func);
  350. if (inf_pair != constructors_.end()) {
  351. return inf_pair->second;
  352. }
  353. MS_EXCEPTION_IF_NULL(func);
  354. std::shared_ptr<MetaFuncGraphEvaluator> evaluator =
  355. std::make_shared<MetaFuncGraphEvaluator>(func->meta_func_graph(), func->context(), func->GetScope());
  356. constructors_[func] = evaluator;
  357. return evaluator;
  358. }
  359. EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<JTransformedAbstractClosure> &func) {
  360. MS_EXCEPTION_IF_NULL(func);
  361. AbstractFunctionPtr func_orig = func->fn();
  362. EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig);
  363. auto jevaluator = std::make_shared<JEvaluator>(evaluator_orig, func_orig);
  364. return jevaluator;
  365. }
  366. EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<VirtualAbstractClosure> &func) {
  367. MS_EXCEPTION_IF_NULL(func);
  368. std::shared_ptr<VirtualEvaluator> virtual_evaluator =
  369. std::make_shared<VirtualEvaluator>(func->args_spec_list(), func->output());
  370. return virtual_evaluator;
  371. }
  372. EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PartialAbstractClosure> &func) {
  373. MS_EXCEPTION_IF_NULL(func);
  374. AbstractFunctionPtr func_orig = func->fn();
  375. EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig);
  376. std::shared_ptr<PartialAppEvaluator> partial_evaluator =
  377. std::make_shared<PartialAppEvaluator>(evaluator_orig, func->args());
  378. return partial_evaluator;
  379. }
  380. EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<TypedPrimitiveAbstractClosure> &) {
  381. MS_LOG(EXCEPTION) << "Should not be called ";
  382. }
  383. // Forward to specific subclass of FunctionWrapper.
  384. EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const AbstractFunctionPtr &func) {
  385. MS_EXCEPTION_IF_NULL(func);
  386. EvaluatorPtr evaluator = func->GetEvaluator(shared_from_this());
  387. return evaluator;
  388. }
  389. EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
  390. MS_LOG(DEBUG) << "The func value: " << func->ToString();
  391. if (func->tracking_id() != nullptr) {
  392. MS_LOG(DEBUG) << "The tracking_id: " << func->tracking_id()->DebugString();
  393. }
  394. MS_EXCEPTION_IF_NULL(func);
  395. if (func->tracking_id() == nullptr) {
  396. EvaluatorPtr evaluator = _GetEvaluatorFor(func);
  397. return evaluator;
  398. }
  399. auto inf_pair = constructors_.find(func);
  400. if (inf_pair != constructors_.end()) {
  401. return inf_pair->second;
  402. }
  403. AbstractFunctionPtr func_generic = func->Copy();
  404. func_generic->set_tracking_id(nullptr);
  405. EvaluatorPtr eval = _GetEvaluatorFor(func_generic);
  406. auto tracked_eval = std::make_shared<TrackedEvaluator>(eval);
  407. constructors_[func] = tracked_eval;
  408. return tracked_eval;
  409. }
  410. AbstractBasePtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators,
  411. const AnfNodeConfigPtr &out_conf,
  412. const ConfigPtrList &args_conf_list) {
  413. if (evaluators.size() == 1) {
  414. EvaluatorPtr eval = evaluators[0];
  415. MS_EXCEPTION_IF_NULL(eval);
  416. return eval->Run(shared_from_this(), args_conf_list, out_conf);
  417. }
  418. return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list);
  419. }
  420. AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators,
  421. const AnfNodeConfigPtr &out_conf,
  422. const ConfigPtrList &args_conf_list) {
  423. AbstractBasePtrList out_specs;
  424. AbstractBasePtrList args_spec_list;
  425. (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
  426. [](const ConfigPtr &conf) -> AbstractBasePtr {
  427. MS_EXCEPTION_IF_NULL(conf);
  428. return conf->GetEvaluatedValue();
  429. });
  430. for (auto eval : evaluators) {
  431. auto fg_eval = eval->cast<FuncGraphEvaluatorPtr>();
  432. if (fg_eval) {
  433. auto undetermined_fgs = fg_eval->func_graph()->recursive_graphs();
  434. if (undetermined_fgs) {
  435. for (auto undetermined_fg : *undetermined_fgs) {
  436. MS_LOG(DEBUG) << "Set graph undetermined: " << undetermined_fg->ToString();
  437. // As the current evaluator has multiple possibles, all the func_graphs which
  438. // are recursive with the current func_graph are undetermined in control flow.
  439. undetermined_fg->set_flags(kFuncGraphFlagUndetermined, true);
  440. }
  441. }
  442. }
  443. auto current_inf = std::make_pair(eval, args_spec_list);
  444. // If current evaluator is under tracing, then skip current evaluator to avoid recursively inferring.
  445. auto it = std::find(eval_trace_.begin(), eval_trace_.end(), current_inf);
  446. if (it == eval_trace_.end()) {
  447. eval_trace_.push_back(current_inf);
  448. MS_EXCEPTION_IF_NULL(eval);
  449. auto out_spec = eval->Run(shared_from_this(), args_conf_list, out_conf);
  450. MS_EXCEPTION_IF_NULL(out_spec);
  451. MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << out_spec->ToString();
  452. out_specs.push_back(out_spec);
  453. eval_trace_.pop_back();
  454. }
  455. }
  456. if (out_specs.size() == 0) {
  457. MS_LOG(EXCEPTION) << "There is an endless loop for evaluator.";
  458. }
  459. if (out_specs.size() == 1) {
  460. MS_EXCEPTION_IF_NULL(out_specs[0]);
  461. // If only one result derived, then broaden it to avoid wrong constant propagation.
  462. return out_specs[0]->Broaden();
  463. }
  464. auto joined_spec = AbstractJoin(out_specs);
  465. MS_EXCEPTION_IF_NULL(joined_spec);
  466. MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString();
  467. return joined_spec;
  468. }
  469. AbstractBasePtr AnfNodeConfig::GetEvaluatedValue() {
  470. AnfNodeConfigPtr self = shared_from_base<AnfNodeConfig>();
  471. return engine_.lock()->GetEvaluatedValue(self);
  472. }
  473. AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context, const AnfNodeConfigPtr &conf) {
  474. if (value->isa<FuncGraph>()) {
  475. auto func_graph = value->cast<FuncGraphPtr>();
  476. return func_graph->MakeAbstractClosure(context);
  477. }
  478. AnfNodePtr anf_node = nullptr;
  479. if (conf != nullptr) {
  480. anf_node = conf->node();
  481. }
  482. if (value->isa<MetaFuncGraph>()) {
  483. auto meta_func_graph = value->cast<MetaFuncGraphPtr>();
  484. return meta_func_graph->MakeAbstractClosure(anf_node);
  485. }
  486. if (value->isa<Primitive>()) {
  487. auto prim = value->cast<PrimitivePtr>();
  488. return prim->ToPrimAbstract(anf_node);
  489. }
  490. return value->ToAbstract();
  491. }
  492. AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) {
  493. AbstractBasePtr a = ToAbstract(value, nullptr, nullptr);
  494. if (broaden) {
  495. a = a->Broaden();
  496. }
  497. return a;
  498. }
  499. AbstractBasePtr InferOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) {
  500. auto evaluator = GetPrimEvaluator(primitive, nullptr);
  501. MS_EXCEPTION_IF_NULL(evaluator);
  502. if (!evaluator->isa<TrivialPrimEvaluator>()) {
  503. MS_LOG(EXCEPTION) << "Prim " << primitive->ToString() << " should build a TrivialPrimEvaluator, but "
  504. << evaluator->ToString();
  505. }
  506. auto trivial_evaluator = dyn_cast<TrivialPrimEvaluator>(evaluator);
  507. auto res_spec = trivial_evaluator->EvalPrim(nullptr, arg_specs);
  508. return res_spec;
  509. }
  510. } // namespace abstract
  511. } // namespace mindspore