You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

static_analysis.cc 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/static_analysis/static_analysis.h"
  19. #include <algorithm>
  20. #include <set>
  21. #include "pipeline/static_analysis/utils.h"
  22. #include "pipeline/static_analysis/prim.h"
  23. #include "operator/ops.h"
  24. #include "utils/symbolic.h"
  25. #include "ir/tensor.h"
  26. #include "ir/func_graph_cloner.h"
  27. #include "./common.h"
  28. #include "pipeline/parse/data_converter.h"
  29. #include "debug/draw.h"
  30. #include "pipeline/static_analysis/evaluator.h"
  31. #include "debug/trace.h"
  32. namespace mindspore {
  33. namespace abstract {
  34. bool IsIntermediateAbstract(const AbstractBasePtr &arg_spec) {
  35. if (dyn_cast<AbstractScalar>(arg_spec)) {
  36. auto v = arg_spec->GetValueTrack();
  37. if (v->isa<SymbolicKeyInstance>()) {
  38. return true;
  39. } else {
  40. return false;
  41. }
  42. } else {
  43. return false;
  44. }
  45. }
  46. AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBasePtr &arg2) {
  47. if (dyn_cast<AbstractScalar>(arg1) && dyn_cast<AbstractScalar>(arg2)) {
  48. return arg1->Join(arg2);
  49. }
  50. return nullptr;
  51. }
  52. void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) {
  53. MS_LOG(DEBUG) << "AnalysisCache set for NodeConfig: " << conf->node()->DebugString()
  54. << ", Context: " << conf->context()->ToString() << ", Value: " << result->abstract()->ToString()
  55. << ", Pointer: " << result->abstract().get();
  56. cache_[conf] = result;
  57. // Set intermediate abstract value.
  58. if (IsIntermediateAbstract(result->abstract())) {
  59. if (conf->node()->intermediate_abstract() == nullptr) {
  60. conf->node()->set_intermediate_abstract(result->abstract());
  61. MS_LOG(DEBUG) << "Set intermediate abstract: " << result->abstract()->ToString();
  62. } else {
  63. auto old_spec = conf->node()->intermediate_abstract();
  64. auto joined_spec = IntermediateJoin(result->abstract(), old_spec);
  65. conf->node()->set_intermediate_abstract(joined_spec);
  66. MS_LOG(DEBUG) << "Set joined intermediate abstract:\nold_spec:\t\t" << old_spec->ToString() << "\nnew_spec:\t\t"
  67. << result->abstract()->ToString() << "\njoined_spec:\t"
  68. << (joined_spec != nullptr ? joined_spec->ToString() : "nullptr");
  69. }
  70. }
  71. }
  72. EvalResultPtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) {
  73. auto value = cache_.find(conf);
  74. if (value == cache_.end()) {
  75. return nullptr;
  76. }
  77. return value->second;
  78. }
  79. std::size_t AnfNodeConfigHasher::operator()(const AnfNodeConfigPtr conf) const {
  80. MS_EXCEPTION_IF_NULL(conf);
  81. MS_EXCEPTION_IF_NULL(conf->node());
  82. std::size_t hash_value = conf->node()->hash();
  83. if (!conf->context()->IsDummyContext()) {
  84. hash_value = hash_combine(hash_value, std::hash<AnalysisContext *>{}(conf->context().get()));
  85. }
  86. if (conf->context() != nullptr && conf->context()->func_graph() != nullptr) {
  87. MS_LOG(DEBUG) << "NodeConfigHasher Node: " << conf->node()->DebugString()
  88. << ", Graph: " << conf->context()->func_graph()->ToString() << " ### , hash value: " << hash_value;
  89. } else {
  90. MS_LOG(DEBUG) << "NodeConfigHasher Node: " << conf->node()->DebugString() << " ### , hash value: " << hash_value;
  91. }
  92. return hash_value;
  93. }
  94. bool AnfNodeConfigEqual::operator()(const AnfNodeConfigPtr lhs, const AnfNodeConfigPtr rhs) const {
  95. if (lhs == nullptr || rhs == nullptr) {
  96. return false;
  97. }
  98. if (lhs == rhs) {
  99. return true;
  100. }
  101. return (*lhs == *rhs);
  102. }
  103. AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list) {
  104. ConfigPtrList args_conf_list;
  105. (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list),
  106. [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
  107. MS_EXCEPTION_IF_NULL(func_graph_manager_);
  108. func_graph_manager_->AddFuncGraph(func_graph);
  109. AnalysisContextPtr empty_context = AnalysisContext::DummyContext();
  110. // Running the analyzer.
  111. AnalysisContextPtr root_context = Run(func_graph, empty_context, args_conf_list);
  112. MS_EXCEPTION_IF_NULL(root_context);
  113. MS_EXCEPTION_IF_NULL(root_context->func_graph());
  114. AnfNodeConfigPtr output_conf = MakeConfig(root_context->func_graph()->get_return(), root_context);
  115. MS_EXCEPTION_IF_NULL(func_graph);
  116. MS_LOG(INFO) << func_graph->ToString() << ": Run finished.";
  117. AnalysisResult result;
  118. MS_EXCEPTION_IF_NULL(output_conf);
  119. result.inferred = output_conf->GetEvaluatedValue();
  120. result.context = root_context;
  121. return result;
  122. }
  123. AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
  124. const ConfigPtrList &args_conf_list) {
  125. std::shared_ptr<FuncGraphEvaluator> eval = std::make_shared<FuncGraphEvaluator>(func_graph, context);
  126. (void)eval->Run(shared_from_this(), args_conf_list, nullptr);
  127. return eval->graph_context();
  128. }
  129. EvalResultPtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) {
  130. MS_EXCEPTION_IF_NULL(conf);
  131. auto value = cache_.GetValue(conf);
  132. if (value != nullptr) {
  133. MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString() << ", Value: " << value->abstract().get()
  134. << ", " << value->abstract()->ToString();
  135. return value;
  136. }
  137. MS_LOG(DEBUG) << "Evaluate cache miss for NodeConfig: " << conf->ToString();
  138. value = Eval(conf);
  139. if (value == nullptr) {
  140. MS_LOG(EXCEPTION) << "Evaluate for NodeConfig " << conf->ToString() << " get nullptr";
  141. }
  142. cache_.set_value(conf, value);
  143. return value;
  144. }
  145. EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
  146. MS_EXCEPTION_IF_NULL(conf);
  147. AnfNodePtr node = conf->node();
  148. EvalResultPtr eval_result = nullptr;
  149. #ifdef DEBUG
  150. compute_conf_stack_.push_back(node);
  151. std::ostringstream buffer;
  152. buffer << "Compute Config Begin:";
  153. for (auto iter : compute_conf_stack_) {
  154. buffer << " -> " << iter->DebugString();
  155. }
  156. MS_LOG(DEBUG) << buffer.str();
  157. #endif
  158. MS_LOG(DEBUG) << "Begin Eval NodeConfig " << conf->ToString();
  159. MS_EXCEPTION_IF_NULL(node);
  160. if (node->abstract() != nullptr) {
  161. MS_LOG(DEBUG) << "Return old abstract: " << node->DebugString();
  162. eval_result = std::make_shared<EvalResult>(node->abstract(), std::make_shared<AttrValueMap>());
  163. } else if (node->isa<ValueNode>()) {
  164. auto value_node = node->cast<ValueNodePtr>();
  165. eval_result = std::make_shared<EvalResult>(EvalValueNode(value_node, conf), nullptr);
  166. } else if (node->isa<CNode>()) {
  167. auto cnode = node->cast<CNodePtr>();
  168. trace::TraceEvalCNodeEnter(conf);
  169. eval_result = EvalCNode(cnode, conf);
  170. trace::TraceEvalCNodeLeave();
  171. } else {
  172. MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, " << node->DebugString()
  173. << ". NodeInfo: " << trace::GetDebugInfo(node->debug_info());
  174. }
  175. #ifdef DEBUG
  176. compute_conf_stack_.pop_back();
  177. if (eval_result == nullptr) {
  178. MS_LOG(EXCEPTION) << "Compute Config failed, node: " << node->DebugString()
  179. << " NodeInfo: " << trace::GetDebugInfo(node->debug_info());
  180. }
  181. #endif
  182. MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << eval_result->abstract()->ToString();
  183. return eval_result;
  184. }
  185. AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) {
  186. MS_EXCEPTION_IF_NULL(conf);
  187. MS_EXCEPTION_IF_NULL(value_node);
  188. return ToAbstract(value_node->value(), conf->context(), conf);
  189. }
  190. EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
  191. MS_EXCEPTION_IF_NULL(conf);
  192. MS_EXCEPTION_IF_NULL(cnode);
  193. auto &inputs = cnode->inputs();
  194. if (inputs.empty()) {
  195. MS_LOG(EXCEPTION) << "CNode->inputs() is empty, CNode: " << cnode->DebugString();
  196. }
  197. AnfNodePtr func_node = inputs[0];
  198. MS_EXCEPTION_IF_NULL(func_node);
  199. MS_LOG(DEBUG) << "Current CNode function: " << func_node->DebugString();
  200. AnalysisContextPtr context = conf->context();
  201. AnfNodeConfigPtr func_conf = MakeConfig(func_node, context);
  202. MS_EXCEPTION_IF_NULL(func_conf);
  203. // Keep it in a local variable, otherwise smart pointer will free it.
  204. AbstractBasePtr maybe_func = func_conf->GetEvaluatedValue()->abstract();
  205. if (maybe_func == nullptr) {
  206. MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return null, func_conf: " << func_conf->ToString()
  207. << " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info());
  208. }
  209. if (maybe_func->BuildType()->type_id() == kObjectTypeUndeterminedType) {
  210. MS_LOG(DEBUG) << "EvalCNode eval Undetermined";
  211. return std::make_shared<EvalResult>(maybe_func->Clone(), std::make_shared<AttrValueMap>());
  212. }
  213. AbstractFunctionPtr func = dyn_cast<AbstractFunction>(maybe_func);
  214. if (func == nullptr) {
  215. MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return not AbstractFunction: " << maybe_func->ToString()
  216. << ", func_conf: " << func_conf->ToString()
  217. << " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info());
  218. }
  219. ConfigPtrList args_conf_list;
  220. // ignore the first node which is function name
  221. for (std::size_t i = 1; i < inputs.size(); i++) {
  222. const AnfNodePtr &node = inputs[i];
  223. args_conf_list.push_back(MakeConfig(node, context));
  224. }
  225. std::vector<EvaluatorPtr> infs;
  226. auto build_evaluator = [this, &infs, &cnode](const AbstractFuncAtomPtr &poss) {
  227. auto evaluator = this->GetEvaluatorFor(poss);
  228. evaluator->set_bound_node(cnode);
  229. infs.push_back(evaluator);
  230. };
  231. func->Visit(build_evaluator);
  232. return ExecuteEvaluators(infs, conf, args_conf_list);
  233. }
  234. EvalResultPtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) {
  235. ConfigPtrList args_conf_list;
  236. (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list),
  237. [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
  238. std::vector<EvaluatorPtr> infs;
  239. MS_EXCEPTION_IF_NULL(func);
  240. auto build_evaluator = [this, &infs](const AbstractFuncAtomPtr &poss) {
  241. auto evaluator = this->GetEvaluatorFor(poss);
  242. infs.push_back(evaluator);
  243. };
  244. func->Visit(build_evaluator);
  245. return ExecuteEvaluators(infs, nullptr, args_conf_list);
  246. }
  247. void AnalysisEngine::ClearEvaluatorCache() {
  248. for (std::pair<AbstractFunctionPtr, EvaluatorPtr> element : constructors_) {
  249. EvaluatorPtr evaluator = element.second;
  250. MS_EXCEPTION_IF_NULL(evaluator);
  251. MS_EXCEPTION_IF_NULL(evaluator->cache());
  252. evaluator->cache()->clear();
  253. }
  254. for (auto &element : prim_constructors_) {
  255. EvaluatorPtr evaluator = element.second;
  256. MS_EXCEPTION_IF_NULL(evaluator);
  257. MS_EXCEPTION_IF_NULL(evaluator->cache());
  258. evaluator->cache()->clear();
  259. }
  260. for (auto &element : prim_py_evaluators_) {
  261. EvaluatorPtr evaluator = element.second;
  262. MS_EXCEPTION_IF_NULL(evaluator);
  263. MS_EXCEPTION_IF_NULL(evaluator->cache());
  264. evaluator->cache()->clear();
  265. }
  266. }
  267. void AnalysisEngine::Clear() {
  268. cache_.Clear();
  269. anfnode_config_map_.clear();
  270. eval_trace_.clear();
  271. constructors_.clear();
  272. }
  273. namespace {
  274. EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) {
  275. // Custom Primitive with python infer_shape, infer_type
  276. EvaluatorPtr evaluator = nullptr;
  277. MS_EXCEPTION_IF_NULL(prim);
  278. if (prim->isa<prim::DoSignaturePrimitive>()) {
  279. evaluator = std::make_shared<DoSignatureEvaluator>(prim);
  280. return evaluator;
  281. }
  282. if (prim->isa<prim::UnpackGraphPrimitive>()) {
  283. evaluator = std::make_shared<UnpackGraphEvaluator>(prim);
  284. return evaluator;
  285. }
  286. if (prim->Hash() == prim::kPrimMixedPrecisionCast->Hash() && prim->name() == prim::kPrimMixedPrecisionCast->name()) {
  287. evaluator = std::make_shared<MixedPrecisionCastEvaluator>(prim);
  288. return evaluator;
  289. }
  290. if (prim->HasPyEvaluator()) {
  291. auto prim_py = dyn_cast<PrimitivePy>(prim);
  292. if (prim_py != nullptr) {
  293. if (engine == nullptr) {
  294. return std::make_shared<PythonPrimEvaluator>(prim_py);
  295. }
  296. const auto &iter = engine->prim_py_evaluators_.find(prim_py);
  297. if (iter != engine->prim_py_evaluators_.end()) {
  298. return iter->second;
  299. }
  300. evaluator = std::make_shared<PythonPrimEvaluator>(prim_py);
  301. engine->prim_py_evaluators_[prim_py] = evaluator;
  302. return evaluator;
  303. }
  304. MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive.";
  305. }
  306. if (prim->isa<PrimitivePy>() || prim->HasAttr()) {
  307. if (engine == nullptr) {
  308. (void)GetPrimEvaluatorConstructors();
  309. }
  310. // If a primitive may have attr, try to create a new evaluator.
  311. StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim);
  312. if (eval_impl != nullptr) {
  313. return std::make_shared<StandardPrimEvaluator>(prim, eval_impl);
  314. }
  315. }
  316. if (engine == nullptr) {
  317. // If engine is nullptr, get constructor from default.
  318. const PrimEvaluatorMap &prim_evaluator_map = GetPrimEvaluatorConstructors();
  319. auto iter = prim_evaluator_map.find(prim);
  320. if (iter != prim_evaluator_map.end()) {
  321. evaluator = iter->second;
  322. }
  323. } else {
  324. // If engine is given, get constructor from engine resource.
  325. const PrimEvaluatorMap &prim_evaluator_map = engine->PrimConstructors();
  326. auto iter = prim_evaluator_map.find(prim);
  327. if (iter != prim_evaluator_map.end()) {
  328. evaluator = iter->second;
  329. }
  330. }
  331. if (evaluator == nullptr) {
  332. MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << prim->name() << ").";
  333. }
  334. return evaluator;
  335. }
  336. } // namespace
  337. EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> &func) {
  338. auto inf_pair = constructors_.find(func);
  339. if (inf_pair != constructors_.end()) {
  340. return inf_pair->second;
  341. }
  342. MS_EXCEPTION_IF_NULL(func);
  343. auto primitive = func->prim();
  344. auto evaluator = GetPrimEvaluator(primitive, shared_from_this());
  345. constructors_[func] = evaluator;
  346. return evaluator;
  347. }
  348. EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<FuncGraphAbstractClosure> &func) {
  349. auto inf_pair = constructors_.find(func);
  350. if (inf_pair != constructors_.end()) {
  351. return inf_pair->second;
  352. }
  353. MS_EXCEPTION_IF_NULL(func);
  354. std::shared_ptr<FuncGraphEvaluator> func_graph_evaluator =
  355. std::make_shared<FuncGraphEvaluator>(func->func_graph(), func->context());
  356. constructors_[func] = func_graph_evaluator;
  357. return func_graph_evaluator;
  358. }
  359. EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<MetaFuncGraphAbstractClosure> &func) {
  360. auto inf_pair = constructors_.find(func);
  361. if (inf_pair != constructors_.end()) {
  362. return inf_pair->second;
  363. }
  364. MS_EXCEPTION_IF_NULL(func);
  365. std::shared_ptr<MetaFuncGraphEvaluator> evaluator =
  366. std::make_shared<MetaFuncGraphEvaluator>(func->meta_func_graph(), func->context(), func->GetScope());
  367. constructors_[func] = evaluator;
  368. return evaluator;
  369. }
  370. EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<JTransformedAbstractClosure> &func) {
  371. MS_EXCEPTION_IF_NULL(func);
  372. AbstractFunctionPtr func_orig = func->fn();
  373. EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig);
  374. auto jevaluator = std::make_shared<JEvaluator>(evaluator_orig, func_orig);
  375. return jevaluator;
  376. }
  377. EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<VirtualAbstractClosure> &func) {
  378. MS_EXCEPTION_IF_NULL(func);
  379. std::shared_ptr<VirtualEvaluator> virtual_evaluator =
  380. std::make_shared<VirtualEvaluator>(func->args_spec_list(), func->output());
  381. return virtual_evaluator;
  382. }
  383. EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PartialAbstractClosure> &func) {
  384. MS_EXCEPTION_IF_NULL(func);
  385. AbstractFunctionPtr func_orig = func->fn();
  386. EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig);
  387. std::shared_ptr<PartialAppEvaluator> partial_evaluator =
  388. std::make_shared<PartialAppEvaluator>(evaluator_orig, func->args());
  389. return partial_evaluator;
  390. }
  391. EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<TypedPrimitiveAbstractClosure> &) {
  392. MS_LOG(EXCEPTION) << "Should not be called ";
  393. }
  394. // Forward to specific subclass of FunctionWrapper.
  395. EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const AbstractFunctionPtr &func) {
  396. MS_EXCEPTION_IF_NULL(func);
  397. EvaluatorPtr evaluator = func->GetEvaluator(shared_from_this());
  398. return evaluator;
  399. }
  400. EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
  401. MS_LOG(DEBUG) << "The func value: " << func->ToString();
  402. if (func->tracking_id() != nullptr) {
  403. MS_LOG(DEBUG) << "The tracking_id: " << func->tracking_id()->DebugString();
  404. }
  405. MS_EXCEPTION_IF_NULL(func);
  406. if (func->tracking_id() == nullptr) {
  407. EvaluatorPtr evaluator = _GetEvaluatorFor(func);
  408. return evaluator;
  409. }
  410. auto inf_pair = constructors_.find(func);
  411. if (inf_pair != constructors_.end()) {
  412. return inf_pair->second;
  413. }
  414. AbstractFunctionPtr func_generic = func->Copy();
  415. func_generic->set_tracking_id(nullptr);
  416. EvaluatorPtr eval = _GetEvaluatorFor(func_generic);
  417. auto tracked_eval = std::make_shared<TrackedEvaluator>(eval);
  418. constructors_[func] = tracked_eval;
  419. return tracked_eval;
  420. }
  421. EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators,
  422. const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list) {
  423. if (evaluators.size() == 1) {
  424. EvaluatorPtr eval = evaluators[0];
  425. MS_EXCEPTION_IF_NULL(eval);
  426. return eval->Run(shared_from_this(), args_conf_list, out_conf);
  427. }
  428. return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list);
  429. }
  430. void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
  431. auto fg_eval = evaluator->cast<FuncGraphEvaluatorPtr>();
  432. if (fg_eval == nullptr) {
  433. return;
  434. }
  435. auto fg = fg_eval->func_graph();
  436. MS_EXCEPTION_IF_NULL(fg);
  437. auto undetermined_fgs = fg->recursive_graphs();
  438. if (undetermined_fgs) {
  439. auto fg_parent = fg->parent();
  440. MS_EXCEPTION_IF_NULL(fg_parent);
  441. fg_parent->set_flag(kFuncGraphFlagUndetermined, true);
  442. MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString();
  443. }
  444. }
  445. EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators,
  446. const EvaluatorPtr &eval, const AbstractBasePtrList &args_spec_list,
  447. const EvalTraceRevIter &it, bool *continue_flag) {
  448. *continue_flag = false;
  449. // Find latest entry function to handle nested recursion.
  450. EvaluatorPtr latest_entry = eval;
  451. auto latest_entry_iter = eval_trace_.rbegin();
  452. for (auto r_it = eval_trace_.rbegin(); *r_it != *it;) {
  453. auto it_temp = std::find(evaluators.begin(), evaluators.end(), r_it->first);
  454. if (it_temp != evaluators.end()) {
  455. latest_entry = *it_temp;
  456. latest_entry_iter = r_it;
  457. break;
  458. }
  459. latest_entry_iter = ++r_it;
  460. }
  461. if (latest_entry != eval) {
  462. MS_LOG(DEBUG) << "Continue Evaluator " << eval->ToString();
  463. *continue_flag = true;
  464. return latest_entry;
  465. }
  466. bool has_undetermined = false;
  467. // Check whether sub loop has untraced undetermined evaluator.
  468. std::set<std::pair<EvaluatorPtr, AbstractBasePtrList>> undetermined_evals;
  469. for (auto r_it = eval_trace_.rbegin(); r_it != latest_entry_iter; r_it++) {
  470. undetermined_evals.insert(*r_it);
  471. }
  472. MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size();
  473. for (auto u_eval : undetermined_evals) {
  474. MS_LOG(DEBUG) << u_eval.first->ToString() << " check undetermined.";
  475. if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) {
  476. MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined.";
  477. has_undetermined = true;
  478. break;
  479. }
  480. }
  481. if (has_undetermined == false) {
  482. MS_LOG(DEBUG) << eval->ToString() << " has no undetermined.";
  483. *continue_flag = true;
  484. return latest_entry;
  485. }
  486. return latest_entry;
  487. }
  488. EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_specs) {
  489. if (out_specs.size() == 0) {
  490. MS_LOG(EXCEPTION) << "There is an endless loop for evaluator.";
  491. }
  492. if (out_specs.size() == 1) {
  493. MS_EXCEPTION_IF_NULL(out_specs[0]);
  494. // If only one result derived, then broaden it to avoid wrong constant propagation.
  495. return std::make_shared<EvalResult>(out_specs[0]->Broaden(), std::make_shared<AttrValueMap>());
  496. }
  497. auto joined_spec = AbstractJoin(out_specs);
  498. MS_EXCEPTION_IF_NULL(joined_spec);
  499. MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString();
  500. return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
  501. }
  502. EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators,
  503. const AnfNodeConfigPtr &out_conf,
  504. const ConfigPtrList &args_conf_list) {
  505. AbstractBasePtrList out_specs;
  506. if (!multi_poss_.count(evaluators[0])) {
  507. multi_poss_[evaluators[0]] = evaluators[1];
  508. multi_poss_[evaluators[1]] = evaluators[0];
  509. }
  510. AbstractBasePtrList args_spec_list;
  511. (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
  512. [](const ConfigPtr &conf) -> AbstractBasePtr {
  513. MS_EXCEPTION_IF_NULL(conf);
  514. return conf->GetEvaluatedValue()->abstract();
  515. });
  516. for (auto eval : evaluators) {
  517. SetUndeterminedFlag(eval);
  518. auto current_inf = std::make_pair(eval, args_spec_list);
  519. MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString();
  520. // If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating.
  521. auto it = std::find(eval_trace_.rbegin(), eval_trace_.rend(), current_inf);
  522. if (it == eval_trace_.rend()) {
  523. eval_trace_.push_back(current_inf);
  524. MS_LOG(DEBUG) << "Trace Evaluator " << eval->ToString() << " ptr: " << eval.get();
  525. MS_EXCEPTION_IF_NULL(eval);
  526. auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf);
  527. MS_EXCEPTION_IF_NULL(eval_result->abstract());
  528. MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << eval_result->abstract()->ToString();
  529. out_specs.push_back(eval_result->abstract());
  530. eval_trace_.pop_back();
  531. if (eval_trace_.empty()) {
  532. multi_poss_.clear();
  533. }
  534. } else if (it != eval_trace_.rbegin()) {
  535. bool continue_flag = false;
  536. auto latest_entry = HandleNestedRecursion(evaluators, eval, args_spec_list, it, &continue_flag);
  537. if (continue_flag) {
  538. continue;
  539. }
  540. // Try to travel the latest undetermined.
  541. if (latest_entry != eval_trace_.rbegin()->first) {
  542. MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString();
  543. auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf);
  544. MS_EXCEPTION_IF_NULL(eval_result->abstract());
  545. MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString()
  546. << " return out_spec: " << eval_result->abstract()->ToString();
  547. return eval_result;
  548. }
  549. }
  550. }
  551. return ProcessEvalResults(out_specs);
  552. }
  553. EvalResultPtr AnfNodeConfig::GetEvaluatedValue() {
  554. AnfNodeConfigPtr self = shared_from_base<AnfNodeConfig>();
  555. return engine_.lock()->GetEvaluatedValue(self);
  556. }
  557. AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context, const AnfNodeConfigPtr &conf) {
  558. if (value->isa<FuncGraph>()) {
  559. auto func_graph = value->cast<FuncGraphPtr>();
  560. return func_graph->MakeAbstractClosure(context);
  561. }
  562. AnfNodePtr anf_node = nullptr;
  563. if (conf != nullptr) {
  564. anf_node = conf->node();
  565. }
  566. if (value->isa<MetaFuncGraph>()) {
  567. auto meta_func_graph = value->cast<MetaFuncGraphPtr>();
  568. return meta_func_graph->MakeAbstractClosure(anf_node);
  569. }
  570. if (value->isa<Primitive>()) {
  571. auto prim = value->cast<PrimitivePtr>();
  572. return prim->ToPrimAbstract(anf_node);
  573. }
  574. return value->ToAbstract();
  575. }
  576. AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) {
  577. AbstractBasePtr a = ToAbstract(value, nullptr, nullptr);
  578. if (broaden) {
  579. a = a->Broaden();
  580. }
  581. return a;
  582. }
  583. EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) {
  584. auto evaluator = GetPrimEvaluator(primitive, nullptr);
  585. MS_EXCEPTION_IF_NULL(evaluator);
  586. if (!evaluator->isa<TrivialPrimEvaluator>()) {
  587. MS_LOG(EXCEPTION) << "Prim " << primitive->ToString() << " should build a TrivialPrimEvaluator, but "
  588. << evaluator->ToString();
  589. }
  590. auto trivial_evaluator = dyn_cast<TrivialPrimEvaluator>(evaluator);
  591. auto eval_result = trivial_evaluator->EvalPrim(nullptr, arg_specs);
  592. return eval_result;
  593. }
  594. } // namespace abstract
  595. } // namespace mindspore