You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prim.cc 61 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/jit/static_analysis/prim.h"
  19. #include <algorithm>
  20. #include <limits>
  21. #include <mutex>
  22. #include <string>
  23. #include <utility>
  24. #include <unordered_set>
  25. #include "frontend/operator/cc_implementations.h"
  26. #include "frontend/operator/ops.h"
  27. #include "frontend/operator/composite/do_signature.h"
  28. #include "frontend/operator/prim_to_function.h"
  29. #include "abstract/utils.h"
  30. #include "utils/symbolic.h"
  31. #include "pipeline/jit/resource.h"
  32. #include "pipeline/jit/parse/resolve.h"
  33. #include "utils/convert_utils.h"
  34. #include "utils/convert_utils_py.h"
  35. #include "utils/ms_context.h"
  36. #include "pipeline/jit/parse/data_converter.h"
  37. #include "abstract/primitive_infer_map.h"
  38. #include "abstract/param_validator.h"
  39. #include "utils/ms_utils.h"
  40. #include "utils/shape_utils.h"
  41. namespace mindspore {
  42. namespace abstract {
  43. using mindspore::parse::PyObjectWrapper;
  44. std::unordered_set<std::string> prims_to_skip_undetermined_infer{
  45. "make_tuple", "make_list", "switch", "env_setitem", "env_getitem", "Load", "UpdateState"};
  46. EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
  47. AnfNodeConfigPtr out_conf) {
  48. AbstractBasePtrList args_spec_list;
  49. (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
  50. [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); });
  51. auto do_signature = prim_->cast<prim::DoSignaturePrimitivePtr>();
  52. auto &func = do_signature->function();
  53. if (func->isa<Primitive>()) {
  54. auto sig_prim = func->cast<PrimitivePtr>();
  55. if (prims_to_skip_undetermined_infer.find(sig_prim->name()) == prims_to_skip_undetermined_infer.end()) {
  56. auto ret_abstract = AbstractEval(args_spec_list);
  57. if (ret_abstract != nullptr) {
  58. MS_LOG(DEBUG) << "DoSignatureEvaluator eval Undetermined";
  59. return ret_abstract;
  60. }
  61. }
  62. }
  63. if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
  64. MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
  65. }
  66. auto out_node = dyn_cast<CNode>(out_conf->node());
  67. const auto &out_node_inputs = out_node->inputs();
  68. if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) {
  69. MS_LOG(EXCEPTION) << "Op: " << do_signature->function()->ToString()
  70. << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
  71. << ", inputs size " << out_node_inputs.size();
  72. }
  73. AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
  74. ScopePtr scope = kDefaultScope;
  75. if (out_conf != nullptr) {
  76. scope = out_conf->node()->scope();
  77. }
  78. ScopeGuard scope_guard(scope);
  79. AnfNodePtr new_cnode = nullptr;
  80. if (bound_node() != nullptr) {
  81. TraceGuard trace_guard(std::make_shared<TraceDoSignature>(bound_node()->debug_info()));
  82. new_cnode = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list,
  83. args_inputs);
  84. } else {
  85. new_cnode = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list,
  86. args_inputs);
  87. }
  88. AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context());
  89. return engine->ForwardConfig(out_conf, fn_conf);
  90. }
  91. static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_spec_list, bool need_unpack) {
  92. // arg[0] is the func graph to unpack, ignore it
  93. AbstractBasePtrList specialize_args_before_unpack(args_spec_list.begin() + 1, args_spec_list.end());
  94. AbstractBasePtrList graph_specialize_args;
  95. if (need_unpack) {
  96. for (size_t index = 0; index < specialize_args_before_unpack.size(); index++) {
  97. MS_EXCEPTION_IF_NULL(specialize_args_before_unpack[index]);
  98. if (specialize_args_before_unpack[index]->isa<AbstractTuple>()) {
  99. auto arg_tuple = specialize_args_before_unpack[index]->cast<AbstractTuplePtr>();
  100. std::transform(arg_tuple->elements().begin(), arg_tuple->elements().end(),
  101. std::back_inserter(graph_specialize_args), [](AbstractBasePtr abs) { return abs; });
  102. } else if (specialize_args_before_unpack[index]->isa<AbstractDictionary>()) {
  103. auto arg_dict = specialize_args_before_unpack[index]->cast<AbstractDictionaryPtr>();
  104. auto dict_elems = arg_dict->elements();
  105. (void)std::transform(
  106. dict_elems.begin(), dict_elems.end(), std::back_inserter(graph_specialize_args),
  107. [](const AbstractAttribute &item) { return std::make_shared<AbstractKeywordArg>(item.first, item.second); });
  108. } else {
  109. MS_LOG(EXCEPTION) << "UnpackGraph require args should be tuple or dict, but got "
  110. << specialize_args_before_unpack[index]->ToString();
  111. }
  112. }
  113. } else {
  114. graph_specialize_args = specialize_args_before_unpack;
  115. }
  116. return graph_specialize_args;
  117. }
  118. EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
  119. AnfNodeConfigPtr out_conf) {
  120. if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
  121. MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
  122. }
  123. auto unpack_graph = prim_->cast<prim::UnpackGraphPrimitivePtr>();
  124. auto out_node = out_conf->node()->cast<CNodePtr>();
  125. const auto &out_node_inputs = out_node->inputs();
  126. if (out_node->inputs().empty() || (out_node_inputs.size() - 1) != args_conf_list.size()) {
  127. MS_LOG(EXCEPTION) << "UnpackGraphPrimitive"
  128. << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
  129. << ", inputs size " << out_node_inputs.size();
  130. }
  131. AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
  132. AbstractBasePtrList args_spec_list;
  133. (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
  134. [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); });
  135. // get the forward graph
  136. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  137. auto fn = args_spec_list[0]->cast<AbstractFunctionPtr>();
  138. if (fn == nullptr) {
  139. MS_LOG(EXCEPTION) << "UnpackGraphPrimitive arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString();
  140. }
  141. auto real_fn = fn->cast<FuncGraphAbstractClosurePtr>();
  142. MS_EXCEPTION_IF_NULL(real_fn);
  143. FuncGraphPtr forward_graph = real_fn->func_graph();
  144. MS_EXCEPTION_IF_NULL(forward_graph);
  145. AbstractBasePtrList graph_specialize_args =
  146. GetUnpackGraphSpecArgsList(args_spec_list, unpack_graph->need_unpack_args());
  147. AbstractBasePtrList graph_specialize_args_without_sens;
  148. if (unpack_graph->with_sens_in_args() && graph_specialize_args.empty()) {
  149. MS_EXCEPTION(ValueError) << "Grad with sens, but the sens is not provided.";
  150. }
  151. (void)std::transform(graph_specialize_args.begin(),
  152. graph_specialize_args.end() - (unpack_graph->with_sens_in_args() ? 1 : 0),
  153. std::back_inserter(graph_specialize_args_without_sens), [](AbstractBasePtr abs) { return abs; });
  154. auto new_graph = forward_graph->GenerateGraph(graph_specialize_args_without_sens);
  155. engine->func_graph_manager()->AddFuncGraph(new_graph);
  156. ScopePtr scope = kDefaultScope;
  157. if (out_conf != nullptr) {
  158. scope = out_conf->node()->scope();
  159. }
  160. ScopeGuard scope_guard(scope);
  161. AnfNodePtr new_vnode = NewValueNode(new_graph);
  162. AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_vnode, out_conf->context());
  163. return engine->ForwardConfig(out_conf, fn_conf);
  164. }
  165. AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const AbstractBasePtr &node_type,
  166. const AnfNodePtr &target_type, const FuncGraphPtr &func_graph) {
  167. AnfNodePtr target_node = source_node;
  168. if (node_type->isa<AbstractTensor>()) {
  169. auto x = node_type->cast<AbstractTensorPtr>();
  170. if (x->element()->BuildType()->isa<Float>()) {
  171. auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional");
  172. MS_EXCEPTION_IF_NULL(cast);
  173. target_node = func_graph->NewCNodeAfter(source_node, {NewValueNode(cast), source_node, target_type});
  174. }
  175. } else if (node_type->isa<AbstractTuple>()) {
  176. auto x = node_type->cast<AbstractTuplePtr>();
  177. auto &items = x->elements();
  178. std::vector<AnfNodePtr> nodes;
  179. nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  180. int64_t idx = 0;
  181. for (const auto &item : items) {
  182. AnfNodePtr tuple_node =
  183. func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), source_node, NewValueNode(idx)});
  184. AnfNodePtr node = MixedPrecisionCastHelper(tuple_node, item, target_type, func_graph);
  185. nodes.emplace_back(node);
  186. ++idx;
  187. }
  188. target_node = func_graph->NewCNode(nodes);
  189. } else if (node_type->isa<AbstractDictionary>()) {
  190. auto x = node_type->cast<AbstractDictionaryPtr>();
  191. auto &items = x->elements();
  192. std::vector<AnfNodePtr> dict_key_nodes;
  193. std::vector<AnfNodePtr> dict_value_nodes;
  194. dict_key_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  195. dict_value_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  196. for (const auto &item : items) {
  197. AnfNodePtr dict_value_node =
  198. func_graph->NewCNode({NewValueNode(prim::kPrimDictGetItem), source_node, NewValueNode(item.first)});
  199. AnfNodePtr node = MixedPrecisionCastHelper(dict_value_node, item.second, target_type, func_graph);
  200. dict_key_nodes.emplace_back(NewValueNode(item.first));
  201. dict_value_nodes.emplace_back(node);
  202. }
  203. target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), func_graph->NewCNode(dict_key_nodes),
  204. func_graph->NewCNode(dict_value_nodes)});
  205. } else if (node_type->isa<AbstractKeywordArg>()) {
  206. auto x = node_type->cast<AbstractKeywordArgPtr>();
  207. std::string kwarg_key = x->get_key();
  208. AnfNodePtr kwarg_value_node =
  209. func_graph->NewCNode({NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kwarg_key), source_node});
  210. AnfNodePtr node = MixedPrecisionCastHelper(kwarg_value_node, x->get_arg(), target_type, func_graph);
  211. target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(kwarg_key), node});
  212. }
  213. return target_node;
  214. }
  215. EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
  216. AnfNodeConfigPtr out_conf) {
  217. AbstractBasePtrList args_spec_list;
  218. if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
  219. MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
  220. }
  221. auto out_node = out_conf->node()->cast<CNodePtr>();
  222. const auto &out_node_inputs = out_node->inputs();
  223. if (out_node->inputs().empty() || (out_node_inputs.size() - 1) != args_conf_list.size()) {
  224. MS_LOG(EXCEPTION) << "MixedPrecisionCast"
  225. << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
  226. << ", inputs size " << out_node_inputs.size();
  227. }
  228. (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
  229. [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); });
  230. ScopePtr scope = kDefaultScope;
  231. if (out_conf != nullptr) {
  232. scope = out_conf->node()->scope();
  233. }
  234. ScopeGuard scope_guard(scope);
  235. FuncGraphPtr func_graph = out_conf->node()->func_graph();
  236. AnfNodePtr new_node = MixedPrecisionCastHelper(out_node_inputs[2], args_spec_list[1], out_node_inputs[1], func_graph);
  237. AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context());
  238. return engine->ForwardConfig(out_conf, fn_conf);
  239. }
  240. namespace {
  241. py::object BuildValue(const ValuePtr &value_ptr) {
  242. if (value_ptr == nullptr) {
  243. return py::none();
  244. } else {
  245. return ValuePtrToPyData(value_ptr);
  246. }
  247. }
  248. py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
  249. auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
  250. size_t len = arg_tuple->size();
  251. py::tuple shape_tuple(len);
  252. py::tuple dtype_tuple(len);
  253. py::tuple value_tuple(len);
  254. py::tuple min_value_tuple(len);
  255. py::tuple max_value_tuple(len);
  256. py::tuple min_shape_tuple(len);
  257. py::tuple max_shape_tuple(len);
  258. auto dic = py::dict();
  259. bool dyn_shape = false;
  260. bool dyn_value = false;
  261. for (size_t i = 0; i < len; i++) {
  262. auto arg = arg_tuple->elements()[i];
  263. py::dict out = ConvertAbstractToPython(arg);
  264. shape_tuple[i] = out[ATTR_SHAPE];
  265. dtype_tuple[i] = out[ATTR_DTYPE];
  266. value_tuple[i] = out[ATTR_VALUE];
  267. // Elements in tuple is tensor shape value.
  268. if (out.contains(py::str(ATTR_MIN_VALUE)) && out.contains(py::str(ATTR_MAX_VALUE))) {
  269. min_value_tuple[i] = out[ATTR_MIN_VALUE];
  270. max_value_tuple[i] = out[ATTR_MAX_VALUE];
  271. dyn_value = true;
  272. }
  273. // Elements in tuple is tensor, which shape is dynamic.
  274. if (out.contains(py::str(ATTR_MIN_SHAPE)) && out.contains(py::str(ATTR_MAX_SHAPE))) {
  275. min_shape_tuple[i] = out[ATTR_MIN_SHAPE];
  276. max_shape_tuple[i] = out[ATTR_MAX_SHAPE];
  277. dyn_shape = true;
  278. }
  279. }
  280. dic[ATTR_SHAPE] = shape_tuple;
  281. dic[ATTR_DTYPE] = dtype_tuple;
  282. if (arg_tuple->BuildValue()->isa<AnyValue>()) {
  283. dic[ATTR_VALUE] = py::none();
  284. } else {
  285. dic[ATTR_VALUE] = value_tuple;
  286. }
  287. if (dyn_value) {
  288. dic[ATTR_MIN_VALUE] = min_value_tuple;
  289. dic[ATTR_MAX_VALUE] = max_value_tuple;
  290. }
  291. if (dyn_shape) {
  292. dic[ATTR_MIN_SHAPE] = min_shape_tuple;
  293. dic[ATTR_MAX_SHAPE] = max_shape_tuple;
  294. }
  295. return dic;
  296. }
  297. py::dict AbstractListToPython(const AbstractBasePtr &abs_base) {
  298. auto arg_list = dyn_cast<AbstractList>(abs_base);
  299. size_t len = arg_list->size();
  300. py::list shape_list(len);
  301. py::list dtype_list(len);
  302. py::list value_list(len);
  303. py::list min_shape_list(len);
  304. py::list max_shape_list(len);
  305. auto dic = py::dict();
  306. bool dyn_shape = false;
  307. for (size_t i = 0; i < len; i++) {
  308. py::dict out = ConvertAbstractToPython(arg_list->elements()[i]);
  309. shape_list[i] = out[ATTR_SHAPE];
  310. dtype_list[i] = out[ATTR_DTYPE];
  311. value_list[i] = out[ATTR_VALUE];
  312. // Elements in list is tensor, which shape is dynamic.
  313. if (out.contains(py::str(ATTR_MIN_SHAPE)) && out.contains(py::str(ATTR_MAX_SHAPE))) {
  314. min_shape_list[i] = out[ATTR_MIN_SHAPE];
  315. max_shape_list[i] = out[ATTR_MAX_SHAPE];
  316. dyn_shape = true;
  317. }
  318. }
  319. dic[ATTR_SHAPE] = shape_list;
  320. dic[ATTR_DTYPE] = dtype_list;
  321. if (arg_list->BuildValue()->isa<AnyValue>()) {
  322. dic[ATTR_VALUE] = py::none();
  323. } else {
  324. dic[ATTR_VALUE] = value_list;
  325. }
  326. if (dyn_shape) {
  327. dic[ATTR_MIN_SHAPE] = min_shape_list;
  328. dic[ATTR_MAX_SHAPE] = max_shape_list;
  329. }
  330. return dic;
  331. }
  332. } // end anonymous namespace
  333. py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
  334. MS_EXCEPTION_IF_NULL(abs_base);
  335. auto dic = py::dict();
  336. if (abs_base->isa<AbstractTensor>()) {
  337. auto arg_tensor = dyn_cast<AbstractTensor>(abs_base);
  338. dic[ATTR_SHAPE] = arg_tensor->shape()->shape();
  339. if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
  340. const auto &min_shape = arg_tensor->shape()->min_shape();
  341. const auto &max_shape = arg_tensor->shape()->max_shape();
  342. if (!min_shape.empty() && !max_shape.empty()) {
  343. dic[ATTR_MIN_SHAPE] = min_shape;
  344. dic[ATTR_MAX_SHAPE] = max_shape;
  345. }
  346. }
  347. auto min_value = arg_tensor->get_min_value();
  348. auto max_value = arg_tensor->get_max_value();
  349. if (min_value != nullptr && max_value != nullptr) {
  350. dic[ATTR_MIN_VALUE] = BuildValue(min_value);
  351. dic[ATTR_MAX_VALUE] = BuildValue(max_value);
  352. }
  353. dic[ATTR_DTYPE] = arg_tensor->BuildType();
  354. dic[ATTR_VALUE] = BuildValue(arg_tensor->BuildValue());
  355. } else if (abs_base->isa<AbstractRowTensor>()) {
  356. auto arg = dyn_cast<AbstractRowTensor>(abs_base);
  357. dic[ATTR_SHAPE] = arg->shape()->shape();
  358. dic[ATTR_DTYPE] = arg->BuildType();
  359. dic[ATTR_VALUE] = BuildValue(arg->BuildValue());
  360. } else if (abs_base->isa<AbstractSparseTensor>()) {
  361. auto arg = dyn_cast<AbstractSparseTensor>(abs_base);
  362. dic[ATTR_SHAPE] = arg->shape()->shape();
  363. dic[ATTR_DTYPE] = arg->BuildType();
  364. dic[ATTR_VALUE] = BuildValue(arg->BuildValue());
  365. } else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>() || abs_base->isa<AbstractRefKey>()) {
  366. ShapeVector shape;
  367. dic[ATTR_SHAPE] = shape;
  368. dic[ATTR_DTYPE] = abs_base->BuildType();
  369. dic[ATTR_VALUE] = BuildValue(abs_base->BuildValue());
  370. } else if (abs_base->isa<AbstractSlice>()) {
  371. auto arg_slice = dyn_cast<AbstractSlice>(abs_base);
  372. ShapeVector shape;
  373. dic[ATTR_SHAPE] = shape;
  374. dic[ATTR_DTYPE] = arg_slice->BuildType();
  375. dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue());
  376. } else if (abs_base->isa<AbstractEllipsis>()) {
  377. dic[ATTR_SHAPE] = py::none();
  378. dic[ATTR_DTYPE] = py::ellipsis();
  379. dic[ATTR_VALUE] = py::ellipsis();
  380. } else if (abs_base->isa<AbstractTuple>()) {
  381. return AbstractTupleToPython(abs_base);
  382. } else if (abs_base->isa<AbstractList>()) {
  383. return AbstractListToPython(abs_base);
  384. } else if (abs_base->isa<AbstractNone>()) {
  385. dic[ATTR_SHAPE] = py::none();
  386. dic[ATTR_DTYPE] = py::none();
  387. dic[ATTR_VALUE] = py::none();
  388. } else if (abs_base->isa<AbstractFunction>()) {
  389. dic[ATTR_SHAPE] = py::none();
  390. dic[ATTR_DTYPE] = abs_base->BuildType();
  391. dic[ATTR_VALUE] = py::none();
  392. if (abs_base->isa<PartialAbstractClosure>()) {
  393. AbstractBasePtrList args = abs_base->cast<PartialAbstractClosurePtr>()->args();
  394. if (!args.empty()) {
  395. auto value = args[0]->BuildValue()->cast<parse::ClassTypePtr>();
  396. if (value != nullptr) {
  397. dic[ATTR_DTYPE] = std::make_shared<TypeType>();
  398. dic[ATTR_VALUE] = value->obj();
  399. }
  400. }
  401. }
  402. } else if (abs_base->isa<AbstractUndetermined>()) {
  403. auto arg = dyn_cast<AbstractUndetermined>(abs_base);
  404. dic[ATTR_SHAPE] = py::none();
  405. dic[ATTR_DTYPE] = arg->BuildType();
  406. dic[ATTR_VALUE] = py::none();
  407. } else if (abs_base->isa<AbstractMonad>()) {
  408. dic[ATTR_SHAPE] = py::none();
  409. dic[ATTR_DTYPE] = abs_base->BuildType();
  410. dic[ATTR_VALUE] = py::none();
  411. } else {
  412. auto value = abs_base->BuildValue();
  413. if ((*value == *kAnyValue)) {
  414. auto value_desc = abs_base->value_desc();
  415. MS_EXCEPTION(TypeError) << "Unsupported parameter " << (value_desc.empty() ? "type" : value_desc)
  416. << " for python primitive." << abs_base->ToString();
  417. }
  418. MS_EXCEPTION(TypeError) << "Unsupported parameter type for python primitive, the parameter value is "
  419. << value->ToString();
  420. }
  421. return dic;
  422. }
  423. namespace {
  424. py::tuple PreparePyInputs(const PrimitivePyPtr &prim_py, const AbstractBasePtrList &args) {
  425. const AbstractBasePtrList *args_ptr;
  426. if (prim_py->is_tuple_input_) {
  427. if (args.empty()) {
  428. MS_LOG(EXCEPTION) << "Primitive args is empty";
  429. }
  430. if (args[0] == nullptr || !args[0]->isa<AbstractTuple>()) {
  431. MS_LOG(EXCEPTION) << "Custom Primitive inputs should be packed into a Tuple after converting"
  432. "prim convert pass for GE.";
  433. }
  434. args_ptr = &(args[0]->cast<AbstractTuplePtr>()->elements());
  435. } else {
  436. args_ptr = &args;
  437. }
  438. // The monad parameter is defined at the end of the parameter and needs to be ignored
  439. std::size_t size_args = args_ptr->size() - GetAbstractMonadNum(*args_ptr);
  440. py::tuple py_args(size_args);
  441. for (size_t i = 0; i < size_args; i++) {
  442. auto arg_i = (*args_ptr)[i];
  443. py_args[i] = ConvertAbstractToPython(arg_i);
  444. }
  445. return py_args;
  446. }
  447. AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) {
  448. // Convert to AbstractValue based on type and shape
  449. auto out_dtype = output[ATTR_DTYPE];
  450. if (output[ATTR_VALUE].is_none()) {
  451. auto out_shape = output[ATTR_SHAPE];
  452. return PyListDtype2AbstractTensor(out_shape, out_dtype, output);
  453. }
  454. // Convert pyobject to Value, then to AbstractValue
  455. ValuePtr converted_ret = nullptr;
  456. TypePtr dtype = py::isinstance<Type>(out_dtype) ? out_dtype.cast<TypePtr>() : nullptr;
  457. bool converted = parse::ConvertData(output[ATTR_VALUE], &converted_ret, false, dtype);
  458. if (!converted) {
  459. MS_LOG(EXCEPTION) << "Convert data failed";
  460. }
  461. auto res_spec = FromValue(converted_ret);
  462. MS_EXCEPTION_IF_NULL(res_spec);
  463. if (res_spec->isa<AbstractTensor>()) {
  464. // Replace to tensor constant node in specialize
  465. auto res_tensor = res_spec->cast<AbstractTensorPtr>();
  466. res_tensor->set_value(converted_ret);
  467. SetValueRange(res_tensor, output);
  468. }
  469. if (prim_py->IsCustomPrim()) {
  470. // Raise error if output_num is not match the infer result.
  471. int64_t output_num = GetValue<int64_t>(prim_py->GetAttr("output_num"));
  472. if (res_spec->isa<AbstractTensor>() && output_num != 1) {
  473. MS_LOG(EXCEPTION) << "Custom primitive " << prim_py->ToString() << " output_num " << output_num
  474. << " not matches the infer result.";
  475. } else if (res_spec->isa<AbstractTuple>() &&
  476. (res_spec->cast<AbstractTuplePtr>()->size() != LongToSize(output_num))) {
  477. MS_LOG(EXCEPTION) << "Custom primitive " << prim_py->ToString() << " output_num " << output_num
  478. << " not matches the infer result.";
  479. }
  480. }
  481. return res_spec;
  482. }
  483. } // end anonymous namespace
  484. EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
  485. auto prim_py = dyn_cast<PrimitivePy>(prim_);
  486. if (prim_py == nullptr) {
  487. MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyInferCheck' should be a python primitive.";
  488. }
  489. // Call checking method '__check__' for subclass of 'PrimitiveWithCheck'
  490. MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString();
  491. auto py_args = PreparePyInputs(prim_py, args);
  492. prim_py->RunCheck(py_args);
  493. prim_->BeginRecordAddAttr();
  494. AbstractBasePtr abs_base = eval_impl_(engine, prim_, args);
  495. prim_->EndRecordAddAttr();
  496. auto added_attrs = prim_->evaluate_added_attrs();
  497. if (!py::hasattr(prim_py->GetPyObj(), PY_PRIM_METHOD_INFER_VALUE)) {
  498. return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
  499. }
  500. // Call method 'infer_value' for primitive with this method for constant propagation
  501. py::tuple py_vals(py_args.size());
  502. for (size_t i = 0; i < py_args.size(); ++i) {
  503. py_vals[i] = py_args[i][ATTR_VALUE];
  504. }
  505. py::object py_ret = prim_py->RunInferValue(py_vals);
  506. if (py::isinstance<py::none>(py_ret)) {
  507. return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
  508. }
  509. // Convert pyobject to Value, then to AbstractValue
  510. ValuePtr converted_ret = nullptr;
  511. TypePtr dtype = abs_base->BuildType();
  512. bool converted = parse::ConvertData(py_ret, &converted_ret, false, dtype);
  513. if (!converted) {
  514. MS_LOG(EXCEPTION) << "Convert data failed";
  515. }
  516. auto res_spec = FromValue(converted_ret);
  517. MS_EXCEPTION_IF_NULL(res_spec);
  518. if (res_spec->isa<AbstractTensor>()) {
  519. // Replace to tensor constant node in specialize
  520. auto res_tensor = res_spec->cast<AbstractTensorPtr>();
  521. res_tensor->set_value(converted_ret);
  522. }
  523. return std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
  524. }
  525. EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
  526. if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) {
  527. auto ret_abstract = AbstractEval(args);
  528. if (ret_abstract != nullptr) {
  529. MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined";
  530. return ret_abstract;
  531. }
  532. }
  533. if (prim_->prim_type() == PrimType::kPrimTypePyInferCheck) {
  534. return EvalPyCheckPrim(engine, args);
  535. }
  536. prim_->BeginRecordAddAttr();
  537. AbstractBasePtr abs_base = eval_impl_(engine, prim_, args);
  538. prim_->EndRecordAddAttr();
  539. auto added_attrs = prim_->evaluate_added_attrs();
  540. auto eval_result = std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
  541. return eval_result;
  542. }
  543. EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
  544. auto ret_abstract = AbstractEval(args);
  545. if (ret_abstract != nullptr) {
  546. MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined";
  547. return ret_abstract;
  548. }
  549. MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString();
  550. const auto &iter = cache_->find(args);
  551. if (iter != cache_->end()) {
  552. return iter->second;
  553. }
  554. auto py_args = PreparePyInputs(prim_py_, args);
  555. prim_py_->BeginRecordAddAttr();
  556. py::dict output = prim_py_->RunInfer(py_args);
  557. prim_py_->EndRecordAddAttr();
  558. auto added_attrs = prim_py_->evaluate_added_attrs();
  559. MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output);
  560. auto res_spec = PyInferRes2Abstract(prim_py_, output);
  561. MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << ".";
  562. auto infer_result = std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
  563. (*cache_)[args] = infer_result;
  564. return infer_result;
  565. }
  566. EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
  567. auto ret_abstract = AbstractEval(args);
  568. if (ret_abstract != nullptr) {
  569. MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined";
  570. return ret_abstract;
  571. }
  572. // if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type.
  573. if (nargs_ != args.size()) {
  574. MS_LOG(EXCEPTION) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size() << " inputs";
  575. }
  576. TypePtr ret_value_type = return_value_type_;
  577. ValuePtrList value_list;
  578. for (const auto &arg : args) {
  579. // Check if all arguments are scalar type.
  580. MS_EXCEPTION_IF_NULL(arg);
  581. if (arg->isa<AbstractScalar>()) {
  582. auto arg_scalar = dyn_cast<AbstractScalar>(arg);
  583. auto arg_value = arg_scalar->GetValueTrack();
  584. value_list.push_back(arg_value);
  585. } else {
  586. // Raise TypeError Expected Scalar.
  587. MS_LOG(EXCEPTION) << "Expect scalar arguments for uniform primitives.";
  588. }
  589. }
  590. for (const auto &item : type_map_) {
  591. TypePtrList selections;
  592. MS_EXCEPTION_IF_NULL(item.second);
  593. (void)std::transform(item.second->begin(), item.second->end(), std::back_inserter(selections),
  594. [&args](size_t arg_idx) -> TypePtr { return args[arg_idx]->GetTypeTrack(); });
  595. TypePtr res = CheckTypeList(item.first, selections);
  596. if (*return_value_type_ == *(item.first)) {
  597. ret_value_type = res;
  598. }
  599. }
  600. ValuePtr evaluated_value = RunImpl(value_list);
  601. if (!(*evaluated_value == *kAnyValue)) {
  602. ret_value_type = evaluated_value->type();
  603. }
  604. // for comparison primitives , return type shall have be specified to be bool.
  605. if (specify_out_type_ != nullptr) {
  606. ret_value_type = specify_out_type_;
  607. }
  608. AbstractScalarPtr abs_base = std::make_shared<AbstractScalar>(evaluated_value, ret_value_type);
  609. return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>());
  610. }
  611. ValuePtr UniformPrimEvaluator::RunImpl(const ValuePtrList &args) const {
  612. if (!eval_value_) {
  613. return kAnyValue;
  614. } else {
  615. if (std::any_of(args.begin(), args.end(), [](const ValuePtr &arg) {
  616. MS_EXCEPTION_IF_NULL(arg);
  617. return arg->isa<AnyValue>();
  618. })) {
  619. return kAnyValue;
  620. }
  621. return impl_(args);
  622. }
  623. }
  624. // Primitive implementation
  625. // static function start
  626. namespace {
  627. EvaluatorPtr InitStandardPrimEvaluator(PrimitivePtr primitive, const StandardPrimitiveEvalImpl eval_impl) {
  628. EvaluatorPtr prim_evaluator = std::make_shared<StandardPrimEvaluator>(primitive, eval_impl);
  629. return prim_evaluator;
  630. }
  631. EvaluatorPtr InitUniformPrimEvaluator(const PrimitivePtr &primitive, PrimitiveImpl prim_impl, bool eval_value,
  632. const TypePtr &specify_out_type) {
  633. FunctionPtr func = nullptr;
  634. (void)prim::PrimToFunction::GetInstance().GetFunction(primitive, &func);
  635. MS_EXCEPTION_IF_NULL(func);
  636. EvaluatorPtr uniform_primitive_evaluator =
  637. std::make_shared<UniformPrimEvaluator>(func, prim_impl, eval_value, specify_out_type);
  638. return uniform_primitive_evaluator;
  639. }
  640. const int64_t kResolveCaseUserDefineClass = 1;
  641. const int64_t kResolveCaseBuiltInType = 2;
  642. const int64_t kResolveCaseFunction = 3;
  643. int64_t GetResolveCase(const TypePtr &data_type) {
  644. MS_EXCEPTION_IF_NULL(data_type);
  645. if (data_type->type_id() == kObjectTypeClass) {
  646. return kResolveCaseUserDefineClass;
  647. }
  648. // try method map, if not in method map, the data_type should be External type.
  649. if (pipeline::Resource::IsTypeInBuiltInMap(data_type->type_id())) {
  650. return kResolveCaseBuiltInType;
  651. }
  652. return kResolveCaseFunction;
  653. }
  654. FuncGraphPtr PyObjToGraph(const AnalysisEnginePtr &engine, const ValuePtr &method) {
  655. MS_EXCEPTION_IF_NULL(engine);
  656. MS_EXCEPTION_IF_NULL(method);
  657. if (!method->isa<parse::PyObjectWrapper>()) {
  658. MS_LOG(EXCEPTION) << "Method type error: " << method->ToString();
  659. }
  660. std::shared_ptr<PyObjectWrapper> obj = method->cast<std::shared_ptr<PyObjectWrapper>>();
  661. FuncGraphPtr func_graph = mindspore::parse::ConvertToFuncGraph(obj->obj());
  662. if (func_graph == nullptr) {
  663. MS_LOG(EXCEPTION) << "Parse python object: " << method->ToString() << " failed";
  664. }
  665. FuncGraphManagerPtr manager = engine->func_graph_manager();
  666. manager->AddFuncGraph(func_graph);
  667. return func_graph;
  668. }
  669. inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr func_graph) {
  670. MS_EXCEPTION_IF_NULL(engine);
  671. FuncGraphManagerPtr manager = engine->func_graph_manager();
  672. manager->AddFuncGraph(func_graph);
  673. }
  674. enum REQUIRE_TYPE { ATTR, METHOD };
  675. EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, const AnfNodeConfigPtr &old_conf,
  676. REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD) {
  677. MS_EXCEPTION_IF_NULL(old_conf);
  678. AbstractBasePtr abs_ptr = ToAbstract(value, AnalysisContext::DummyContext(), old_conf);
  679. AbstractFunctionPtr abs_func = dyn_cast<abstract::AbstractFunction>(abs_ptr);
  680. MS_EXCEPTION_IF_NULL(abs_func);
  681. // Create new cnode
  682. std::vector<AnfNodePtr> input = {NewValueNode(prim::kPrimPartial)};
  683. auto func_graph_func = dyn_cast<abstract::FuncGraphAbstractClosure>(abs_func);
  684. if (func_graph_func != nullptr) {
  685. FuncGraphPtr fg = func_graph_func->func_graph();
  686. input.push_back(NewValueNode(fg));
  687. } else {
  688. auto prim_func = dyn_cast<abstract::PrimitiveAbstractClosure>(abs_func);
  689. MS_EXCEPTION_IF_NULL(prim_func);
  690. PrimitivePtr prim = prim_func->prim();
  691. input.push_back(NewValueNode(prim));
  692. }
  693. AnfNodeConfigPtr conf = dyn_cast<abstract::AnfNodeConfig>(data_conf);
  694. MS_EXCEPTION_IF_NULL(conf);
  695. input.push_back(conf->node());
  696. MS_EXCEPTION_IF_NULL(old_conf);
  697. FuncGraphPtr func_graph = old_conf->node()->func_graph();
  698. CNodePtr new_cnode = func_graph->NewCNode(input);
  699. if (require_type == REQUIRE_TYPE::ATTR) {
  700. new_cnode = func_graph->NewCNode({new_cnode});
  701. }
  702. AnalysisEnginePtr eng = old_conf->engine();
  703. AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, old_conf->context());
  704. return eng->ForwardConfig(old_conf, fn_conf);
  705. }
  706. EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engine,
  707. const AbstractBasePtrList &args_spec_list,
  708. const AnfNodeConfigPtr &out_conf) {
  709. // args_spec_list: same as StaticGetter
  710. if (args_spec_list.size() < 2) {
  711. MS_LOG(EXCEPTION) << "Size of args_spec_list is less than 2";
  712. }
  713. MS_EXCEPTION_IF_NULL(out_conf);
  714. // An external type.
  715. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  716. MS_EXCEPTION_IF_NULL(args_spec_list[1]);
  717. MS_LOG(DEBUG) << "Args[0]: " << args_spec_list[0]->ToString();
  718. MS_LOG(DEBUG) << "Args[1]: " << args_spec_list[1]->ToString();
  719. auto data_v = args_spec_list[0]->BuildValue();
  720. if (!data_v->isa<parse::NameSpace>()) {
  721. MS_LOG(EXCEPTION) << "Data is not NameSpace : " << data_v->ToString();
  722. }
  723. auto item_v = args_spec_list[1]->BuildValue();
  724. if (item_v->isa<StringImm>()) {
  725. item_v = std::make_shared<parse::Symbol>(item_v->cast<StringImmPtr>()->value());
  726. }
  727. if (!item_v->isa<parse::Symbol>()) {
  728. MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_v->ToString();
  729. }
  730. // item_name to func addr from obj_map
  731. parse::SymbolPtr symbol = item_v->cast<parse::SymbolPtr>();
  732. parse::NameSpacePtr name_space = data_v->cast<parse::NameSpacePtr>();
  733. auto out_node = out_conf->node();
  734. FuncGraphPtr func_graph = out_node->func_graph();
  735. auto new_node = parse::ResolveSymbol(func_graph->manager(), name_space, symbol, out_node);
  736. if (new_node == nullptr) {
  737. MS_LOG(EXCEPTION) << "Resolve node failed";
  738. }
  739. // Replace old node with the resolved new node in order list.
  740. func_graph->ReplaceInOrder(out_node, new_node);
  741. AnalysisEnginePtr eng = out_conf->engine();
  742. AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context());
  743. return eng->ForwardConfig(out_conf, fn_conf);
  744. }
  745. EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine,
  746. const AbstractBasePtrList &args_spec_list, const ValuePtr &item_v,
  747. const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
  748. if (args_spec_list.empty()) {
  749. MS_LOG(EXCEPTION) << "args_spec_list is empty";
  750. }
  751. AbstractClassPtr cls = CheckArg<AbstractClass>("__FUNC__", args_spec_list, 0);
  752. // If item_v is an attribute, get abstract value from AbstractClass
  753. MS_EXCEPTION_IF_NULL(item_v);
  754. if (!item_v->isa<StringImm>()) {
  755. MS_LOG(EXCEPTION) << "Attribute type error";
  756. }
  757. std::string item_name = item_v->cast<StringImmPtr>()->value();
  758. MS_LOG(DEBUG) << "Resolve name: " << cls->tag().name();
  759. MS_LOG(DEBUG) << "Resolve item: " << item_name;
  760. AbstractBasePtr attr = cls->GetAttribute(item_name);
  761. if (attr != nullptr) {
  762. return std::make_shared<EvalResult>(attr, nullptr);
  763. }
  764. ValuePtr method = cls->GetMethod(item_name);
  765. if (method->isa<AnyValue>()) {
  766. MS_EXCEPTION(AttributeError) << "Unknown field, data type: " << args_spec_list[0]->BuildType()->ToString()
  767. << ", item value: " << item_v->ToString();
  768. }
  769. // Infer class method
  770. ValuePtr converted_v = PyObjToGraph(engine, method);
  771. return StaticGetterInferred(converted_v, data_conf, out_conf);
  772. }
  773. EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v,
  774. const TypePtr &data_type, const ConfigPtr &data_conf,
  775. const AnfNodeConfigPtr &out_conf) {
  776. MS_EXCEPTION_IF_NULL(item_v);
  777. MS_EXCEPTION_IF_NULL(data_type);
  778. // The method maybe a Primitive or Composite
  779. if (!item_v->isa<StringImm>()) {
  780. MS_LOG(EXCEPTION) << "Error item is not string";
  781. }
  782. std::string item_name = item_v->cast<StringImmPtr>()->value();
  783. REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD;
  784. Any require = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name);
  785. if (require.empty()) {
  786. require = pipeline::Resource::GetAttrPtr(data_type->type_id(), item_name);
  787. if (require.empty()) {
  788. MS_LOG(EXCEPTION) << "The object of type: " << data_type->ToString() << " has no method or attr: " << item_name;
  789. }
  790. require_type = REQUIRE_TYPE::ATTR;
  791. }
  792. ValuePtr converted_v = nullptr;
  793. if (require.is<std::string>()) {
  794. // composite registered in standard_method_map go to this branch
  795. converted_v = prim::GetPythonOps(require.cast<std::string>());
  796. if (!converted_v->isa<Primitive>()) {
  797. AddToManager(engine, converted_v->cast<FuncGraphPtr>());
  798. }
  799. } else if (require.is<PrimitivePtr>()) {
  800. converted_v = require.cast<PrimitivePtr>();
  801. } else {
  802. MS_LOG(EXCEPTION) << "Expect to get string or PrimitivePtr from attr or method map, but got " << require.ToString();
  803. }
  804. return StaticGetterInferred(converted_v, data_conf, out_conf, require_type);
  805. }
  806. EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
  807. const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
  808. // Inputs: namespace and its static function; or class and its member function
  809. CheckArgsSize("StaticGetter", args_spec_list, 2);
  810. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  811. MS_EXCEPTION_IF_NULL(args_spec_list[1]);
  812. TypePtr data_type = args_spec_list[0]->BuildType();
  813. ValuePtr item_value = args_spec_list[1]->BuildValue();
  814. ScopePtr scope = kDefaultScope;
  815. if (out_conf != nullptr) {
  816. scope = out_conf->node()->scope();
  817. }
  818. ScopeGuard scope_guard(scope);
  819. if (item_value->isa<AnyValue>()) {
  820. MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
  821. }
  822. int64_t case_v = GetResolveCase(data_type);
  823. if (case_v == kResolveCaseUserDefineClass) {
  824. return GetEvaluatedValueForClassAttrOrMethod(engine, args_spec_list, item_value, data_conf, out_conf);
  825. } else if (case_v == kResolveCaseBuiltInType) {
  826. return GetEvaluatedValueForBuiltinTypeAttrOrMethod(engine, item_value, data_type, data_conf, out_conf);
  827. } else {
  828. return GetEvaluatedValueForNameSpaceString(engine, args_spec_list, out_conf);
  829. }
  830. }
  831. } // end anonymous namespace
  832. // static variable start;
  833. namespace {
  834. class EmbedEvaluator : public SymbolicPrimEvaluator {
  835. public:
  836. EmbedEvaluator() : SymbolicPrimEvaluator("EmbedEvaluator") {}
  837. ~EmbedEvaluator() override = default;
  838. MS_DECLARE_PARENT(EmbedEvaluator, SymbolicPrimEvaluator);
  839. EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override {
  840. // arg: free variable to be embedded
  841. if (args_conf_list.size() != 1) {
  842. MS_LOG(EXCEPTION) << "EmbedEvaluator requires 1 parameter, but got " << args_conf_list.size();
  843. }
  844. AnfNodeConfigPtr node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]);
  845. MS_EXCEPTION_IF_NULL(node_conf);
  846. AbstractBasePtr x = node_conf->GetEvaluatedValue()->abstract();
  847. x = SensitivityTransform(x);
  848. SymbolicKeyInstancePtr key = std::make_shared<SymbolicKeyInstance>(node_conf->node(), x);
  849. AbstractScalarPtr abs_scalar = std::make_shared<AbstractScalar>(key, std::make_shared<SymbolicKeyType>());
  850. return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
  851. }
  852. };
  853. static AnfNodePtr FindParameterNodeByString(const FuncGraphManagerPtr &manager, const std::string &name) {
  854. auto root_g_set = manager->roots();
  855. if (root_g_set.size() != 1) {
  856. return nullptr;
  857. }
  858. const FuncGraphPtr &root_g = root_g_set.back();
  859. for (auto &param_node : root_g->parameters()) {
  860. auto param = param_node->cast<ParameterPtr>();
  861. if (param && name == param->name()) {
  862. return param;
  863. }
  864. }
  865. return nullptr;
  866. }
  867. class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
  868. public:
  869. RefToEmbedEvaluator() : SymbolicPrimEvaluator("RefToEmbedEvaluator") {}
  870. ~RefToEmbedEvaluator() override = default;
  871. MS_DECLARE_PARENT(RefToEmbedEvaluator, SymbolicPrimEvaluator);
  872. EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override {
  873. if (args_conf_list.size() != 1) {
  874. MS_LOG(ERROR) << "Requires 1 parameter, but has: " << args_conf_list.size();
  875. return nullptr;
  876. }
  877. static TypePtr type = std::make_shared<SymbolicKeyType>();
  878. auto node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]);
  879. if (node_conf == nullptr) {
  880. MS_LOG(ERROR) << "Conf should be AnfNodeConfig";
  881. return nullptr;
  882. }
  883. AbstractBasePtr abs = node_conf->GetEvaluatedValue()->abstract();
  884. AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>();
  885. if (ref_abs == nullptr) {
  886. MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString();
  887. return nullptr;
  888. }
  889. auto key_abs = ref_abs->ref_key();
  890. if (key_abs == nullptr) {
  891. MS_LOG(ERROR) << "RefToEmbed input Ref key is nullptr.";
  892. return nullptr;
  893. }
  894. auto key_value = key_abs->BuildValue();
  895. if (key_value == nullptr) {
  896. MS_LOG(ERROR) << "RefToEmbed input Ref key value is nullptr.";
  897. return nullptr;
  898. }
  899. auto refkey = key_value->cast<RefKeyPtr>();
  900. if (refkey == nullptr) {
  901. auto ret = std::make_shared<AbstractScalar>(type);
  902. auto ref_value = ref_abs->ref();
  903. MS_EXCEPTION_IF_NULL(ref_value);
  904. return std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
  905. }
  906. std::string name = refkey->tag();
  907. const auto &manager = node_conf->node()->func_graph()->manager();
  908. auto node = FindParameterNodeByString(manager, name);
  909. if (node == nullptr) {
  910. MS_LOG(ERROR) << "RefToEmbed input can't find parameter \"" << name << "\" in graph.";
  911. return nullptr;
  912. }
  913. AbstractBasePtr x = ref_abs->ref();
  914. x = SensitivityTransform(x);
  915. std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x);
  916. std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type);
  917. return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
  918. }
  919. };
  920. class GetAttrEvaluator : public TransitionPrimEvaluator {
  921. public:
  922. GetAttrEvaluator() : TransitionPrimEvaluator("GetAttrEvaluator") {}
  923. ~GetAttrEvaluator() override = default;
  924. MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator);
  925. EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
  926. const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
  927. auto ret_abstract = AbstractEval(args_spec_list);
  928. if (ret_abstract != nullptr) {
  929. MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined";
  930. return ret_abstract;
  931. }
  932. // Inputs: data, item
  933. if (args_spec_list.size() != 2) {
  934. MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
  935. }
  936. EvalResultPtr ret = nullptr;
  937. if (bound_node() != nullptr) {
  938. TraceGuard trace_guard(std::make_shared<TraceResolve>(bound_node()->debug_info()));
  939. ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
  940. } else {
  941. ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
  942. }
  943. // don't lookup from cache, as different out_conf with same node but different context
  944. // may add different entry to anfnode_config_map, like getattr primitive;
  945. (*cache_)[args_spec_list] = ret;
  946. return ret;
  947. }
  948. };
  949. class ResolveEvaluator : public TransitionPrimEvaluator {
  950. public:
  951. ResolveEvaluator() : TransitionPrimEvaluator("ResolveEvaluator") {}
  952. ~ResolveEvaluator() override = default;
  953. MS_DECLARE_PARENT(ResolveEvaluator, TransitionPrimEvaluator);
  954. EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
  955. const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
  956. // Inputs: namespace, symbol
  957. if (args_spec_list.size() != 2) {
  958. MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
  959. }
  960. EvalResultPtr ret = nullptr;
  961. if (bound_node() != nullptr) {
  962. TraceGuard trace_guard(std::make_shared<TraceResolve>(bound_node()->debug_info()));
  963. ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
  964. } else {
  965. ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
  966. }
  967. return ret;
  968. }
  969. };
  970. class CreateInstanceEvaluator : public TransitionPrimEvaluator {
  971. public:
  972. CreateInstanceEvaluator() : TransitionPrimEvaluator("CreateInstanceEvaluator") {}
  973. ~CreateInstanceEvaluator() override = default;
  974. MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator);
  975. EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
  976. const AnfNodeConfigPtr &out_conf) override {
  977. if (args_spec_list.empty()) {
  978. MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty";
  979. }
  980. // get the type parameter
  981. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  982. TypePtr type = args_spec_list[0]->GetTypeTrack();
  983. if (type->type_id() != kMetaTypeTypeType) {
  984. MS_LOG(EXCEPTION) << "CreateInstanceEvaluator require first parameter should be an object of TypeType, but got "
  985. << type->ToString();
  986. }
  987. ValuePtr value_track = args_spec_list[0]->GetValueTrack();
  988. MS_EXCEPTION_IF_NULL(value_track);
  989. std::shared_ptr<parse::PyObjectWrapper> type_obj = dyn_cast<parse::PyObjectWrapper>(value_track);
  990. if (type_obj == nullptr) {
  991. MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << ".";
  992. }
  993. if (!type_obj->isa<parse::ClassType>()) {
  994. MS_LOG(EXCEPTION) << "CreateInstanceEvaluator the type_obj should be an object of ClassType, but got "
  995. << type_obj->ToString() << ".";
  996. }
  997. auto class_type = type_obj->obj();
  998. MS_LOG(DEBUG) << "Get class type is " << type_obj->ToString() << ".";
  999. // get the create instance obj's parameters
  1000. pybind11::tuple params = GetParameters(args_spec_list);
  1001. // create class instance
  1002. auto obj = parse::data_converter::CreatePythonObject(class_type, params);
  1003. if (py::isinstance<py::none>(obj)) {
  1004. MS_LOG(EXCEPTION) << "Create python object" << py::str(class_type)
  1005. << " failed, only support create Cell or Primitive object.";
  1006. }
  1007. // process the object
  1008. ValuePtr converted_ret = nullptr;
  1009. bool converted = parse::ConvertData(obj, &converted_ret, true);
  1010. if (!converted) {
  1011. MS_LOG(EXCEPTION) << "Convert the python object failed";
  1012. }
  1013. MS_EXCEPTION_IF_NULL(converted_ret);
  1014. if (converted_ret->isa<FuncGraph>()) {
  1015. AddToManager(engine, converted_ret->cast<FuncGraphPtr>());
  1016. }
  1017. AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf);
  1018. auto infer_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
  1019. (*cache_)[args_spec_list] = infer_result;
  1020. return infer_result;
  1021. }
  1022. pybind11::tuple GetParameters(const AbstractBasePtrList &args_spec_list) const {
  1023. // Exclude class type by minus 1;
  1024. std::size_t params_size = args_spec_list.size() - 1;
  1025. auto params = py::tuple(params_size);
  1026. if (params_size > 0) {
  1027. for (size_t i = 0; i < params_size; i++) {
  1028. // Only support the Scalar parameters type. Bypass class type by offset with 1.
  1029. auto arg = args_spec_list[i + 1];
  1030. MS_EXCEPTION_IF_NULL(arg);
  1031. // Because the Tensor's AbstractTensor can't get value from GetValueTrack.
  1032. ValuePtr param_value = arg->BuildValue();
  1033. py::object param = ValuePtrToPyData(param_value);
  1034. params[i] = param;
  1035. }
  1036. }
  1037. return params;
  1038. }
  1039. };
  1040. class PartialEvaluator : public Evaluator {
  1041. public:
  1042. PartialEvaluator() : Evaluator("PartialEvaluator") {}
  1043. ~PartialEvaluator() override = default;
  1044. EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
  1045. AnfNodeConfigPtr out_conf = nullptr) override {
  1046. if (args_conf_list.size() == 0) {
  1047. MS_LOG(EXCEPTION) << "Args size should be greater than 0";
  1048. }
  1049. MS_EXCEPTION_IF_NULL(out_conf);
  1050. MS_EXCEPTION_IF_NULL(out_conf->node());
  1051. auto arg0_value = args_conf_list[0]->GetEvaluatedValue()->abstract();
  1052. AbstractBasePtrList args_spec_list{arg0_value};
  1053. // Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node.
  1054. if (arg0_value->isa<AbstractError>()) {
  1055. auto ret = std::make_shared<AbstractError>(arg0_value->GetValueTrack()->cast<StringImmPtr>(), out_conf->node());
  1056. MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString()
  1057. << " as func is: " << arg0_value->ToString();
  1058. auto eval_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
  1059. (*cache_)[args_spec_list] = eval_result;
  1060. return eval_result;
  1061. }
  1062. auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0);
  1063. // Sometimes, node[0] in out_conf becomes phi0;
  1064. if (func->isa<PrimitiveAbstractClosure>()) {
  1065. auto prim_func = dyn_cast<PrimitiveAbstractClosure>(func);
  1066. if (prim_func->prim()->isa<prim::DoSignaturePrimitive>()) {
  1067. prim::DoSignaturePrimitivePtr do_signature_prim = dyn_cast<prim::DoSignaturePrimitive>(prim_func->prim());
  1068. return HandleDoSignature(engine, do_signature_prim->function(), out_conf);
  1069. }
  1070. }
  1071. std::vector<EvalResultPtr> eval_result_list;
  1072. (void)std::transform(args_conf_list.cbegin() + 1, args_conf_list.cend(), std::back_inserter(eval_result_list),
  1073. [](const ConfigPtr &config) -> EvalResultPtr { return config->GetEvaluatedValue(); });
  1074. (void)std::transform(eval_result_list.cbegin(), eval_result_list.cend(), std::back_inserter(args_spec_list),
  1075. [](const EvalResultPtr &eval_result) -> AbstractBasePtr { return eval_result->abstract(); });
  1076. AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end());
  1077. auto cnode = out_conf->node()->cast<CNodePtr>();
  1078. MS_EXCEPTION_IF_NULL(cnode);
  1079. if (cnode->size() != (args_conf_list.size() + 1)) {
  1080. MS_LOG(EXCEPTION) << "Out_conf node: " << cnode->DebugString()
  1081. << ", args_conf_list: " << mindspore::ToString(args_conf_list);
  1082. }
  1083. auto flag = std::any_of(eval_result_list.cbegin(), eval_result_list.cend(), [](const EvalResultPtr &eval_result) {
  1084. MS_LOG(DEBUG) << "Propagate isolate nodes flag from: " << eval_result->abstract()->ToString()
  1085. << ", flag: " << eval_result->HasIsolateNodesPropagateCNodeFlag();
  1086. return eval_result->HasIsolateNodesPropagateCNodeFlag();
  1087. });
  1088. AbstractFuncAtomPtrList partial_funcs_list;
  1089. auto build_partial = [args, cnode, flag, &partial_funcs_list](const AbstractFuncAtomPtr &atom_func) {
  1090. auto new_func = std::make_shared<PartialAbstractClosure>(atom_func, args, cnode);
  1091. partial_funcs_list.push_back(new_func);
  1092. if (atom_func->HasIsolateNodesFlag() || flag) {
  1093. new_func->SetIsolateNodesFlag(true);
  1094. }
  1095. };
  1096. func->Visit(build_partial);
  1097. auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list);
  1098. auto eval_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
  1099. (*cache_)[args_spec_list] = eval_result;
  1100. return eval_result;
  1101. }
  1102. EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
  1103. MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
  1104. }
  1105. EvalResultPtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value,
  1106. const AnfNodeConfigPtr &out_conf = nullptr) const {
  1107. MS_EXCEPTION_IF_NULL(out_conf);
  1108. MS_EXCEPTION_IF_NULL(out_conf->node());
  1109. auto cnode = out_conf->node()->cast<CNodePtr>();
  1110. if (cnode == nullptr) {
  1111. MS_LOG(EXCEPTION) << "Cnode is nullptr";
  1112. }
  1113. std::vector<AnfNodePtr> new_nodes_inputs = cnode->inputs();
  1114. auto new_signature_value = std::make_shared<prim::DoSignatureMetaFuncGraph>("signature", signature_value);
  1115. new_nodes_inputs[1] = NewValueNode(new_signature_value);
  1116. FuncGraphPtr func_graph = cnode->func_graph();
  1117. ScopePtr scope = out_conf->node()->scope();
  1118. ScopeGuard scope_guard(scope);
  1119. CNodePtr new_cnode = func_graph->NewCNode(new_nodes_inputs);
  1120. AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context());
  1121. return engine->ForwardConfig(out_conf, fn_conf);
  1122. }
  1123. };
  1124. struct PrimitiveImplInferValue {
  1125. PrimitiveImpl impl_; // implement function of primitive
  1126. bool eval_value_; // whether evaluate value
  1127. TypePtr specify_out_type_; // whether specify return type
  1128. bool in_white_list_; // true if this Primitive in white list, else false.
  1129. };
  1130. using PrimitiveToImplMap = std::unordered_map<PrimitivePtr, PrimitiveImplInferValue, PrimitiveHasher, PrimitiveEqual>;
  1131. PrimitiveToImplMap &GetUniformPrimitiveToImplMap() {
  1132. static PrimitiveToImplMap uniform_prim_implement_map = {
  1133. {prim::kPrimScalarAdd, {prim::ScalarAdd, true, nullptr, true}},
  1134. {prim::kPrimScalarSub, {prim::ScalarSub, true, nullptr, true}},
  1135. {prim::kPrimScalarMul, {prim::ScalarMul, true, nullptr, true}},
  1136. {prim::kPrimScalarDiv, {prim::ScalarDiv, true, nullptr, true}},
  1137. {prim::kPrimScalarMod, {prim::ScalarMod, true, nullptr, true}},
  1138. {prim::kPrimScalarPow, {prim::ScalarPow, true, nullptr, true}},
  1139. {prim::kPrimScalarFloordiv, {prim::ScalarFloordiv, true, nullptr, true}},
  1140. {prim::kPrimScalarUadd, {prim::ScalarUAdd, true, nullptr, true}},
  1141. {prim::kPrimScalarUsub, {prim::ScalarUSub, true, nullptr, true}},
  1142. {prim::kPrimScalarLog, {prim::ScalarLog, true, nullptr, true}},
  1143. {prim::kPrimScalarEq, {prim::ScalarEq, true, std::make_shared<Bool>(), true}},
  1144. {prim::kPrimScalarLt, {prim::ScalarLt, true, std::make_shared<Bool>(), true}},
  1145. {prim::kPrimScalarGt, {prim::ScalarGt, true, std::make_shared<Bool>(), true}},
  1146. {prim::kPrimScalarNe, {prim::ScalarNe, true, std::make_shared<Bool>(), true}},
  1147. {prim::kPrimScalarLe, {prim::ScalarLe, true, std::make_shared<Bool>(), true}},
  1148. {prim::kPrimScalarGe, {prim::ScalarGe, true, std::make_shared<Bool>(), true}},
  1149. {prim::kPrimBoolNot, {prim::BoolNot, true, std::make_shared<Bool>(), true}},
  1150. {prim::kPrimBoolAnd, {prim::BoolAnd, true, std::make_shared<Bool>(), true}},
  1151. {prim::kPrimBoolEq, {prim::BoolEq, true, std::make_shared<Bool>(), true}},
  1152. {prim::kPrimBoolOr, {prim::BoolOr, true, std::make_shared<Bool>(), true}},
  1153. };
  1154. return uniform_prim_implement_map;
  1155. }
  1156. PrimEvaluatorMap PrimEvaluatorConstructors = PrimEvaluatorMap();
  1157. std::mutex PrimEvaluatorConstructorMutex;
  1158. void InitPrimEvaluatorConstructors() {
  1159. PrimEvaluatorMap &constructor = PrimEvaluatorConstructors;
  1160. for (const auto &iter : GetPrimitiveToEvalImplMap()) {
  1161. constructor[iter.first] = InitStandardPrimEvaluator(iter.first, iter.second.impl_);
  1162. }
  1163. for (const auto &iter : GetUniformPrimitiveToImplMap()) {
  1164. constructor[iter.first] =
  1165. InitUniformPrimEvaluator(iter.first, iter.second.impl_, iter.second.eval_value_, iter.second.specify_out_type_);
  1166. }
  1167. constructor[prim::kPrimEmbed] = std::make_shared<EmbedEvaluator>();
  1168. constructor[prim::kPrimRefToEmbed] = std::make_shared<RefToEmbedEvaluator>();
  1169. constructor[prim::kPrimGetAttr] = std::make_shared<GetAttrEvaluator>();
  1170. constructor[prim::kPrimResolve] = std::make_shared<ResolveEvaluator>();
  1171. constructor[prim::kPrimCreateInstance] = std::make_shared<CreateInstanceEvaluator>();
  1172. constructor[prim::kPrimPartial] = std::make_shared<PartialEvaluator>();
  1173. }
  1174. } // namespace
  1175. void ClearPrimEvaluatorMap() {
  1176. PrimEvaluatorConstructors.clear();
  1177. GetPrimitiveToEvalImplMap().clear();
  1178. GetUniformPrimitiveToImplMap().clear();
  1179. }
  1180. bool IsInWhiteList(const PrimitivePtr &primitive) {
  1181. MS_EXCEPTION_IF_NULL(primitive);
  1182. auto iter = GetPrimitiveToEvalImplMap().find(primitive);
  1183. if (iter != GetPrimitiveToEvalImplMap().end()) {
  1184. return iter->second.in_white_list_;
  1185. }
  1186. auto uni_iter = GetUniformPrimitiveToImplMap().find(primitive);
  1187. if (uni_iter != GetUniformPrimitiveToImplMap().end()) {
  1188. return uni_iter->second.in_white_list_;
  1189. }
  1190. return false;
  1191. }
  1192. StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive) {
  1193. MS_EXCEPTION_IF_NULL(primitive);
  1194. auto iter = GetPrimitiveToEvalImplMap().find(primitive);
  1195. if (iter == GetPrimitiveToEvalImplMap().end()) {
  1196. return nullptr;
  1197. }
  1198. return iter->second.impl_;
  1199. }
  1200. PrimEvaluatorMap &GetPrimEvaluatorConstructors() {
  1201. PrimEvaluatorMap &constructor = PrimEvaluatorConstructors;
  1202. if (!constructor.empty()) {
  1203. return constructor;
  1204. }
  1205. std::lock_guard<std::mutex> initLock(PrimEvaluatorConstructorMutex);
  1206. if (constructor.empty()) {
  1207. InitPrimEvaluatorConstructors();
  1208. }
  1209. return constructor;
  1210. }
  1211. namespace {
  1212. bool IsSubtypeTuple(const AbstractBasePtr x, const TypePtr model) {
  1213. MS_EXCEPTION_IF_NULL(x);
  1214. MS_EXCEPTION_IF_NULL(model);
  1215. auto x_tuple = dyn_cast<AbstractTuple>(x);
  1216. auto model_tuple = dyn_cast<Tuple>(model);
  1217. if (x_tuple == nullptr || model_tuple == nullptr) {
  1218. return false;
  1219. }
  1220. if (model->IsGeneric()) {
  1221. return true;
  1222. }
  1223. if (x_tuple->size() != model_tuple->size()) {
  1224. return false;
  1225. }
  1226. for (size_t i = 0; i < x_tuple->size(); i++) {
  1227. bool is_subtype = IsSubtype((*x_tuple)[i], (*model_tuple)[i]);
  1228. if (!is_subtype) {
  1229. return false;
  1230. }
  1231. }
  1232. return true;
  1233. }
  1234. bool IsSubtypeArray(const AbstractBasePtr x, const TypePtr model) {
  1235. MS_EXCEPTION_IF_NULL(x);
  1236. MS_EXCEPTION_IF_NULL(model);
  1237. auto x_tensor = dyn_cast<AbstractTensor>(x);
  1238. auto model_tensor = dyn_cast<TensorType>(model);
  1239. if (x_tensor == nullptr || model_tensor == nullptr) {
  1240. return false;
  1241. }
  1242. if (model->IsGeneric()) {
  1243. return true;
  1244. }
  1245. return IsSubtype(x_tensor->element(), model_tensor->element());
  1246. }
  1247. bool IsSubtypeList(const AbstractBasePtr x, const TypePtr model) {
  1248. MS_EXCEPTION_IF_NULL(x);
  1249. MS_EXCEPTION_IF_NULL(model);
  1250. auto x_list = dyn_cast<AbstractList>(x);
  1251. auto model_list = dyn_cast<List>(model);
  1252. if (x_list == nullptr || model_list == nullptr) {
  1253. return false;
  1254. }
  1255. if (model->IsGeneric()) {
  1256. return true;
  1257. }
  1258. if (x_list->size() != model_list->size()) {
  1259. return false;
  1260. }
  1261. bool is_subtype = true;
  1262. for (size_t i = 0; i < x_list->size(); i++) {
  1263. is_subtype = IsSubtype((*x_list)[i], (*model_list)[i]);
  1264. if (!is_subtype) {
  1265. return false;
  1266. }
  1267. }
  1268. return is_subtype;
  1269. }
  1270. bool IsSubtypeClass(const AbstractBasePtr x, const TypePtr model) {
  1271. MS_EXCEPTION_IF_NULL(x);
  1272. MS_EXCEPTION_IF_NULL(model);
  1273. auto x_class = dyn_cast<AbstractClass>(x);
  1274. auto model_class = dyn_cast<Class>(model);
  1275. if (x_class == nullptr) {
  1276. return false;
  1277. }
  1278. if (model->IsGeneric()) {
  1279. return true;
  1280. }
  1281. if (x_class->tag() == model_class->tag()) {
  1282. auto m_attributes = model_class->GetAttributes();
  1283. auto x_attributes = x_class->attributes();
  1284. if (m_attributes.size() != x_attributes.size()) {
  1285. return false;
  1286. }
  1287. for (size_t i = 0; i < m_attributes.size(); i++) {
  1288. if (!IsSubtype(x_attributes[i].second, m_attributes[i].second)) {
  1289. return false;
  1290. }
  1291. }
  1292. return true;
  1293. }
  1294. return false;
  1295. }
  1296. inline bool IsSubtypeScalar(const AbstractBasePtr x, const TypePtr model) {
  1297. MS_EXCEPTION_IF_NULL(x);
  1298. MS_EXCEPTION_IF_NULL(model);
  1299. if (dyn_cast<AbstractScalar>(x) == nullptr) {
  1300. return false;
  1301. }
  1302. TypePtr x_type = x->GetTypeTrack();
  1303. return IsSubType(x_type, model);
  1304. }
  1305. } // namespace
  1306. bool IsSubtype(const AbstractBasePtr x, const TypePtr model) {
  1307. MS_EXCEPTION_IF_NULL(x);
  1308. MS_EXCEPTION_IF_NULL(model);
  1309. TypeId model_typeid = model->type_id();
  1310. switch (model_typeid) {
  1311. case kMetaTypeObject:
  1312. return true;
  1313. case kObjectTypeTuple:
  1314. return IsSubtypeTuple(x, model);
  1315. case kObjectTypeTensorType:
  1316. return IsSubtypeArray(x, model);
  1317. case kObjectTypeList:
  1318. return IsSubtypeList(x, model);
  1319. case kObjectTypeClass:
  1320. return IsSubtypeClass(x, model);
  1321. default:
  1322. if (IsSubType(model, std::make_shared<Number>())) {
  1323. return IsSubtypeScalar(x, model);
  1324. }
  1325. MS_LOG(EXCEPTION) << "Invalid model type: " << model->ToString() << ".";
  1326. }
  1327. }
  1328. } // namespace abstract
  1329. } // namespace mindspore