You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

static_analysis.cc 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/static_analysis/static_analysis.h"
  19. #include <algorithm>
  20. #include "pipeline/static_analysis/utils.h"
  21. #include "pipeline/static_analysis/prim.h"
  22. #include "operator/ops.h"
  23. #include "utils/symbolic.h"
  24. #include "ir/meta_tensor.h"
  25. #include "ir/func_graph_cloner.h"
  26. #include "./common.h"
  27. #include "pipeline/parse/data_converter.h"
  28. #include "debug/draw.h"
  29. #include "pipeline/static_analysis/evaluator.h"
  30. #include "debug/trace.h"
  31. namespace mindspore {
  32. namespace abstract {
  33. bool IsIntermediateAbstract(const AbstractBasePtr &arg_spec) {
  34. if (dyn_cast<AbstractScalar>(arg_spec)) {
  35. auto v = arg_spec->GetValueTrack();
  36. if (v->isa<SymbolicKeyInstance>()) {
  37. return true;
  38. } else {
  39. return false;
  40. }
  41. } else {
  42. return false;
  43. }
  44. }
  45. AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBasePtr &arg2) {
  46. if (dyn_cast<AbstractScalar>(arg1) && dyn_cast<AbstractScalar>(arg2)) {
  47. return arg1->Join(arg2);
  48. }
  49. return nullptr;
  50. }
  51. void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) {
  52. MS_LOG(DEBUG) << "AnalysisCache set for NodeConfig: " << conf->node()->DebugString()
  53. << ", Context: " << conf->context()->ToString() << ", Value: " << arg->ToString()
  54. << ", Pointer: " << arg.get();
  55. cache_[conf] = arg;
  56. // Set intermediate abstract value.
  57. if (IsIntermediateAbstract(arg)) {
  58. if (conf->node()->intermediate_abstract() == nullptr) {
  59. conf->node()->set_intermediate_abstract(arg);
  60. MS_LOG(DEBUG) << "Set intermediate abstract: " << arg->ToString();
  61. } else {
  62. auto old_spec = conf->node()->intermediate_abstract();
  63. auto joined_spec = IntermediateJoin(arg, old_spec);
  64. conf->node()->set_intermediate_abstract(joined_spec);
  65. MS_LOG(DEBUG) << "Set joined intermediate abstract:\nold_spec:\t\t" << old_spec->ToString() << "\nnew_spec:\t\t"
  66. << arg->ToString() << "\njoined_spec:\t"
  67. << (joined_spec != nullptr ? joined_spec->ToString() : "nullptr");
  68. }
  69. }
  70. }
  71. AbstractBasePtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) {
  72. auto value = cache_.find(conf);
  73. if (value == cache_.end()) {
  74. return nullptr;
  75. }
  76. return value->second;
  77. }
  78. std::size_t AnfNodeConfigHasher::operator()(const AnfNodeConfigPtr conf) const {
  79. MS_EXCEPTION_IF_NULL(conf);
  80. MS_EXCEPTION_IF_NULL(conf->node());
  81. std::size_t hash_value = hash_combine(conf->node()->hash(), conf->context()->hash());
  82. if (conf->context() != nullptr && conf->context()->func_graph() != nullptr) {
  83. MS_LOG(DEBUG) << "NodeConfigHasher Node: " << conf->node()->DebugString()
  84. << ", Graph: " << conf->context()->func_graph()->ToString() << " ### , hash value: " << hash_value;
  85. } else {
  86. MS_LOG(DEBUG) << "NodeConfigHasher Node: " << conf->node()->DebugString() << " ### , hash value: " << hash_value;
  87. }
  88. return hash_value;
  89. }
  90. bool AnfNodeConfigEqual::operator()(const AnfNodeConfigPtr lhs, const AnfNodeConfigPtr rhs) const {
  91. if (lhs == nullptr || rhs == nullptr) {
  92. return false;
  93. }
  94. if (lhs == rhs) {
  95. return true;
  96. }
  97. return (*lhs == *rhs);
  98. }
  99. AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list) {
  100. ConfigPtrList args_conf_list;
  101. (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list),
  102. [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
  103. MS_EXCEPTION_IF_NULL(func_graph_manager_);
  104. func_graph_manager_->AddFuncGraph(func_graph);
  105. AnalysisContextPtr empty_context = AnalysisContext::DummyContext();
  106. // Running the analyzer.
  107. AnalysisContextPtr root_context = Run(func_graph, empty_context, args_conf_list);
  108. MS_EXCEPTION_IF_NULL(root_context);
  109. MS_EXCEPTION_IF_NULL(root_context->func_graph());
  110. AnfNodeConfigPtr output_conf = MakeConfig(root_context->func_graph()->get_return(), root_context);
  111. MS_EXCEPTION_IF_NULL(func_graph);
  112. MS_LOG(INFO) << "" << func_graph->ToString() << ": Run finished.";
  113. AnalysisResult result;
  114. MS_EXCEPTION_IF_NULL(output_conf);
  115. result.inferred = output_conf->GetEvaluatedValue();
  116. result.context = root_context;
  117. return result;
  118. }
  119. AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
  120. const ConfigPtrList &args_conf_list) {
  121. std::shared_ptr<FuncGraphEvaluator> eval = std::make_shared<FuncGraphEvaluator>(func_graph, context);
  122. (void)eval->Run(shared_from_this(), args_conf_list, nullptr);
  123. return eval->graph_context();
  124. }
  125. AbstractBasePtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) {
  126. MS_EXCEPTION_IF_NULL(conf);
  127. auto value = cache_.GetValue(conf);
  128. if (value != nullptr) {
  129. MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString() << ", Value: " << value.get() << ", "
  130. << value->ToString();
  131. return value;
  132. }
  133. MS_LOG(DEBUG) << "Evaluate cache miss for NodeConfig: " << conf->ToString();
  134. value = Eval(conf);
  135. if (value == nullptr) {
  136. MS_LOG(EXCEPTION) << "Evaluate for NodeConfig " << conf->ToString() << " get nullptr";
  137. }
  138. cache_.set_value(conf, value);
  139. return value;
  140. }
  141. AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
  142. MS_EXCEPTION_IF_NULL(conf);
  143. AnfNodePtr node = conf->node();
  144. AbstractBasePtr ret_abstract = nullptr;
  145. #ifdef DEBUG
  146. compute_conf_stack_.push_back(node);
  147. std::ostringstream buffer;
  148. buffer << "Compute Config Begin:";
  149. for (auto iter : compute_conf_stack_) {
  150. buffer << " -> " << iter->DebugString();
  151. }
  152. MS_LOG(DEBUG) << "" << buffer.str();
  153. #endif
  154. MS_LOG(DEBUG) << "Begin Eval NodeConfig " << conf->ToString();
  155. MS_EXCEPTION_IF_NULL(node);
  156. if (node->abstract() != nullptr) {
  157. MS_LOG(DEBUG) << "Return old abstract: " << node->DebugString();
  158. ret_abstract = node->abstract();
  159. } else if (node->isa<ValueNode>()) {
  160. auto value_node = node->cast<ValueNodePtr>();
  161. ret_abstract = EvalValueNode(value_node, conf);
  162. } else if (node->isa<CNode>()) {
  163. auto cnode = node->cast<CNodePtr>();
  164. trace::TraceInferCNodeEnter(conf);
  165. ret_abstract = InferCNode(cnode, conf);
  166. trace::TraceInferCNodeLeave();
  167. } else {
  168. MS_LOG(EXCEPTION) << "Illegal AnfNode for inferring, " << node->DebugString()
  169. << ". NodeInfo: " << trace::GetDebugInfo(node->debug_info());
  170. }
  171. #ifdef DEBUG
  172. compute_conf_stack_.pop_back();
  173. if (ret_abstract == nullptr) {
  174. MS_LOG(EXCEPTION) << "Compute Config failed, node: " << node->DebugString()
  175. << " NodeInfo: " << trace::GetDebugInfo(node->debug_info());
  176. }
  177. #endif
  178. MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << ret_abstract->ToString();
  179. return ret_abstract;
  180. }
  181. AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) {
  182. MS_EXCEPTION_IF_NULL(conf);
  183. MS_EXCEPTION_IF_NULL(value_node);
  184. return ToAbstract(value_node->value(), conf->context(), conf);
  185. }
  186. AbstractBasePtr AnalysisEngine::InferCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
  187. MS_EXCEPTION_IF_NULL(conf);
  188. MS_EXCEPTION_IF_NULL(cnode);
  189. auto &inputs = cnode->inputs();
  190. if (inputs.empty()) {
  191. MS_LOG(EXCEPTION) << "CNode->inputs() is empty, CNode: " << cnode->DebugString();
  192. }
  193. AnfNodePtr func_node = inputs[0];
  194. MS_EXCEPTION_IF_NULL(func_node);
  195. MS_LOG(DEBUG) << "Current CNode function: " << func_node->DebugString();
  196. AnalysisContextPtr context = conf->context();
  197. AnfNodeConfigPtr func_conf = MakeConfig(func_node, context);
  198. MS_EXCEPTION_IF_NULL(func_conf);
  199. // Keep it in a local variable, otherwise smart pointer will free it.
  200. AbstractBasePtr maybe_func = func_conf->GetEvaluatedValue();
  201. if (maybe_func == nullptr) {
  202. MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return null, func_conf: " << func_conf->ToString()
  203. << " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info());
  204. }
  205. AbstractFunctionPtr func = dyn_cast<AbstractFunction>(maybe_func);
  206. if (func == nullptr) {
  207. MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return not AbstractFunction: " << maybe_func->ToString()
  208. << ", func_conf: " << func_conf->ToString()
  209. << " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info());
  210. }
  211. ConfigPtrList args_conf_list;
  212. // ignore the first node which is function name
  213. for (std::size_t i = 1; i < inputs.size(); i++) {
  214. const AnfNodePtr &node = inputs[i];
  215. args_conf_list.push_back(MakeConfig(node, context));
  216. MS_LOG(DEBUG) << "Current CNode args_conf_list[" << i << "] node: " << node->DebugString();
  217. }
  218. std::vector<EvaluatorPtr> infs;
  219. auto build_evaluator = [this, &infs, &cnode](const AbstractFuncAtomPtr &poss) {
  220. auto evaluator = this->GetEvaluatorFor(poss);
  221. evaluator->set_bound_node(cnode);
  222. infs.push_back(evaluator);
  223. };
  224. func->Visit(build_evaluator);
  225. return ExecuteEvaluators(infs, conf, args_conf_list);
  226. }
  227. AbstractBasePtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) {
  228. ConfigPtrList args_conf_list;
  229. (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list),
  230. [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
  231. std::vector<EvaluatorPtr> infs;
  232. MS_EXCEPTION_IF_NULL(func);
  233. auto build_evaluator = [this, &infs](const AbstractFuncAtomPtr &poss) {
  234. auto evaluator = this->GetEvaluatorFor(poss);
  235. infs.push_back(evaluator);
  236. };
  237. func->Visit(build_evaluator);
  238. return ExecuteEvaluators(infs, nullptr, args_conf_list);
  239. }
  240. void AnalysisEngine::ClearEvaluatorCache() {
  241. for (std::pair<AbstractFunctionPtr, EvaluatorPtr> element : constructors_) {
  242. EvaluatorPtr evaluator = element.second;
  243. MS_EXCEPTION_IF_NULL(evaluator);
  244. MS_EXCEPTION_IF_NULL(evaluator->cache());
  245. evaluator->cache()->clear();
  246. }
  247. }
  248. void AnalysisEngine::Clear() {
  249. cache_.Clear();
  250. anfnode_config_map_.clear();
  251. eval_trace_.clear();
  252. constructors_.clear();
  253. }
  254. namespace {
  255. EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) {
  256. // Custom Primitive with python infer_shape, infer_type
  257. EvaluatorPtr evaluator = nullptr;
  258. MS_EXCEPTION_IF_NULL(prim);
  259. if (prim->isa<prim::DoSignaturePrimitive>()) {
  260. evaluator = std::make_shared<DoSignatureEvaluator>(prim);
  261. return evaluator;
  262. }
  263. if (prim->isa<prim::UnpackGraphPrimitive>()) {
  264. evaluator = std::make_shared<UnpackGraphEvaluator>(prim);
  265. return evaluator;
  266. }
  267. if (prim->HasPyEvaluator()) {
  268. auto prim_py = dyn_cast<PrimitivePy>(prim);
  269. if (prim_py != nullptr) {
  270. return std::make_shared<PythonPrimEvaluator>(prim_py);
  271. }
  272. MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive.";
  273. }
  274. if (prim->isa<PrimitivePy>() || prim->HasAttr()) {
  275. if (engine == nullptr) {
  276. (void)GetPrimEvaluatorConstructors();
  277. }
  278. // If a primitive may have attr, try to create a new evaluator.
  279. StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim);
  280. if (eval_impl != nullptr) {
  281. return std::make_shared<StandardPrimEvaluator>(prim, eval_impl);
  282. }
  283. }
  284. if (engine == nullptr) {
  285. // If engine is nullptr, get constructor from default.
  286. const PrimEvaluatorMap &prim_evaluator_map = GetPrimEvaluatorConstructors();
  287. auto iter = prim_evaluator_map.find(prim);
  288. if (iter != prim_evaluator_map.end()) {
  289. evaluator = iter->second;
  290. }
  291. } else {
  292. // If engine is given, get constructor from engine resource.
  293. const PrimEvaluatorMap &prim_evaluator_map = engine->PrimConstructors();
  294. auto iter = prim_evaluator_map.find(prim);
  295. if (iter != prim_evaluator_map.end()) {
  296. evaluator = iter->second;
  297. }
  298. }
  299. if (evaluator == nullptr) {
  300. MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << prim->name() << ").";
  301. }
  302. return evaluator;
  303. }
  304. } // namespace
  305. EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> &func) {
  306. auto inf_pair = constructors_.find(func);
  307. if (inf_pair != constructors_.end()) {
  308. return inf_pair->second;
  309. }
  310. MS_EXCEPTION_IF_NULL(func);
  311. auto primitive = func->prim();
  312. auto evaluator = GetPrimEvaluator(primitive, shared_from_this());
  313. constructors_[func] = evaluator;
  314. return evaluator;
  315. }
  316. EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<FuncGraphAbstractClosure> &func) {
  317. auto inf_pair = constructors_.find(func);
  318. if (inf_pair != constructors_.end()) {
  319. return inf_pair->second;
  320. }
  321. MS_EXCEPTION_IF_NULL(func);
  322. std::shared_ptr<FuncGraphEvaluator> func_graph_evaluator =
  323. std::make_shared<FuncGraphEvaluator>(func->func_graph(), func->context());
  324. constructors_[func] = func_graph_evaluator;
  325. return func_graph_evaluator;
  326. }
  327. EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<MetaFuncGraphAbstractClosure> &func) {
  328. auto inf_pair = constructors_.find(func);
  329. if (inf_pair != constructors_.end()) {
  330. return inf_pair->second;
  331. }
  332. MS_EXCEPTION_IF_NULL(func);
  333. std::shared_ptr<MetaFuncGraphEvaluator> evaluator =
  334. std::make_shared<MetaFuncGraphEvaluator>(func->meta_func_graph(), func->context(), func->GetScope());
  335. constructors_[func] = evaluator;
  336. return evaluator;
  337. }
  338. EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<JTransformedAbstractClosure> &func) {
  339. MS_EXCEPTION_IF_NULL(func);
  340. AbstractFunctionPtr func_orig = func->fn();
  341. EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig);
  342. auto jevaluator = std::make_shared<JEvaluator>(evaluator_orig, func_orig);
  343. return jevaluator;
  344. }
  345. EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<VirtualAbstractClosure> &func) {
  346. MS_EXCEPTION_IF_NULL(func);
  347. std::shared_ptr<VirtualEvaluator> virtual_evaluator =
  348. std::make_shared<VirtualEvaluator>(func->args_spec_list(), func->output());
  349. return virtual_evaluator;
  350. }
  351. EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PartialAbstractClosure> &func) {
  352. MS_EXCEPTION_IF_NULL(func);
  353. AbstractFunctionPtr func_orig = func->fn();
  354. EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig);
  355. std::shared_ptr<PartialAppEvaluator> partial_evaluator =
  356. std::make_shared<PartialAppEvaluator>(evaluator_orig, func->args());
  357. return partial_evaluator;
  358. }
  359. EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<TypedPrimitiveAbstractClosure> &) {
  360. MS_LOG(EXCEPTION) << "Should not be called ";
  361. }
  362. // Forward to specific subclass of FunctionWrapper.
  363. EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const AbstractFunctionPtr &func) {
  364. MS_EXCEPTION_IF_NULL(func);
  365. EvaluatorPtr evaluator = func->GetEvaluator(shared_from_this());
  366. return evaluator;
  367. }
  368. EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
  369. MS_LOG(DEBUG) << "The func value: " << func->ToString();
  370. if (func->tracking_id() != nullptr) {
  371. MS_LOG(DEBUG) << "The tracking_id: " << func->tracking_id()->DebugString();
  372. }
  373. MS_EXCEPTION_IF_NULL(func);
  374. if (func->tracking_id() == nullptr) {
  375. EvaluatorPtr evaluator = _GetEvaluatorFor(func);
  376. return evaluator;
  377. }
  378. auto inf_pair = constructors_.find(func);
  379. if (inf_pair != constructors_.end()) {
  380. return inf_pair->second;
  381. }
  382. AbstractFunctionPtr func_generic = func->Copy();
  383. func_generic->set_tracking_id(nullptr);
  384. EvaluatorPtr eval = _GetEvaluatorFor(func_generic);
  385. auto tracked_eval = std::make_shared<TrackedEvaluator>(eval);
  386. constructors_[func] = tracked_eval;
  387. return tracked_eval;
  388. }
  389. AbstractBasePtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators,
  390. const AnfNodeConfigPtr &out_conf,
  391. const ConfigPtrList &args_conf_list) {
  392. if (evaluators.size() == 1) {
  393. EvaluatorPtr eval = evaluators[0];
  394. MS_EXCEPTION_IF_NULL(eval);
  395. return eval->Run(shared_from_this(), args_conf_list, out_conf);
  396. }
  397. return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list);
  398. }
  399. AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators,
  400. const AnfNodeConfigPtr &out_conf,
  401. const ConfigPtrList &args_conf_list) {
  402. AbstractBasePtrList out_specs;
  403. AbstractBasePtrList args_spec_list;
  404. (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
  405. [](const ConfigPtr &conf) -> AbstractBasePtr {
  406. MS_EXCEPTION_IF_NULL(conf);
  407. return conf->GetEvaluatedValue();
  408. });
  409. for (auto eval : evaluators) {
  410. auto fg_eval = eval->cast<FuncGraphEvaluatorPtr>();
  411. if (fg_eval) {
  412. auto undetermined_fgs = fg_eval->func_graph()->recursive_graphs();
  413. if (undetermined_fgs) {
  414. for (auto undetermined_fg : *undetermined_fgs) {
  415. MS_LOG(DEBUG) << "Set graph undetermined: " << undetermined_fg->ToString();
  416. // As the current evaluator has multiple possibles, all the func_graphs which
  417. // are recursive with the current func_graph are undetermined in control flow.
  418. undetermined_fg->set_flags(kFuncGraphFlagUndetermined, true);
  419. }
  420. }
  421. }
  422. auto current_inf = std::make_pair(eval, args_spec_list);
  423. // If current evaluator is under tracing, then skip current evaluator to avoid recursively inferring.
  424. auto it = std::find(eval_trace_.begin(), eval_trace_.end(), current_inf);
  425. if (it == eval_trace_.end()) {
  426. eval_trace_.push_back(current_inf);
  427. MS_EXCEPTION_IF_NULL(eval);
  428. auto out_spec = eval->Run(shared_from_this(), args_conf_list, out_conf);
  429. MS_EXCEPTION_IF_NULL(out_spec);
  430. MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << out_spec->ToString();
  431. out_specs.push_back(out_spec);
  432. eval_trace_.pop_back();
  433. }
  434. }
  435. if (out_specs.size() == 0) {
  436. MS_LOG(EXCEPTION) << "There is an endless loop for evaluator.";
  437. }
  438. if (out_specs.size() == 1) {
  439. MS_EXCEPTION_IF_NULL(out_specs[0]);
  440. // If only one result derived, then broaden it to avoid wrong constant propagation.
  441. return out_specs[0]->Broaden();
  442. }
  443. auto joined_spec = AbstractJoin(out_specs);
  444. MS_EXCEPTION_IF_NULL(joined_spec);
  445. MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString();
  446. return joined_spec;
  447. }
  448. AbstractBasePtr AnfNodeConfig::GetEvaluatedValue() {
  449. AnfNodeConfigPtr self = shared_from_base<AnfNodeConfig>();
  450. return engine_.lock()->GetEvaluatedValue(self);
  451. }
  452. AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context, const AnfNodeConfigPtr &conf) {
  453. if (value->isa<FuncGraph>()) {
  454. auto func_graph = value->cast<FuncGraphPtr>();
  455. return func_graph->MakeAbstractClosure(context);
  456. }
  457. AnfNodePtr anf_node = nullptr;
  458. if (conf != nullptr) {
  459. anf_node = conf->node();
  460. }
  461. if (value->isa<MetaFuncGraph>()) {
  462. auto meta_func_graph = value->cast<MetaFuncGraphPtr>();
  463. return meta_func_graph->MakeAbstractClosure(anf_node);
  464. }
  465. if (value->isa<Primitive>()) {
  466. auto prim = value->cast<PrimitivePtr>();
  467. return prim->ToPrimAbstract(anf_node);
  468. }
  469. return value->ToAbstract();
  470. }
  471. AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) {
  472. AbstractBasePtr a = ToAbstract(value, nullptr, nullptr);
  473. if (broaden) {
  474. a = a->Broaden();
  475. }
  476. return a;
  477. }
  478. AbstractBasePtr InferOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) {
  479. auto evaluator = GetPrimEvaluator(primitive, nullptr);
  480. MS_EXCEPTION_IF_NULL(evaluator);
  481. if (!evaluator->isa<TrivialPrimEvaluator>()) {
  482. MS_LOG(EXCEPTION) << "Prim " << primitive->ToString() << " should build a TrivialPrimEvaluator, but "
  483. << evaluator->ToString();
  484. }
  485. auto trivial_evaluator = dyn_cast<TrivialPrimEvaluator>(evaluator);
  486. auto res_spec = trivial_evaluator->EvalPrim(nullptr, arg_specs);
  487. return res_spec;
  488. }
  489. } // namespace abstract
  490. } // namespace mindspore