You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluator.cc 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pipeline/static_analysis/evaluator.h"
  17. #include <algorithm>
  18. #include <unordered_set>
  19. #include "ir/func_graph_cloner.h"
  20. #include "pipeline/static_analysis/utils.h"
  21. #include "debug/trace.h"
  22. namespace mindspore {
  23. namespace abstract {
  24. namespace {
  25. void InferEntryLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &arg_spec_list,
  26. const AnfNodeConfigPtr &out_conf) {
  27. MS_EXCEPTION_IF_NULL(evaluator);
  28. if (out_conf != nullptr) {
  29. MS_LOG(DEBUG) << "Evaluator " << evaluator->ToString() << " run for " << out_conf->node()->scope()->name();
  30. }
  31. for (size_t i = 0; i < arg_spec_list.size(); i++) {
  32. MS_LOG(DEBUG) << evaluator->ToString() << " input[" << i << "] abstract value: " << arg_spec_list[i]->ToString();
  33. }
  34. }
  35. void InferFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &, const AnfNodeConfigPtr &out_conf) {
  36. MS_EXCEPTION_IF_NULL(evaluator);
  37. if (out_conf != nullptr) {
  38. auto node = out_conf->node();
  39. if (IsValueNode<Primitive>(node)) {
  40. MS_LOG(ERROR) << "Evaluator " << evaluator->ToString() << " run failed for node " << node->fullname_with_scope()
  41. << ", with debug info: " << trace::GetDebugInfo(node->debug_info());
  42. } else {
  43. MS_LOG(ERROR) << "Evaluator " << evaluator->ToString() << " run failed for node " << node->DebugString()
  44. << ", with debug info: " << trace::GetDebugInfo(node->debug_info());
  45. }
  46. }
  47. }
  48. } // namespace
  49. AnalysisContextPtr BaseFuncGraphEvaluator::MakeContext(const AnalysisEnginePtr &engine,
  50. const AbstractBasePtrList &args_spec_list) {
  51. AbstractBasePtrList normalized_args_spec_list = NormalizeArgs(args_spec_list);
  52. FuncGraphPtr fg = GetFuncGraph(engine, normalized_args_spec_list);
  53. MS_EXCEPTION_IF_NULL(parent_context_);
  54. AnalysisContextPtr context = parent_context_->NewFuncGraphContext(fg, normalized_args_spec_list);
  55. return context;
  56. }
  57. static std::vector<AnfNodePtr> FastShadowSort(const AnfNodePtr &ret_node) {
  58. auto ori_func_graph = ret_node->func_graph();
  59. MS_EXCEPTION_IF_NULL(ori_func_graph);
  60. std::vector<AnfNodePtr> sorted_nodes;
  61. std::unordered_set<AnfNodePtr> checked_cnodes;
  62. std::size_t index = 0;
  63. sorted_nodes.emplace_back(ret_node);
  64. while (index < sorted_nodes.size()) {
  65. auto current = sorted_nodes[index];
  66. index++;
  67. MS_EXCEPTION_IF_NULL(current);
  68. if (current->isa<CNode>()) {
  69. auto &inputs = current->cast<CNodePtr>()->inputs();
  70. for (auto it = inputs.begin(); it != inputs.end(); it++) {
  71. AnfNodePtr input = *it;
  72. if (input != nullptr && input->isa<CNode>() && checked_cnodes.find(input) == checked_cnodes.end() &&
  73. input->func_graph() == ori_func_graph) {
  74. sorted_nodes.emplace_back(input);
  75. (void)checked_cnodes.insert(input);
  76. }
  77. }
  78. }
  79. }
  80. return sorted_nodes;
  81. }
  82. AbstractBasePtr BaseFuncGraphEvaluator::Infer(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) {
  83. FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list);
  84. MS_EXCEPTION_IF_NULL(fg);
  85. std::size_t nargs = fg->parameters().size();
  86. if (args_spec_list.size() != nargs) {
  87. MS_EXCEPTION(ValueError) << "Function " << fg->ToString() << ", The number of parameters of this function is "
  88. << fg->parameters().size() << ", but the number of provided arguments is "
  89. << args_spec_list.size() << ". NodeInfo: " << trace::GetDebugInfo(fg->debug_info());
  90. }
  91. MS_EXCEPTION_IF_NULL(parent_context_);
  92. MS_EXCEPTION_IF_NULL(engine);
  93. graph_context_ = parent_context_->NewFuncGraphContext(fg, args_spec_list);
  94. const auto &parameters = fg->parameters();
  95. for (size_t i = 0; i < nargs; i++) {
  96. const auto &arg = args_spec_list[i];
  97. const auto &node = parameters[i];
  98. AnfNodeConfigPtr conf = engine->MakeConfig(node, graph_context_);
  99. engine->cache().set_value(conf, arg);
  100. }
  101. const AnfNodePtr &func_node = fg->get_return();
  102. MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg->ToString()
  103. << ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString();
  104. AbstractBasePtr ret_base = nullptr;
  105. std::vector<AnfNodePtr> nodes = FastShadowSort(func_node);
  106. for (auto it = nodes.crbegin(); it != nodes.crend(); it++) {
  107. const auto &node = *it;
  108. AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_);
  109. MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString();
  110. ret_base = engine->GetEvaluatedValue(node_conf);
  111. MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString()
  112. << ", abstract: " << ret_base->ToString();
  113. }
  114. MS_EXCEPTION_IF_NULL(ret_base);
  115. MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " infer end, inferred abstract: " << ret_base->ToString();
  116. return ret_base;
  117. }
  118. AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
  119. MS_EXCEPTION_IF_NULL(func_graph_);
  120. if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
  121. AbstractBasePtrList broaded_list;
  122. (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broaded_list),
  123. [](const AbstractBasePtr &arg) -> AbstractBasePtr {
  124. MS_EXCEPTION_IF_NULL(arg);
  125. return arg->Broaden();
  126. });
  127. MS_LOG(DEBUG) << func_graph_->ToString() << " original: " << mindspore::ToString(args_spec_list)
  128. << ", broaded: " << mindspore::ToString(broaded_list);
  129. return broaded_list;
  130. }
  131. if (func_graph_->has_flag(kFuncGraphFlagUndetermined)) {
  132. if (parent_context_) {
  133. MS_LOG(DEBUG) << "Undeterminate FuncGraphEvaluator " << ToString()
  134. << ", context: " << parent_context_->ToString();
  135. auto last_context = parent_context_->Filter(func_graph_);
  136. if (last_context && last_context->func_graph() == func_graph_) {
  137. MS_LOG(DEBUG) << "Find last infer context: " << last_context->ToString();
  138. MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list);
  139. MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(last_context->args_spec_list());
  140. // Join the last eval arguments and current arguments to check if there are loop variant.
  141. auto joined_args_spec_list = AbstractJoin(args_spec_list, last_context->args_spec_list());
  142. MS_LOG(DEBUG) << "Joined args: " << ::mindspore::ToString(joined_args_spec_list);
  143. // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation.
  144. if (!(joined_args_spec_list == args_spec_list)) {
  145. func_graph_->set_flags(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
  146. }
  147. return joined_args_spec_list;
  148. }
  149. }
  150. }
  151. return args_spec_list;
  152. }
  153. FuncGraphPtr FuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) {
  154. auto iter = func_graph_cache_.find(args_spec_list);
  155. FuncGraphPtr ret = nullptr;
  156. if (iter == func_graph_cache_.end()) {
  157. auto fg = func_graph();
  158. MS_EXCEPTION_IF_NULL(fg);
  159. TraceManager::DebugTrace(std::make_shared<TraceEvaluatorGenGraph>(fg->debug_info()));
  160. FuncGraphPtr generated_graph = fg->GenerateGraph(args_spec_list);
  161. TraceManager::EndTrace();
  162. func_graph_cache_[args_spec_list] = generated_graph;
  163. MS_EXCEPTION_IF_NULL(engine);
  164. engine->func_graph_manager()->AddFuncGraph(generated_graph);
  165. ret = generated_graph;
  166. } else {
  167. ret = iter->second;
  168. }
  169. // For the top graph, if it is replaced by generated graph, update the top graph to the new one.
  170. if (parse::Parser::GetTopFuncGraph() == func_graph()) {
  171. if (ret != func_graph()) {
  172. parse::Parser::UpdateTopFuncGraph(ret);
  173. }
  174. }
  175. return ret;
  176. }
  177. FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) {
  178. auto iter = func_graph_cache_.find(args_spec_list);
  179. if (iter != func_graph_cache_.end()) {
  180. return iter->second;
  181. }
  182. MS_EXCEPTION_IF_NULL(meta_func_graph_);
  183. FuncGraphPtr generated_func_graph = nullptr;
  184. if (this->bound_node() != nullptr) {
  185. TraceManager::DebugTrace(std::make_shared<TraceGenMetaFuncGraph>(bound_node()->debug_info()));
  186. generated_func_graph = meta_func_graph_->GenerateFuncGraph(args_spec_list);
  187. TraceManager::EndTrace();
  188. } else {
  189. generated_func_graph = meta_func_graph_->GenerateFuncGraph(args_spec_list);
  190. }
  191. FuncGraphPtr cloned_func_graph = BasicClone(generated_func_graph);
  192. func_graph_cache_[args_spec_list] = cloned_func_graph;
  193. MS_EXCEPTION_IF_NULL(engine);
  194. engine->func_graph_manager()->AddFuncGraph(cloned_func_graph);
  195. return cloned_func_graph;
  196. }
  197. AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
  198. AnfNodeConfigPtr out_conf) {
  199. const std::string &evaluator_name = ToString();
  200. AbstractBasePtrList args_spec_list;
  201. (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
  202. [](const ConfigPtr &conf) -> AbstractBasePtr {
  203. MS_EXCEPTION_IF_NULL(conf);
  204. return conf->GetEvaluatedValue();
  205. });
  206. args_spec_list = NormalizeArgs(args_spec_list);
  207. trace::TraceGraphInferEnter(shared_from_base<Evaluator>(), out_conf);
  208. InferEntryLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
  209. MS_EXCEPTION_IF_NULL(cache_);
  210. auto iter = cache_->find(args_spec_list);
  211. if (iter == cache_->end()) {
  212. MS_LOG(DEBUG) << evaluator_name << " cache miss, call Infer().";
  213. AbstractBasePtr ret = Infer(engine, args_spec_list);
  214. if (ret == nullptr) {
  215. InferFailLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
  216. MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr.";
  217. }
  218. MS_EXCEPTION_IF_NULL(ret);
  219. MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->ToString() << ".";
  220. (*cache_)[args_spec_list] = ret;
  221. trace::TraceGraphInferLeave(shared_from_base<Evaluator>());
  222. return ret;
  223. } else {
  224. MS_EXCEPTION_IF_NULL(iter->second);
  225. MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->ToString() << ".";
  226. trace::TraceGraphInferLeave(shared_from_base<Evaluator>());
  227. return iter->second;
  228. }
  229. }
  230. AbstractBasePtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
  231. AnfNodeConfigPtr) {
  232. AbstractBasePtrList args_spec_list;
  233. (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
  234. [](const ConfigPtr &conf) -> AbstractBasePtr {
  235. MS_EXCEPTION_IF_NULL(conf);
  236. return conf->GetEvaluatedValue();
  237. });
  238. AbstractBasePtr ret = EvalPrim(engine, args_spec_list);
  239. return ret;
  240. }
  241. AbstractBasePtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
  242. AnfNodeConfigPtr out_conf) {
  243. AbstractBasePtrList args_spec_list;
  244. (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
  245. [](const ConfigPtr &conf) -> AbstractBasePtr {
  246. MS_EXCEPTION_IF_NULL(conf);
  247. return conf->GetEvaluatedValue();
  248. });
  249. if (args_conf_list.size() == 0) {
  250. MS_LOG(EXCEPTION) << "Size should greater than 0";
  251. }
  252. AbstractBasePtr ret = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf);
  253. // No need to cache.
  254. return ret;
  255. }
  256. AbstractBasePtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) {
  257. AbstractBasePtr ret = EvalPrim(args_conf_list);
  258. return ret;
  259. }
  260. AbstractBasePtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
  261. AnfNodeConfigPtr out_conf) {
  262. AbstractBasePtrList args_spec_list;
  263. (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
  264. [](const ConfigPtr &conf) -> AbstractBasePtr {
  265. MS_EXCEPTION_IF_NULL(conf);
  266. return conf->GetEvaluatedValue();
  267. });
  268. AbstractBasePtr ret = sub_evaluator_->Run(engine, args_conf_list, out_conf);
  269. // Don't lookup from cache, as different out_conf with same node but different context
  270. // may add different entry to anfnode_config_map_, like getattr primitive.
  271. (*cache_)[args_spec_list] = ret;
  272. return ret;
  273. }
  274. AbstractBasePtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
  275. AnfNodeConfigPtr out_conf) {
  276. AbstractBasePtrList args_spec_list;
  277. (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
  278. [](const ConfigPtr &conf) -> AbstractBasePtr {
  279. MS_EXCEPTION_IF_NULL(conf);
  280. return conf->GetEvaluatedValue();
  281. });
  282. MS_EXCEPTION_IF_NULL(cache_);
  283. auto iter = cache_->find(args_spec_list);
  284. if (iter != cache_->end()) {
  285. return iter->second;
  286. }
  287. ConfigPtrList partial_args_conf_list;
  288. // Join arguments in partial and the rest arguments from args_conf_list.
  289. (void)std::transform(args_spec_list_.begin(), args_spec_list_.end(), std::back_inserter(partial_args_conf_list),
  290. [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
  291. (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(partial_args_conf_list),
  292. [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
  293. AbstractBasePtr ret = evaluator_->Run(engine, partial_args_conf_list, out_conf);
  294. (*cache_)[args_spec_list] = ret;
  295. return ret;
  296. }
  297. AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) {
  298. AbstractBasePtrList args_spec_list;
  299. (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
  300. [](const ConfigPtr &conf) -> AbstractBasePtr {
  301. MS_EXCEPTION_IF_NULL(conf);
  302. return conf->GetEvaluatedValue();
  303. });
  304. MS_EXCEPTION_IF_NULL(cache_);
  305. auto iter = cache_->find(args_spec_list);
  306. if (iter != cache_->end()) {
  307. return iter->second;
  308. }
  309. // Call the original evaluator, get the result: y = f(x)
  310. AbstractBasePtr result = evaluator_->Run(engine, args_conf_list, nullptr);
  311. // Build a virtual function: bprop_f which use sense of y as input, return sense of function free variable and input
  312. // parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y)
  313. AbstractBasePtrList bparams;
  314. bparams.push_back(SensitivityTransform(orig_func_));
  315. (void)std::transform(
  316. args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams),
  317. [](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { return SensitivityTransform(arg_spec); });
  318. AbstractBasePtr bparams_final = std::make_shared<AbstractTuple>(bparams);
  319. AbstractFunctionPtr bprop = std::make_shared<VirtualAbstractClosure>(SensitivityTransform(result), bparams_final);
  320. // J(f)(J(x)) return a tuple (y, bprop_f)
  321. AbstractBasePtrList jargs = {result, bprop};
  322. AbstractBasePtr jtuple = std::make_shared<AbstractTuple>(jargs);
  323. (*cache_)[args_spec_list] = jtuple;
  324. return jtuple;
  325. }
  326. AbstractBasePtr VirtualEvaluator::Infer(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) {
  327. if (args_spec_list.size() != args_spec_list_.size()) {
  328. MS_LOG(EXCEPTION) << "Arguments mismatch, parameters no: " << args_spec_list_.size()
  329. << ", arguments no: " << args_spec_list.size();
  330. }
  331. // Check each parameter and argument match;
  332. for (std::size_t i = 0; i < args_spec_list.size(); i++) {
  333. MS_EXCEPTION_IF_NULL(args_spec_list[i]);
  334. (void)args_spec_list[i]->Join(args_spec_list_[i]);
  335. }
  336. return output_;
  337. }
  338. } // namespace abstract
  339. } // namespace mindspore