You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prim.cc 72 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/jit/static_analysis/prim.h"
  19. #include <algorithm>
  20. #include <limits>
  21. #include <mutex>
  22. #include <string>
  23. #include <utility>
  24. #include "utils/hash_set.h"
  25. #include "frontend/operator/cc_implementations.h"
  26. #include "frontend/operator/ops.h"
  27. #include "frontend/operator/composite/do_signature.h"
  28. #include "frontend/operator/prim_to_function.h"
  29. #include "abstract/utils.h"
  30. #include "utils/symbolic.h"
  31. #include "pipeline/jit/resource.h"
  32. #include "pipeline/jit/parse/resolve.h"
  33. #include "pipeline/jit/pipeline.h"
  34. #include "utils/convert_utils.h"
  35. #include "utils/convert_utils_py.h"
  36. #include "utils/ms_context.h"
  37. #include "pipeline/jit/parse/data_converter.h"
  38. #include "abstract/primitive_infer_map.h"
  39. #include "abstract/param_validator.h"
  40. #include "utils/ms_utils.h"
  41. #include "utils/shape_utils.h"
  42. #include "utils/parallel_node_check.h"
  43. #include "frontend/operator/ops_front_infer_function.h"
  44. namespace mindspore {
  45. namespace abstract {
  46. using mindspore::parse::PyObjectWrapper;
  47. mindspore::HashSet<std::string> prims_to_skip_undetermined_infer{
  48. "MakeTuple", "make_list", "Switch", "env_setitem", "env_getitem", "Load", "UpdateState"};
  49. EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
  50. const AnfNodeConfigPtr &out_conf) {
  51. MS_EXCEPTION_IF_NULL(engine);
  52. MS_EXCEPTION_IF_NULL(out_conf);
  53. AbstractBasePtrList args_spec_list;
  54. (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
  55. [](const ConfigPtr &ref) -> AbstractBasePtr {
  56. MS_EXCEPTION_IF_NULL(ref);
  57. MS_EXCEPTION_IF_NULL(ref->ObtainEvalResult());
  58. return ref->ObtainEvalResult()->abstract();
  59. });
  60. auto do_signature = prim_->cast<prim::DoSignaturePrimitivePtr>();
  61. MS_EXCEPTION_IF_NULL(do_signature);
  62. auto &func = do_signature->function();
  63. if (func->isa<Primitive>()) {
  64. auto sig_prim = func->cast<PrimitivePtr>();
  65. if (prims_to_skip_undetermined_infer.find(sig_prim->name()) == prims_to_skip_undetermined_infer.end()) {
  66. auto ret_abstract = AbstractEval(args_spec_list);
  67. if (ret_abstract != nullptr) {
  68. MS_LOG(DEBUG) << "DoSignatureEvaluator eval Undetermined";
  69. return ret_abstract;
  70. }
  71. }
  72. }
  73. if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
  74. MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
  75. }
  76. auto out_node = dyn_cast<CNode>(out_conf->node());
  77. MS_EXCEPTION_IF_NULL(out_node);
  78. const auto &out_node_inputs = out_node->inputs();
  79. if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) {
  80. MS_LOG(EXCEPTION) << "Op: " << func->ToString() << " args size should equal to inputs size minus 1, but args size "
  81. << args_conf_list.size() << ", inputs size " << out_node_inputs.size();
  82. }
  83. AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
  84. ScopePtr scope = kDefaultScope;
  85. if (out_conf != nullptr) {
  86. scope = out_conf->node()->scope();
  87. }
  88. ScopeGuard scope_guard(scope);
  89. AnfNodePtr new_node = nullptr;
  90. if (bound_node() != nullptr) {
  91. TraceGuard trace_guard(std::make_shared<TraceDoSignature>(bound_node()->debug_info()));
  92. new_node = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), func, args_spec_list, args_inputs);
  93. } else {
  94. new_node = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), func, args_spec_list, args_inputs);
  95. }
  96. AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
  97. if (out_node->isa<CNode>()) {
  98. auto out_cnode = out_node->cast<CNodePtr>();
  99. auto new_cnode = new_node->cast<CNodePtr>();
  100. new_cnode->CloneCNodeInfo(out_cnode);
  101. }
  102. return engine->ForwardConfig(out_conf, fn_conf);
  103. }
  104. static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_spec_list, bool need_unpack) {
  105. // arg[0] is the func graph to unpack, ignore it
  106. AbstractBasePtrList specialize_args_before_unpack(args_spec_list.begin() + 1, args_spec_list.end());
  107. AbstractBasePtrList graph_specialize_args;
  108. if (need_unpack) {
  109. for (size_t index = 0; index < specialize_args_before_unpack.size(); index++) {
  110. MS_EXCEPTION_IF_NULL(specialize_args_before_unpack[index]);
  111. if (specialize_args_before_unpack[index]->isa<AbstractTuple>()) {
  112. auto arg_tuple = specialize_args_before_unpack[index]->cast<AbstractTuplePtr>();
  113. std::transform(arg_tuple->elements().begin(), arg_tuple->elements().end(),
  114. std::back_inserter(graph_specialize_args), [](AbstractBasePtr abs) { return abs; });
  115. } else if (specialize_args_before_unpack[index]->isa<AbstractDictionary>()) {
  116. auto arg_dict = specialize_args_before_unpack[index]->cast<AbstractDictionaryPtr>();
  117. auto dict_elems = arg_dict->elements();
  118. (void)std::transform(
  119. dict_elems.begin(), dict_elems.end(), std::back_inserter(graph_specialize_args),
  120. [](const AbstractAttribute &item) { return std::make_shared<AbstractKeywordArg>(item.first, item.second); });
  121. } else {
  122. MS_LOG(EXCEPTION) << "UnpackGraph require args should be tuple or dict, but got "
  123. << specialize_args_before_unpack[index]->ToString();
  124. }
  125. }
  126. } else {
  127. graph_specialize_args = specialize_args_before_unpack;
  128. }
  129. return graph_specialize_args;
  130. }
  131. EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
  132. const AnfNodeConfigPtr &out_conf) {
  133. MS_EXCEPTION_IF_NULL(engine);
  134. MS_EXCEPTION_IF_NULL(out_conf);
  135. MS_EXCEPTION_IF_NULL(out_conf->node());
  136. if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
  137. MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
  138. }
  139. auto unpack_graph = prim_->cast<prim::UnpackGraphPrimitivePtr>();
  140. MS_EXCEPTION_IF_NULL(unpack_graph);
  141. auto out_node = out_conf->node()->cast<CNodePtr>();
  142. MS_EXCEPTION_IF_NULL(out_node);
  143. const auto &out_node_inputs = out_node->inputs();
  144. if (out_node->inputs().empty() || (out_node_inputs.size() - 1) != args_conf_list.size()) {
  145. MS_LOG(EXCEPTION) << "UnpackGraphPrimitive"
  146. << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
  147. << ", inputs size " << out_node_inputs.size();
  148. }
  149. AbstractBasePtrList args_spec_list;
  150. (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
  151. [](const ConfigPtr &ref) -> AbstractBasePtr {
  152. MS_EXCEPTION_IF_NULL(ref);
  153. MS_EXCEPTION_IF_NULL(ref->ObtainEvalResult());
  154. return ref->ObtainEvalResult()->abstract();
  155. });
  156. // get the forward graph
  157. if (args_spec_list.empty()) {
  158. MS_LOG(EXCEPTION) << "args_spec_list can't be empty.";
  159. }
  160. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  161. auto fn = args_spec_list[0]->cast<AbstractFunctionPtr>();
  162. if (fn == nullptr) {
  163. MS_LOG(EXCEPTION) << "UnpackGraphPrimitive arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString();
  164. }
  165. auto real_fn = fn->cast<FuncGraphAbstractClosurePtr>();
  166. MS_EXCEPTION_IF_NULL(real_fn);
  167. FuncGraphPtr forward_graph = real_fn->func_graph();
  168. MS_EXCEPTION_IF_NULL(forward_graph);
  169. AbstractBasePtrList graph_specialize_args =
  170. GetUnpackGraphSpecArgsList(args_spec_list, unpack_graph->need_unpack_args());
  171. AbstractBasePtrList graph_specialize_args_without_sens;
  172. if (unpack_graph->with_sens_in_args() && graph_specialize_args.empty()) {
  173. MS_EXCEPTION(ValueError) << "Grad with sens, but the sens is not provided.";
  174. }
  175. (void)std::transform(graph_specialize_args.begin(),
  176. graph_specialize_args.end() - (unpack_graph->with_sens_in_args() ? 1 : 0),
  177. std::back_inserter(graph_specialize_args_without_sens), [](AbstractBasePtr abs) { return abs; });
  178. auto new_graph = forward_graph->GenerateGraph(graph_specialize_args_without_sens);
  179. engine->func_graph_manager()->AddFuncGraph(new_graph);
  180. ScopePtr scope = kDefaultScope;
  181. if (out_conf != nullptr) {
  182. scope = out_conf->node()->scope();
  183. }
  184. ScopeGuard scope_guard(scope);
  185. AnfNodePtr new_vnode = NewValueNode(new_graph);
  186. AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_vnode, out_conf->context(), out_conf->func_graph());
  187. return engine->ForwardConfig(out_conf, fn_conf);
  188. }
  189. AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const AbstractBasePtr &node_type,
  190. const AnfNodePtr &target_type, const FuncGraphPtr &func_graph) {
  191. MS_EXCEPTION_IF_NULL(node_type);
  192. MS_EXCEPTION_IF_NULL(func_graph);
  193. AnfNodePtr target_node = source_node;
  194. if (node_type->isa<AbstractTensor>()) {
  195. auto x = node_type->cast<AbstractTensorPtr>();
  196. if (x->element()->BuildType()->isa<Float>()) {
  197. auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional");
  198. MS_EXCEPTION_IF_NULL(cast);
  199. target_node = func_graph->NewCNodeAfter(source_node, {NewValueNode(cast), source_node, target_type});
  200. }
  201. } else if (node_type->isa<AbstractTuple>()) {
  202. auto x = node_type->cast<AbstractTuplePtr>();
  203. auto &items = x->elements();
  204. std::vector<AnfNodePtr> nodes;
  205. nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  206. int64_t idx = 0;
  207. for (const auto &item : items) {
  208. AnfNodePtr tuple_node =
  209. func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), source_node, NewValueNode(idx)});
  210. AnfNodePtr node = MixedPrecisionCastHelper(tuple_node, item, target_type, func_graph);
  211. nodes.emplace_back(node);
  212. ++idx;
  213. }
  214. target_node = func_graph->NewCNode(nodes);
  215. } else if (node_type->isa<AbstractDictionary>()) {
  216. auto x = node_type->cast<AbstractDictionaryPtr>();
  217. auto &items = x->elements();
  218. std::vector<AnfNodePtr> dict_key_nodes;
  219. std::vector<AnfNodePtr> dict_value_nodes;
  220. dict_key_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  221. dict_value_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  222. for (const auto &item : items) {
  223. AnfNodePtr dict_value_node =
  224. func_graph->NewCNode({NewValueNode(prim::kPrimDictGetItem), source_node, NewValueNode(item.first)});
  225. AnfNodePtr node = MixedPrecisionCastHelper(dict_value_node, item.second, target_type, func_graph);
  226. dict_key_nodes.emplace_back(NewValueNode(item.first));
  227. dict_value_nodes.emplace_back(node);
  228. }
  229. target_node =
  230. func_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), func_graph->NewCNode(std::move(dict_key_nodes)),
  231. func_graph->NewCNode(std::move(dict_value_nodes))});
  232. } else if (node_type->isa<AbstractKeywordArg>()) {
  233. auto x = node_type->cast<AbstractKeywordArgPtr>();
  234. std::string kwarg_key = x->get_key();
  235. AnfNodePtr kwarg_value_node =
  236. func_graph->NewCNode({NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kwarg_key), source_node});
  237. AnfNodePtr node = MixedPrecisionCastHelper(kwarg_value_node, x->get_arg(), target_type, func_graph);
  238. target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(kwarg_key), node});
  239. }
  240. return target_node;
  241. }
  242. EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
  243. const AnfNodeConfigPtr &out_conf) {
  244. MS_EXCEPTION_IF_NULL(engine);
  245. AbstractBasePtrList args_spec_list;
  246. MS_EXCEPTION_IF_NULL(out_conf);
  247. if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
  248. MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
  249. }
  250. auto out_node = out_conf->node()->cast<CNodePtr>();
  251. MS_EXCEPTION_IF_NULL(out_node);
  252. const auto &out_node_inputs = out_node->inputs();
  253. if (out_node->inputs().empty() || (out_node_inputs.size() - 1) != args_conf_list.size()) {
  254. MS_LOG(EXCEPTION) << "MixedPrecisionCast"
  255. << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
  256. << ", inputs size " << out_node_inputs.size();
  257. }
  258. (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
  259. [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->ObtainEvalResult()->abstract(); });
  260. ScopeGuard scope_guard(out_conf->node()->scope());
  261. TraceGuard trace_guard(std::make_shared<TraceMixedPrecision>(out_conf->node()->debug_info()));
  262. FuncGraphPtr func_graph = out_node->func_graph();
  263. constexpr size_t source_node_index = 2;
  264. if (out_node_inputs.size() <= source_node_index) {
  265. MS_LOG(EXCEPTION) << "Input size:" << out_node_inputs.size() << " should bigger than 2.";
  266. }
  267. AnfNodePtr new_node =
  268. MixedPrecisionCastHelper(out_node_inputs[source_node_index], args_spec_list[1], out_node_inputs[1], func_graph);
  269. AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
  270. if (new_node->isa<CNode>()) {
  271. auto new_cnode = new_node->cast<CNodePtr>();
  272. new_cnode->CloneCNodeInfo(out_node);
  273. }
  274. return engine->ForwardConfig(out_conf, fn_conf);
  275. }
  276. namespace {
  277. py::object BuildValue(const ValuePtr &value_ptr) {
  278. if (value_ptr == nullptr) {
  279. return py::none();
  280. } else {
  281. return ValueToPyData(value_ptr);
  282. }
  283. }
  284. py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
  285. auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
  286. MS_EXCEPTION_IF_NULL(arg_tuple);
  287. size_t len = arg_tuple->size();
  288. py::tuple shape_tuple(len);
  289. py::tuple dtype_tuple(len);
  290. py::tuple value_tuple(len);
  291. py::tuple min_value_tuple(len);
  292. py::tuple max_value_tuple(len);
  293. py::tuple min_shape_tuple(len);
  294. py::tuple max_shape_tuple(len);
  295. bool dyn_shape = false;
  296. bool dyn_value = false;
  297. for (size_t i = 0; i < len; i++) {
  298. auto arg = arg_tuple->elements()[i];
  299. py::dict out = ConvertAbstractToPython(arg);
  300. shape_tuple[i] = out[ATTR_SHAPE];
  301. dtype_tuple[i] = out[ATTR_DTYPE];
  302. value_tuple[i] = out[ATTR_VALUE];
  303. // Elements in tuple is tensor shape value.
  304. if (out.contains(py::str(ATTR_MIN_VALUE)) && out.contains(py::str(ATTR_MAX_VALUE))) {
  305. min_value_tuple[i] = out[ATTR_MIN_VALUE];
  306. max_value_tuple[i] = out[ATTR_MAX_VALUE];
  307. dyn_value = true;
  308. }
  309. // Elements in tuple is tensor, which shape is dynamic.
  310. if (out.contains(py::str(ATTR_MIN_SHAPE)) && out.contains(py::str(ATTR_MAX_SHAPE))) {
  311. min_shape_tuple[i] = out[ATTR_MIN_SHAPE];
  312. max_shape_tuple[i] = out[ATTR_MAX_SHAPE];
  313. dyn_shape = true;
  314. }
  315. }
  316. auto dic = py::dict();
  317. dic[ATTR_SHAPE] = shape_tuple;
  318. dic[ATTR_DTYPE] = dtype_tuple;
  319. MS_EXCEPTION_IF_NULL(arg_tuple->BuildValue());
  320. if (arg_tuple->BuildValue()->isa<AnyValue>()) {
  321. dic[ATTR_VALUE] = py::none();
  322. } else {
  323. dic[ATTR_VALUE] = value_tuple;
  324. }
  325. if (dyn_value) {
  326. dic[ATTR_MIN_VALUE] = min_value_tuple;
  327. dic[ATTR_MAX_VALUE] = max_value_tuple;
  328. }
  329. if (dyn_shape) {
  330. dic[ATTR_MIN_SHAPE] = min_shape_tuple;
  331. dic[ATTR_MAX_SHAPE] = max_shape_tuple;
  332. }
  333. return dic;
  334. }
  335. py::dict AbstractListToPython(const AbstractBasePtr &abs_base) {
  336. auto arg_list = dyn_cast<AbstractList>(abs_base);
  337. MS_EXCEPTION_IF_NULL(arg_list);
  338. size_t len = arg_list->size();
  339. py::list shape_list(len);
  340. py::list dtype_list(len);
  341. py::list value_list(len);
  342. py::list min_shape_list(len);
  343. py::list max_shape_list(len);
  344. bool dyn_shape = false;
  345. for (size_t i = 0; i < len; i++) {
  346. py::dict out = ConvertAbstractToPython(arg_list->elements()[i]);
  347. shape_list[i] = out[ATTR_SHAPE];
  348. dtype_list[i] = out[ATTR_DTYPE];
  349. value_list[i] = out[ATTR_VALUE];
  350. // Elements in list is tensor, which shape is dynamic.
  351. if (out.contains(py::str(ATTR_MIN_SHAPE)) && out.contains(py::str(ATTR_MAX_SHAPE))) {
  352. min_shape_list[i] = out[ATTR_MIN_SHAPE];
  353. max_shape_list[i] = out[ATTR_MAX_SHAPE];
  354. dyn_shape = true;
  355. }
  356. }
  357. auto dic = py::dict();
  358. dic[ATTR_SHAPE] = shape_list;
  359. dic[ATTR_DTYPE] = dtype_list;
  360. MS_EXCEPTION_IF_NULL(arg_list->BuildValue());
  361. if (arg_list->BuildValue()->isa<AnyValue>()) {
  362. dic[ATTR_VALUE] = py::none();
  363. } else {
  364. dic[ATTR_VALUE] = value_list;
  365. }
  366. if (dyn_shape) {
  367. dic[ATTR_MIN_SHAPE] = min_shape_list;
  368. dic[ATTR_MAX_SHAPE] = max_shape_list;
  369. }
  370. return dic;
  371. }
  372. void ConvertAbstractTensorToPython(const AbstractBasePtr &abs_base, py::dict *dic) {
  373. auto arg_tensor = dyn_cast<AbstractTensor>(abs_base);
  374. MS_EXCEPTION_IF_NULL(dic);
  375. MS_EXCEPTION_IF_NULL(arg_tensor);
  376. MS_EXCEPTION_IF_NULL(arg_tensor->shape());
  377. (*dic)[ATTR_SHAPE] = arg_tensor->shape()->shape();
  378. const auto &min_shape = arg_tensor->shape()->min_shape();
  379. const auto &max_shape = arg_tensor->shape()->max_shape();
  380. if (!min_shape.empty() && !max_shape.empty()) {
  381. (*dic)[ATTR_MIN_SHAPE] = min_shape;
  382. (*dic)[ATTR_MAX_SHAPE] = max_shape;
  383. }
  384. auto min_value = arg_tensor->get_min_value();
  385. auto max_value = arg_tensor->get_max_value();
  386. if (min_value != nullptr && max_value != nullptr) {
  387. (*dic)[ATTR_MIN_VALUE] = BuildValue(min_value);
  388. (*dic)[ATTR_MAX_VALUE] = BuildValue(max_value);
  389. }
  390. (*dic)[ATTR_DTYPE] = arg_tensor->BuildType();
  391. (*dic)[ATTR_VALUE] = BuildValue(arg_tensor->BuildValue());
  392. }
  393. void ConvertAbstractFunctionToPython(const AbstractBasePtr &abs_base, py::dict *dic) {
  394. MS_EXCEPTION_IF_NULL(dic);
  395. MS_EXCEPTION_IF_NULL(abs_base);
  396. (*dic)[ATTR_SHAPE] = py::none();
  397. (*dic)[ATTR_DTYPE] = abs_base->BuildType();
  398. (*dic)[ATTR_VALUE] = py::none();
  399. if (abs_base->isa<PartialAbstractClosure>()) {
  400. AbstractBasePtrList args = abs_base->cast<PartialAbstractClosurePtr>()->args();
  401. if (!args.empty()) {
  402. MS_EXCEPTION_IF_NULL(args[0]->BuildValue());
  403. auto value = args[0]->BuildValue()->cast<parse::ClassTypePtr>();
  404. if (value != nullptr) {
  405. (*dic)[ATTR_DTYPE] = std::make_shared<TypeType>();
  406. (*dic)[ATTR_VALUE] = value->obj();
  407. }
  408. }
  409. }
  410. }
  411. bool CheckType(const TypePtr &expected_type, const TypePtr &x) {
  412. // As x and predicate both are mindspore type statically, here we only to judge whether
  413. // x is predicate or is a subclass of predicate.
  414. return IsIdentidityOrSubclass(x, expected_type);
  415. }
  416. // Join all types in args_type_list;
  417. TypePtr TypeJoin(const TypePtrList &args_type_list) {
  418. if (args_type_list.empty()) {
  419. MS_LOG(EXCEPTION) << "args_type_list is empty";
  420. }
  421. TypePtr type_tmp = args_type_list[0];
  422. for (std::size_t i = 1; i < args_type_list.size(); i++) {
  423. type_tmp = abstract::TypeJoin(type_tmp, args_type_list[i]);
  424. }
  425. return type_tmp;
  426. }
  427. TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_list) {
  428. MS_EXCEPTION_IF_NULL(predicate);
  429. for (const auto &arg_type : args_type_list) {
  430. MS_EXCEPTION_IF_NULL(arg_type);
  431. if (!CheckType(predicate, arg_type)) {
  432. MS_LOG(EXCEPTION) << "The expected is " << predicate->ToString() << ", not " << arg_type->ToString();
  433. }
  434. }
  435. return TypeJoin(args_type_list);
  436. }
  437. } // end anonymous namespace
  438. py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
  439. MS_EXCEPTION_IF_NULL(abs_base);
  440. auto dic = py::dict();
  441. if (abs_base->isa<AbstractTensor>()) {
  442. ConvertAbstractTensorToPython(abs_base, &dic);
  443. } else if (abs_base->isa<AbstractRowTensor>()) {
  444. auto arg = dyn_cast<AbstractRowTensor>(abs_base);
  445. dic[ATTR_SHAPE] = arg->shape()->shape();
  446. dic[ATTR_DTYPE] = arg->BuildType();
  447. dic[ATTR_VALUE] = BuildValue(arg->BuildValue());
  448. } else if (abs_base->isa<AbstractSparseTensor>()) {
  449. auto arg = dyn_cast<AbstractSparseTensor>(abs_base);
  450. dic[ATTR_SHAPE] = arg->shape()->shape();
  451. dic[ATTR_DTYPE] = arg->BuildType();
  452. dic[ATTR_VALUE] = BuildValue(arg->BuildValue());
  453. } else if (abs_base->isa<AbstractCSRTensor>()) {
  454. auto arg = dyn_cast<AbstractCSRTensor>(abs_base);
  455. dic[ATTR_SHAPE] = arg->shape()->shape();
  456. dic[ATTR_DTYPE] = arg->BuildType();
  457. dic[ATTR_VALUE] = BuildValue(arg->BuildValue());
  458. } else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>() || abs_base->isa<AbstractRefKey>()) {
  459. ShapeVector shape;
  460. dic[ATTR_SHAPE] = shape;
  461. dic[ATTR_DTYPE] = abs_base->BuildType();
  462. dic[ATTR_VALUE] = BuildValue(abs_base->BuildValue());
  463. } else if (abs_base->isa<AbstractSlice>()) {
  464. auto arg_slice = dyn_cast<AbstractSlice>(abs_base);
  465. ShapeVector shape;
  466. dic[ATTR_SHAPE] = shape;
  467. dic[ATTR_DTYPE] = arg_slice->BuildType();
  468. dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue());
  469. } else if (abs_base->isa<AbstractEllipsis>()) {
  470. dic[ATTR_SHAPE] = py::none();
  471. dic[ATTR_DTYPE] = py::ellipsis();
  472. dic[ATTR_VALUE] = py::ellipsis();
  473. } else if (abs_base->isa<AbstractTuple>()) {
  474. return AbstractTupleToPython(abs_base);
  475. } else if (abs_base->isa<AbstractList>()) {
  476. return AbstractListToPython(abs_base);
  477. } else if (abs_base->isa<AbstractNone>()) {
  478. dic[ATTR_SHAPE] = py::none();
  479. dic[ATTR_DTYPE] = py::none();
  480. dic[ATTR_VALUE] = py::none();
  481. } else if (abs_base->isa<AbstractFunction>()) {
  482. ConvertAbstractFunctionToPython(abs_base, &dic);
  483. } else if (abs_base->isa<AbstractUndetermined>()) {
  484. auto arg = dyn_cast<AbstractUndetermined>(abs_base);
  485. dic[ATTR_SHAPE] = py::none();
  486. dic[ATTR_DTYPE] = arg->BuildType();
  487. dic[ATTR_VALUE] = py::none();
  488. } else if (abs_base->isa<AbstractMonad>()) {
  489. dic[ATTR_SHAPE] = py::none();
  490. dic[ATTR_DTYPE] = abs_base->BuildType();
  491. dic[ATTR_VALUE] = py::none();
  492. } else {
  493. auto value = abs_base->BuildValue();
  494. MS_EXCEPTION_IF_NULL(value);
  495. if ((*value == *kAnyValue)) {
  496. auto value_desc = abs_base->value_desc();
  497. MS_EXCEPTION(TypeError) << "Unsupported parameter " << (value_desc.empty() ? "type" : value_desc)
  498. << " for python primitive." << abs_base->ToString();
  499. }
  500. MS_EXCEPTION(TypeError) << "Unsupported parameter type for python primitive, the parameter value is "
  501. << value->ToString();
  502. }
  503. return dic;
  504. }
  505. namespace {
  506. py::tuple PreparePyInputs(const PrimitivePyPtr &, const AbstractBasePtrList &args) {
  507. // The monad parameter is defined at the end of the parameter and needs to be ignored
  508. std::size_t size_args = args.size() - GetAbstractMonadNum(args);
  509. py::tuple py_args(size_args);
  510. for (size_t i = 0; i < size_args; i++) {
  511. auto arg_i = (args)[i];
  512. py_args[i] = ConvertAbstractToPython(arg_i);
  513. }
  514. return py_args;
  515. }
  516. void CheckCustomPrimOutputInferResult(const PrimitivePtr &prim, const AbstractBasePtr &res_spec) {
  517. MS_EXCEPTION_IF_NULL(prim);
  518. MS_EXCEPTION_IF_NULL(res_spec);
  519. const string kOutputNum = "output_num";
  520. if (prim->IsCustomPrim()) {
  521. // Raise error if output_num is not match the infer result.
  522. auto output_num_value = prim->GetAttr(kOutputNum);
  523. if (output_num_value == nullptr) {
  524. MS_LOG(DEBUG) << "The output num may no need to check";
  525. return;
  526. }
  527. int64_t output_num = GetValue<int64_t>(output_num_value);
  528. if (res_spec->isa<AbstractTensor>() && output_num != 1) {
  529. MS_LOG(EXCEPTION) << "Custom operator primitive[" << prim->ToString()
  530. << "]'s attribute[output_num]:" << output_num << " not matches the infer result "
  531. << res_spec->ToString();
  532. } else if (res_spec->isa<AbstractTuple>() &&
  533. (res_spec->cast<AbstractTuplePtr>()->size() != LongToSize(output_num))) {
  534. MS_LOG(EXCEPTION) << "Custom operator primitive[" << prim->ToString()
  535. << "]'s attribute[output_num]:" << output_num << " not matches the infer result "
  536. << res_spec->ToString();
  537. }
  538. }
  539. }
  540. AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) {
  541. // Convert to AbstractValue based on type and shape
  542. auto out_dtype = output[ATTR_DTYPE];
  543. if (output[ATTR_VALUE].is_none()) {
  544. auto out_shape = output[ATTR_SHAPE];
  545. return MakePyInferRes2Abstract(out_shape, out_dtype, output);
  546. }
  547. // Convert pyobject to Value, then to AbstractValue
  548. ValuePtr converted_ret = nullptr;
  549. TypePtr dtype = py::isinstance<Type>(out_dtype) ? out_dtype.cast<TypePtr>() : nullptr;
  550. bool converted = parse::ConvertData(output[ATTR_VALUE], &converted_ret, false, dtype);
  551. if (!converted) {
  552. MS_LOG(EXCEPTION) << "Convert data failed";
  553. }
  554. auto res_spec = FromValue(converted_ret);
  555. MS_EXCEPTION_IF_NULL(res_spec);
  556. if (res_spec->isa<AbstractTensor>()) {
  557. // Replace to tensor constant node in specialize
  558. auto res_tensor = res_spec->cast<AbstractTensorPtr>();
  559. res_tensor->set_value(converted_ret);
  560. SetValueRange(res_tensor, output);
  561. }
  562. CheckCustomPrimOutputInferResult(prim_py, res_spec);
  563. return res_spec;
  564. }
  565. } // end anonymous namespace
  566. EvalResultPtr StandardPrimEvaluator::RunPyInferValue(const AnalysisEnginePtr &engine, const AbstractBasePtr &abs_base,
  567. const AbstractBasePtrList &args) {
  568. auto prim_py = dyn_cast<PrimitivePy>(prim_);
  569. if (prim_py == nullptr) {
  570. MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyCheck' should be a python primitive.";
  571. }
  572. // Call checking method 'infer_value' for python primitive
  573. MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString();
  574. auto py_args = PreparePyInputs(prim_py, args);
  575. py::tuple py_vals(py_args.size());
  576. auto added_attrs = prim_->evaluate_added_attrs();
  577. for (size_t i = 0; i < py_args.size(); ++i) {
  578. py_vals[i] = py_args[i][ATTR_VALUE];
  579. }
  580. py::object py_ret = prim_py->RunInferValue(py_vals);
  581. if (py::isinstance<py::none>(py_ret)) {
  582. return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
  583. }
  584. // Convert pyobject to Value, then to AbstractValue
  585. ValuePtr converted_ret = nullptr;
  586. TypePtr dtype = abs_base->BuildType();
  587. bool converted = parse::ConvertData(py_ret, &converted_ret, false, dtype);
  588. if (!converted) {
  589. MS_LOG(EXCEPTION) << "Convert data failed";
  590. }
  591. auto res_spec = FromValue(converted_ret);
  592. MS_EXCEPTION_IF_NULL(res_spec);
  593. if (res_spec->isa<AbstractTensor>()) {
  594. // Replace to tensor constant node in specialize
  595. auto res_tensor = res_spec->cast<AbstractTensorPtr>();
  596. res_tensor->set_value(converted_ret);
  597. }
  598. return std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
  599. }
  600. EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
  601. auto prim_py = dyn_cast<PrimitivePy>(prim_);
  602. if (prim_py == nullptr) {
  603. MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyCheck' should be a python primitive.";
  604. }
  605. // Call checking method '__check__' for subclass of 'PrimitiveWithCheck'
  606. MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString();
  607. auto py_args = PreparePyInputs(prim_py, args);
  608. prim_py->RunCheck(py_args);
  609. prim_->BeginRecordAddAttr();
  610. AbstractBasePtr abs_base = eval_impl_.infer_shape_impl_(engine, prim_, args);
  611. prim_->EndRecordAddAttr();
  612. auto added_attrs = prim_->evaluate_added_attrs();
  613. if (!py::hasattr(prim_py->GetPyObj(), PY_PRIM_METHOD_INFER_VALUE)) {
  614. return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
  615. }
  616. // Call method 'infer_value' for primitive with this method for constant propagation
  617. return RunPyInferValue(engine, abs_base, args);
  618. }
  619. EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
  620. if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) {
  621. auto ret_abstract = AbstractEval(args);
  622. if (ret_abstract != nullptr) {
  623. MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined";
  624. return ret_abstract;
  625. }
  626. }
  627. if (prim_->prim_type() == PrimType::kPrimTypePyCheck) {
  628. return EvalPyCheckPrim(engine, args);
  629. }
  630. auto context = MsContext::GetInstance();
  631. MS_EXCEPTION_IF_NULL(context);
  632. bool need_infer_value = !eval_impl_.in_white_list_;
  633. if (need_infer_value == false) {
  634. need_infer_value = ((context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode)) &&
  635. std::all_of(args.begin(), args.end(), [](const AbstractBasePtr &abs) -> bool {
  636. MS_EXCEPTION_IF_NULL(abs);
  637. auto value = abs->BuildValue();
  638. return (value != nullptr && !value->isa<AnyValue>() && !value->isa<None>() &&
  639. !value->isa<Monad>() && !value->isa<FuncGraph>());
  640. });
  641. }
  642. AbstractBasePtr abs_base = nullptr;
  643. ValuePtr value = nullptr;
  644. prim_->BeginRecordAddAttr();
  645. if (need_infer_value && eval_impl_.infer_value_impl_ != nullptr) {
  646. value = eval_impl_.infer_value_impl_(prim_, args);
  647. if (value != nullptr) {
  648. abs_base = value->ToAbstract();
  649. prim_->EndRecordAddAttr();
  650. auto added_attrs = prim_->evaluate_added_attrs();
  651. return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
  652. }
  653. }
  654. abs_base = eval_impl_.infer_shape_impl_(engine, prim_, args);
  655. prim_->EndRecordAddAttr();
  656. auto added_attrs = prim_->evaluate_added_attrs();
  657. auto eval_result = std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
  658. return eval_result;
  659. }
  660. EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
  661. auto ret_abstract = AbstractEval(args);
  662. if (ret_abstract != nullptr) {
  663. MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined";
  664. return ret_abstract;
  665. }
  666. MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString();
  667. const auto eval_result = evaluator_cache_mgr_->GetValue(args);
  668. if (eval_result != nullptr) {
  669. auto abs = eval_result->abstract()->Clone();
  670. auto attr = eval_result->attribute();
  671. return std::make_shared<EvalResult>(abs, attr);
  672. }
  673. auto py_args = PreparePyInputs(prim_py_, args);
  674. prim_py_->BeginRecordAddAttr();
  675. py::dict output = prim_py_->RunInfer(py_args);
  676. prim_py_->EndRecordAddAttr();
  677. auto added_attrs = prim_py_->evaluate_added_attrs();
  678. MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output);
  679. auto res_spec = PyInferRes2Abstract(prim_py_, output);
  680. MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << ".";
  681. auto infer_result = std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
  682. evaluator_cache_mgr_->SetValue(args, infer_result);
  683. return infer_result;
  684. }
  685. EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
  686. auto ret_abstract = AbstractEval(args);
  687. if (ret_abstract != nullptr) {
  688. MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined";
  689. return ret_abstract;
  690. }
  691. // if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type.
  692. if (nargs_ != args.size()) {
  693. MS_LOG(EXCEPTION) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size() << " inputs";
  694. }
  695. TypePtr ret_value_type = return_value_type_;
  696. ValuePtrList value_list;
  697. for (const auto &arg : args) {
  698. // Check if all arguments are scalar type.
  699. MS_EXCEPTION_IF_NULL(arg);
  700. if (arg->isa<AbstractScalar>()) {
  701. auto arg_scalar = dyn_cast<AbstractScalar>(arg);
  702. auto arg_value = arg_scalar->GetValueTrack();
  703. value_list.push_back(arg_value);
  704. } else {
  705. // Raise TypeError Expected Scalar.
  706. MS_LOG(EXCEPTION) << "Expect scalar arguments for uniform primitives.";
  707. }
  708. }
  709. for (const auto &item : type_map_) {
  710. TypePtrList selections;
  711. MS_EXCEPTION_IF_NULL(item.second);
  712. (void)std::transform(item.second->begin(), item.second->end(), std::back_inserter(selections),
  713. [&args](size_t arg_idx) -> TypePtr {
  714. if (arg_idx >= args.size()) {
  715. MS_LOG(EXCEPTION) << "Index:" << arg_idx << " out of range:" << args.size();
  716. }
  717. MS_EXCEPTION_IF_NULL(args[arg_idx]);
  718. return args[arg_idx]->GetTypeTrack();
  719. });
  720. TypePtr res = CheckTypeList(item.first, selections);
  721. MS_EXCEPTION_IF_NULL(return_value_type_);
  722. MS_EXCEPTION_IF_NULL(item.first);
  723. if (*return_value_type_ == *(item.first)) {
  724. ret_value_type = res;
  725. }
  726. }
  727. ValuePtr evaluated_value = RunImpl(value_list);
  728. if (!(*evaluated_value == *kAnyValue)) {
  729. ret_value_type = evaluated_value->type();
  730. }
  731. // for comparison primitives , return type shall have be specified to be bool.
  732. if (specify_out_type_ != nullptr) {
  733. ret_value_type = specify_out_type_;
  734. }
  735. AbstractScalarPtr abs_base = std::make_shared<AbstractScalar>(evaluated_value, ret_value_type);
  736. return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>());
  737. }
  738. ValuePtr UniformPrimEvaluator::RunImpl(const ValuePtrList &args) const {
  739. if (!eval_value_) {
  740. return kAnyValue;
  741. } else {
  742. if (std::any_of(args.begin(), args.end(), [](const ValuePtr &arg) {
  743. MS_EXCEPTION_IF_NULL(arg);
  744. return arg->isa<AnyValue>();
  745. })) {
  746. return kAnyValue;
  747. }
  748. return impl_(args);
  749. }
  750. }
  751. // Primitive implementation
  752. // static function start
  753. namespace {
  754. EvaluatorPtr InitStandardPrimEvaluator(PrimitivePtr primitive, const StandardPrimitiveImplReg eval_impl) {
  755. EvaluatorPtr prim_evaluator = std::make_shared<StandardPrimEvaluator>(primitive, eval_impl);
  756. return prim_evaluator;
  757. }
  758. EvaluatorPtr InitUniformPrimEvaluator(const PrimitivePtr &primitive, PrimitiveImpl prim_impl, bool eval_value,
  759. const TypePtr &specify_out_type) {
  760. FunctionPtr func = nullptr;
  761. (void)prim::PrimToFunction::GetInstance().GetFunction(primitive, &func);
  762. MS_EXCEPTION_IF_NULL(func);
  763. EvaluatorPtr uniform_primitive_evaluator =
  764. std::make_shared<UniformPrimEvaluator>(func, prim_impl, eval_value, specify_out_type);
  765. return uniform_primitive_evaluator;
  766. }
  767. FuncGraphPtr PyObjToGraph(const AnalysisEnginePtr &engine, const ValuePtr &method) {
  768. MS_EXCEPTION_IF_NULL(engine);
  769. MS_EXCEPTION_IF_NULL(method);
  770. if (!method->isa<parse::PyObjectWrapper>()) {
  771. MS_LOG(EXCEPTION) << "Method type error: " << method->ToString();
  772. }
  773. std::shared_ptr<PyObjectWrapper> obj = method->cast<std::shared_ptr<PyObjectWrapper>>();
  774. FuncGraphPtr func_graph = mindspore::parse::ConvertToFuncGraph(obj->obj());
  775. if (func_graph == nullptr) {
  776. MS_LOG(EXCEPTION) << "Parse python object: " << method->ToString() << " failed";
  777. }
  778. FuncGraphManagerPtr manager = engine->func_graph_manager();
  779. manager->AddFuncGraph(func_graph);
  780. return func_graph;
  781. }
  782. inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr func_graph) {
  783. MS_EXCEPTION_IF_NULL(engine);
  784. FuncGraphManagerPtr manager = engine->func_graph_manager();
  785. manager->AddFuncGraph(func_graph);
  786. }
  787. enum class REQUIRE_TYPE { ATTR, METHOD };
  788. EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, const AnfNodeConfigPtr &old_conf,
  789. REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD) {
  790. MS_EXCEPTION_IF_NULL(old_conf);
  791. AbstractBasePtr abstract = ToAbstract(value, AnalysisContext::DummyContext(), old_conf);
  792. AbstractFunctionPtr abs_func = dyn_cast<abstract::AbstractFunction>(abstract);
  793. MS_EXCEPTION_IF_NULL(abs_func);
  794. // Create new cnode
  795. std::vector<AnfNodePtr> input = {NewValueNode(prim::kPrimPartial)};
  796. auto func_graph_func = dyn_cast<abstract::FuncGraphAbstractClosure>(abs_func);
  797. if (func_graph_func != nullptr) {
  798. FuncGraphPtr fg = func_graph_func->func_graph();
  799. input.push_back(NewValueNode(fg));
  800. } else {
  801. auto prim_func = dyn_cast<abstract::PrimitiveAbstractClosure>(abs_func);
  802. MS_EXCEPTION_IF_NULL(prim_func);
  803. PrimitivePtr prim = prim_func->prim();
  804. input.push_back(NewValueNode(prim));
  805. }
  806. AnfNodeConfigPtr conf = dyn_cast<abstract::AnfNodeConfig>(data_conf);
  807. MS_EXCEPTION_IF_NULL(conf);
  808. input.push_back(conf->node());
  809. MS_EXCEPTION_IF_NULL(old_conf);
  810. FuncGraphPtr func_graph = old_conf->node()->func_graph();
  811. MS_EXCEPTION_IF_NULL(func_graph);
  812. CNodePtr new_cnode = func_graph->NewCNode(input);
  813. if (require_type == REQUIRE_TYPE::ATTR) {
  814. new_cnode = func_graph->NewCNode({new_cnode});
  815. }
  816. AnalysisEnginePtr eng = old_conf->engine();
  817. AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, old_conf->context(), old_conf->func_graph());
  818. return eng->ForwardConfig(old_conf, fn_conf);
  819. }
  820. EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &, const AbstractBasePtrList &args_spec_list,
  821. const AnfNodeConfigPtr &out_conf) {
  822. // args_spec_list: same as StaticGetter
  823. if (args_spec_list.size() < 2) {
  824. MS_LOG(EXCEPTION) << "Size of args_spec_list is less than 2";
  825. }
  826. MS_EXCEPTION_IF_NULL(out_conf);
  827. // An external type.
  828. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  829. MS_EXCEPTION_IF_NULL(args_spec_list[1]);
  830. auto data_value = args_spec_list[0]->BuildValue();
  831. MS_EXCEPTION_IF_NULL(data_value);
  832. if (!data_value->isa<parse::NameSpace>()) {
  833. MS_EXCEPTION(TypeError) << "Not supported to get attribute for " << data_value->ToString()
  834. << "\nThe first argument should be a NameSpace, but got " << args_spec_list[0]->ToString();
  835. }
  836. auto item_value = args_spec_list[1]->BuildValue();
  837. MS_EXCEPTION_IF_NULL(item_value);
  838. if (item_value->isa<StringImm>()) {
  839. item_value = std::make_shared<parse::Symbol>(item_value->cast<StringImmPtr>()->value());
  840. }
  841. if (!item_value->isa<parse::Symbol>()) {
  842. MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
  843. }
  844. // item_name to func addr from obj_map
  845. parse::SymbolPtr symbol = item_value->cast<parse::SymbolPtr>();
  846. parse::NameSpacePtr name_space = data_value->cast<parse::NameSpacePtr>();
  847. MS_EXCEPTION_IF_NULL(out_conf);
  848. auto out_node = out_conf->node();
  849. FuncGraphPtr func_graph = out_node->func_graph();
  850. MS_EXCEPTION_IF_NULL(func_graph);
  851. auto new_node = parse::ResolveSymbol(func_graph->manager(), name_space, symbol, out_node);
  852. if (new_node == nullptr) {
  853. MS_LOG(EXCEPTION) << "Resolve node failed";
  854. }
  855. if (pipeline::GetJitLevel() == "o0" && IsValueNode<FuncGraph>(new_node)) {
  856. UpdateDebugInfo(GetValueNode<FuncGraphPtr>(new_node), out_node->scope(), out_node->debug_info());
  857. }
  858. // Replace old node with the resolved new node in order list.
  859. func_graph->ReplaceInOrder(out_node, new_node);
  860. AnalysisEnginePtr eng = out_conf->engine();
  861. MS_EXCEPTION_IF_NULL(eng);
  862. AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
  863. return eng->ForwardConfig(out_conf, fn_conf);
  864. }
  865. EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine,
  866. const AbstractBasePtrList &args_spec_list,
  867. const ValuePtr &item_value, const ConfigPtr &data_conf,
  868. const AnfNodeConfigPtr &out_conf) {
  869. if (args_spec_list.empty()) {
  870. MS_LOG(EXCEPTION) << "args_spec_list is empty";
  871. }
  872. AbstractClassPtr cls = CheckArg<AbstractClass>("__FUNC__", args_spec_list, 0);
  873. // If item_value is an attribute, get abstract value from AbstractClass
  874. MS_EXCEPTION_IF_NULL(item_value);
  875. if (!item_value->isa<StringImm>()) {
  876. MS_LOG(EXCEPTION) << "Attribute type error";
  877. }
  878. std::string item_name = item_value->cast<StringImmPtr>()->value();
  879. MS_LOG(DEBUG) << "Resolve name: " << cls->tag().name();
  880. MS_LOG(DEBUG) << "Resolve item: " << item_name;
  881. MS_EXCEPTION_IF_NULL(cls);
  882. AbstractBasePtr attr = cls->GetAttribute(item_name);
  883. if (attr != nullptr) {
  884. return std::make_shared<EvalResult>(attr, nullptr);
  885. }
  886. ValuePtr method = cls->GetMethod(item_name);
  887. if (method->isa<AnyValue>()) {
  888. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  889. MS_EXCEPTION_IF_NULL(args_spec_list[0]->BuildType());
  890. MS_EXCEPTION(AttributeError) << "Unknown field, data type: " << args_spec_list[0]->BuildType()->ToString()
  891. << ", item value: " << item_value->ToString();
  892. }
  893. // Infer class method
  894. ValuePtr converted_value = PyObjToGraph(engine, method);
  895. return StaticGetterInferred(converted_value, data_conf, out_conf);
  896. }
  897. EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_value,
  898. const TypePtr &data_type, const ConfigPtr &data_conf,
  899. const AnfNodeConfigPtr &out_conf) {
  900. MS_EXCEPTION_IF_NULL(item_value);
  901. MS_EXCEPTION_IF_NULL(data_type);
  902. // The method maybe a Primitive or Composite
  903. if (!item_value->isa<StringImm>()) {
  904. MS_LOG(EXCEPTION) << "Expect a string, but got: " << item_value->ToString();
  905. }
  906. std::string item_name = item_value->cast<StringImmPtr>()->value();
  907. REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD;
  908. Any require = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name);
  909. if (require.empty()) {
  910. require = pipeline::Resource::GetAttrPtr(data_type->type_id(), item_name);
  911. if (require.empty()) {
  912. MS_LOG(EXCEPTION) << "Not supported to get attribute item name:\'" << item_name << "\' of a type["
  913. << data_type->ToString() << "]";
  914. }
  915. require_type = REQUIRE_TYPE::ATTR;
  916. }
  917. ValuePtr converted_value = nullptr;
  918. if (require.is<std::string>()) {
  919. // composite registered in standard_method_map go to this branch
  920. converted_value = prim::GetPythonOps(require.cast<std::string>());
  921. MS_EXCEPTION_IF_NULL(converted_value);
  922. if (pipeline::GetJitLevel() == "o0" && converted_value->isa<FuncGraph>()) {
  923. UpdateDebugInfo(converted_value->cast<FuncGraphPtr>(), out_conf->node()->scope(), out_conf->node()->debug_info());
  924. }
  925. if (!converted_value->isa<Primitive>()) {
  926. AddToManager(engine, converted_value->cast<FuncGraphPtr>());
  927. }
  928. } else if (require.is<PrimitivePtr>()) {
  929. converted_value = require.cast<PrimitivePtr>();
  930. } else {
  931. MS_LOG(EXCEPTION) << "Expect to get string or PrimitivePtr from attr or method map, but got " << require.ToString();
  932. }
  933. return StaticGetterInferred(converted_value, data_conf, out_conf, require_type);
  934. }
  935. enum ResolveType : int64_t {
  936. kResolveTypeUserDefineClass = 1,
  937. kResolveTypeBuiltInType,
  938. kResolveTypeFunction,
  939. };
  940. int64_t GetResolveType(const TypePtr &data_type) {
  941. MS_EXCEPTION_IF_NULL(data_type);
  942. if (data_type->type_id() == kObjectTypeClass) {
  943. return kResolveTypeUserDefineClass;
  944. }
  945. // Try to search method map, if not found, the data_type should be External type.
  946. if (pipeline::Resource::IsTypeInBuiltInMap(data_type->type_id())) {
  947. return kResolveTypeBuiltInType;
  948. }
  949. return kResolveTypeFunction;
  950. }
  951. EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
  952. const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
  953. // Inputs: namespace and its static function; or class and its member function
  954. CheckArgsSize("StaticGetter", args_spec_list, 2);
  955. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  956. MS_EXCEPTION_IF_NULL(args_spec_list[1]);
  957. MS_LOG(DEBUG) << "Args[0]: " << args_spec_list[0]->ToString();
  958. MS_LOG(DEBUG) << "Args[1]: " << args_spec_list[1]->ToString();
  959. TypePtr data_type = args_spec_list[0]->BuildType();
  960. ValuePtr item_value = args_spec_list[1]->BuildValue();
  961. ScopePtr scope = kDefaultScope;
  962. if (out_conf != nullptr) {
  963. scope = out_conf->node()->scope();
  964. }
  965. ScopeGuard scope_guard(scope);
  966. MS_EXCEPTION_IF_NULL(item_value);
  967. if (item_value->isa<AnyValue>()) {
  968. MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
  969. }
  970. int64_t resolve_type = GetResolveType(data_type);
  971. if (resolve_type == kResolveTypeUserDefineClass) {
  972. return GetEvaluatedValueForClassAttrOrMethod(engine, args_spec_list, item_value, data_conf, out_conf);
  973. } else if (resolve_type == kResolveTypeBuiltInType) {
  974. return GetEvaluatedValueForBuiltinTypeAttrOrMethod(engine, item_value, data_type, data_conf, out_conf);
  975. } else {
  976. return GetEvaluatedValueForNameSpaceString(engine, args_spec_list, out_conf);
  977. }
  978. }
  979. } // end anonymous namespace
  980. namespace {
  981. class EmbedEvaluator : public SymbolicPrimEvaluator {
  982. public:
  983. EmbedEvaluator() : SymbolicPrimEvaluator("EmbedEvaluator") {}
  984. ~EmbedEvaluator() override = default;
  985. MS_DECLARE_PARENT(EmbedEvaluator, SymbolicPrimEvaluator);
  986. EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override {
  987. // arg: free variable to be embedded
  988. if (args_conf_list.size() != 1) {
  989. MS_LOG(EXCEPTION) << "EmbedEvaluator requires 1 parameter, but got " << args_conf_list.size();
  990. }
  991. AnfNodeConfigPtr node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]);
  992. MS_EXCEPTION_IF_NULL(node_conf);
  993. MS_EXCEPTION_IF_NULL(node_conf->ObtainEvalResult());
  994. AbstractBasePtr x = node_conf->ObtainEvalResult()->abstract();
  995. x = SensitivityTransform(x);
  996. SymbolicKeyInstancePtr key = std::make_shared<SymbolicKeyInstance>(node_conf->node(), x);
  997. AbstractScalarPtr abs_scalar = std::make_shared<AbstractScalar>(key, std::make_shared<SymbolicKeyType>());
  998. return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
  999. }
  1000. };
  1001. static AnfNodePtr FindParameterNodeByString(const FuncGraphManagerPtr &manager, const std::string &name) {
  1002. MS_EXCEPTION_IF_NULL(manager);
  1003. auto root_g_set = manager->roots();
  1004. if (root_g_set.size() != 1) {
  1005. return nullptr;
  1006. }
  1007. const FuncGraphPtr &root_g = root_g_set.back();
  1008. for (auto &param_node : root_g->parameters()) {
  1009. auto param = param_node->cast<ParameterPtr>();
  1010. if (param && name == param->name()) {
  1011. return param;
  1012. }
  1013. }
  1014. return nullptr;
  1015. }
  1016. class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
  1017. public:
  1018. RefToEmbedEvaluator() : SymbolicPrimEvaluator("RefToEmbedEvaluator") {}
  1019. ~RefToEmbedEvaluator() override = default;
  1020. MS_DECLARE_PARENT(RefToEmbedEvaluator, SymbolicPrimEvaluator);
  1021. EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override {
  1022. if (args_conf_list.size() != 1) {
  1023. MS_LOG(ERROR) << "Requires 1 parameter, but has: " << args_conf_list.size();
  1024. return nullptr;
  1025. }
  1026. static TypePtr type = std::make_shared<SymbolicKeyType>();
  1027. auto node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]);
  1028. if (node_conf == nullptr) {
  1029. MS_LOG(ERROR) << "Conf should be AnfNodeConfig";
  1030. return nullptr;
  1031. }
  1032. MS_EXCEPTION_IF_NULL(node_conf->ObtainEvalResult());
  1033. AbstractBasePtr abs = node_conf->ObtainEvalResult()->abstract();
  1034. MS_EXCEPTION_IF_NULL(abs);
  1035. AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>();
  1036. if (ref_abs == nullptr) {
  1037. MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString();
  1038. return nullptr;
  1039. }
  1040. auto key_abs = ref_abs->ref_key();
  1041. if (key_abs == nullptr) {
  1042. MS_LOG(ERROR) << "RefToEmbed input Ref key is nullptr.";
  1043. return nullptr;
  1044. }
  1045. auto key_value = key_abs->BuildValue();
  1046. if (key_value == nullptr) {
  1047. MS_LOG(ERROR) << "RefToEmbed input Ref key value is nullptr.";
  1048. return nullptr;
  1049. }
  1050. auto refkey = key_value->cast<RefKeyPtr>();
  1051. if (refkey == nullptr) {
  1052. auto ret = std::make_shared<AbstractScalar>(type);
  1053. auto ref_value = ref_abs->ref();
  1054. MS_EXCEPTION_IF_NULL(ref_value);
  1055. return std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
  1056. }
  1057. std::string name = refkey->tag();
  1058. MS_EXCEPTION_IF_NULL(node_conf->node());
  1059. if (node_conf->node()->func_graph() == nullptr) {
  1060. MS_LOG(EXCEPTION) << "Should not evaluate a ValueNode, node: " << node_conf->node()->DebugString();
  1061. }
  1062. const auto &manager = node_conf->node()->func_graph()->manager();
  1063. auto node = FindParameterNodeByString(manager, name);
  1064. if (node == nullptr) {
  1065. MS_LOG(ERROR) << "RefToEmbed input can't find parameter \"" << name << "\" in graph.";
  1066. return nullptr;
  1067. }
  1068. AbstractBasePtr x = ref_abs->ref();
  1069. x = SensitivityTransform(x);
  1070. std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x);
  1071. std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type);
  1072. return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
  1073. }
  1074. };
  1075. class GetAttrEvaluator : public TransitionPrimEvaluator {
  1076. public:
  1077. GetAttrEvaluator() : TransitionPrimEvaluator("GetAttrEvaluator") {}
  1078. ~GetAttrEvaluator() override = default;
  1079. MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator);
  1080. EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
  1081. const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
  1082. constexpr auto kGetAttrArgSize = 2;
  1083. auto ret_abstract = AbstractEval(args_spec_list);
  1084. if (ret_abstract != nullptr) {
  1085. MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined";
  1086. return ret_abstract;
  1087. }
  1088. // Inputs: data, item
  1089. if (args_spec_list.size() != kGetAttrArgSize) {
  1090. MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
  1091. }
  1092. EvalResultPtr ret = nullptr;
  1093. if (bound_node() != nullptr) {
  1094. TraceGuard trace_guard(std::make_shared<TraceResolve>(bound_node()->debug_info()));
  1095. ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
  1096. } else {
  1097. ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
  1098. }
  1099. // don't lookup from cache, as different out_conf with same node but different context
  1100. // may add different entry to anfnode_config_map, like getattr primitive;
  1101. evaluator_cache_mgr_->SetValue(args_spec_list, ret);
  1102. return ret;
  1103. }
  1104. };
  1105. class ResolveEvaluator : public TransitionPrimEvaluator {
  1106. public:
  1107. ResolveEvaluator() : TransitionPrimEvaluator("ResolveEvaluator") {}
  1108. ~ResolveEvaluator() override = default;
  1109. MS_DECLARE_PARENT(ResolveEvaluator, TransitionPrimEvaluator);
  1110. EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
  1111. const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
  1112. constexpr auto kResolveArgSize = 2;
  1113. // Inputs: namespace, symbol
  1114. if (args_spec_list.size() != kResolveArgSize) {
  1115. MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
  1116. }
  1117. EvalResultPtr ret = nullptr;
  1118. if (bound_node() != nullptr) {
  1119. TraceGuard trace_guard(std::make_shared<TraceResolve>(bound_node()->debug_info()));
  1120. ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
  1121. } else {
  1122. ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
  1123. }
  1124. return ret;
  1125. }
  1126. };
  1127. class CreateInstanceEvaluator : public TransitionPrimEvaluator {
  1128. public:
  1129. CreateInstanceEvaluator() : TransitionPrimEvaluator("CreateInstanceEvaluator") {}
  1130. ~CreateInstanceEvaluator() override = default;
  1131. MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator);
  1132. EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
  1133. const AnfNodeConfigPtr &out_conf) override {
  1134. if (args_spec_list.empty()) {
  1135. MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty";
  1136. }
  1137. // Get the type parameter.
  1138. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  1139. TypePtr type = args_spec_list[0]->GetTypeTrack();
  1140. MS_EXCEPTION_IF_NULL(type);
  1141. if (type->type_id() != kMetaTypeTypeType) {
  1142. MS_LOG(EXCEPTION) << "CreateInstanceEvaluator require first parameter should be an object of TypeType, but got "
  1143. << type->ToString();
  1144. }
  1145. ValuePtr value_track = args_spec_list[0]->GetValueTrack();
  1146. MS_EXCEPTION_IF_NULL(value_track);
  1147. std::shared_ptr<parse::PyObjectWrapper> type_obj = dyn_cast<parse::PyObjectWrapper>(value_track);
  1148. if (type_obj == nullptr) {
  1149. MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << ".";
  1150. }
  1151. if (!type_obj->isa<parse::ClassType>()) {
  1152. MS_LOG(EXCEPTION) << "CreateInstanceEvaluator the type_obj should be an object of ClassType, but got "
  1153. << type_obj->ToString() << ".";
  1154. }
  1155. auto class_type = type_obj->obj();
  1156. MS_LOG(DEBUG) << "Get class type is " << type_obj->ToString() << ".";
  1157. // Get the create instance obj's parameters, `params` may contain tuple(args, kwargs).
  1158. py::tuple params = GetParameters(args_spec_list);
  1159. // Create class instance.
  1160. auto obj = parse::data_converter::CreatePythonObject(class_type, params);
  1161. if (py::isinstance<py::none>(obj)) {
  1162. MS_LOG(EXCEPTION) << "Create python object `" << py::str(class_type)
  1163. << "` failed, only support to create \'Cell\' or \'Primitive\' object.";
  1164. }
  1165. // Process the object.
  1166. ValuePtr converted_ret = nullptr;
  1167. bool converted = parse::ConvertData(obj, &converted_ret, true);
  1168. if (!converted) {
  1169. MS_LOG(EXCEPTION) << "Convert the python object failed";
  1170. }
  1171. MS_EXCEPTION_IF_NULL(converted_ret);
  1172. if (converted_ret->isa<FuncGraph>()) {
  1173. AddToManager(engine, converted_ret->cast<FuncGraphPtr>());
  1174. }
  1175. AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf);
  1176. auto infer_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
  1177. evaluator_cache_mgr_->SetValue(args_spec_list, infer_result);
  1178. return infer_result;
  1179. }
  1180. py::tuple GetParameters(const AbstractBasePtrList &args_spec_list) const {
  1181. // Exclude class type by minus 1;
  1182. std::size_t params_size = args_spec_list.size() - 1;
  1183. auto params = py::tuple(params_size);
  1184. if (params_size > params.size()) {
  1185. MS_LOG(EXCEPTION) << "Unexpected params_size: " << params_size << ", params.size():" << params.size();
  1186. }
  1187. if (params_size > 0) {
  1188. for (size_t i = 0; i < params_size; i++) {
  1189. // Only support the Scalar parameters type. Bypass class type by offset with 1.
  1190. auto arg = args_spec_list[i + 1];
  1191. MS_EXCEPTION_IF_NULL(arg);
  1192. // Because the Tensor's AbstractTensor can't get value from GetValueTrack.
  1193. ValuePtr param_value = arg->BuildValue();
  1194. py::object param = ValueToPyData(param_value);
  1195. params[i] = param;
  1196. }
  1197. }
  1198. return params;
  1199. }
  1200. };
  1201. class PyInterpretEvaluator : public TransitionPrimEvaluator {
  1202. public:
  1203. PyInterpretEvaluator() : TransitionPrimEvaluator("PyInterpretEvaluator") {}
  1204. ~PyInterpretEvaluator() override = default;
  1205. MS_DECLARE_PARENT(PyInterpretEvaluator, TransitionPrimEvaluator);
  1206. EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
  1207. const AnfNodeConfigPtr &out_conf) override {
  1208. if (args_spec_list.empty()) {
  1209. MS_LOG(ERROR) << "'args_spec_list' should not be empty";
  1210. }
  1211. // Get the type parameter.
  1212. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  1213. ValuePtr value_track = args_spec_list[0]->GetValueTrack();
  1214. MS_EXCEPTION_IF_NULL(value_track);
  1215. std::shared_ptr<parse::Script> script_obj = dyn_cast<parse::Script>(value_track);
  1216. if (script_obj == nullptr) {
  1217. MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << ".";
  1218. }
  1219. // Make global and local parameters.
  1220. py::tuple params = MakeParameters(args_spec_list);
  1221. // Call python script string.
  1222. MS_LOG(DEBUG) << "Call script: " << script_obj->script() << ", params: " << py::str(params);
  1223. auto obj = parse::data_converter::CallPythonScript(py::str(script_obj->script()), params);
  1224. if (py::isinstance<py::none>(obj)) {
  1225. MS_LOG(EXCEPTION) << "Failed to call python script: `" << script_obj->script() << "`";
  1226. }
  1227. ValuePtr converted_val = nullptr;
  1228. bool converted = parse::ConvertData(obj, &converted_val, true);
  1229. if (!converted) {
  1230. MS_LOG(EXCEPTION) << "Convert the python object failed";
  1231. }
  1232. MS_EXCEPTION_IF_NULL(converted_val);
  1233. AbstractBasePtr res = ToAbstract(converted_val, AnalysisContext::DummyContext(), out_conf);
  1234. auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
  1235. evaluator_cache_mgr_->SetValue(args_spec_list, infer_result);
  1236. return infer_result;
  1237. }
  1238. py::tuple MakeParameters(const AbstractBasePtrList &args_spec_list) const {
  1239. constexpr int params_size = 3;
  1240. if (params_size != args_spec_list.size()) {
  1241. MS_LOG(EXCEPTION) << "Unexpected params_size: " << params_size
  1242. << ", not equal to arguments.size:" << args_spec_list.size();
  1243. }
  1244. // The first argument is script string, ignore it.
  1245. auto params = py::tuple(params_size - 1);
  1246. // Make the global parameters.
  1247. auto global_dict = dyn_cast<AbstractDictionary>(args_spec_list[1]); // Global parameters dict.
  1248. MS_EXCEPTION_IF_NULL(global_dict);
  1249. auto filtered_global_dict = FilterParameters(global_dict);
  1250. MS_LOG(DEBUG) << "arg_1, global_dict: " << global_dict->ToString()
  1251. << ", filtered_global_dict: " << filtered_global_dict->ToString();
  1252. ValuePtr global_dict_value = filtered_global_dict->BuildValue();
  1253. py::object global_params_dict = ValueToPyData(global_dict_value);
  1254. MS_LOG(DEBUG) << "arg_1, python global_params_dict: " << global_dict_value->ToString() << " -> "
  1255. << py::str(global_params_dict);
  1256. params[0] = global_params_dict;
  1257. // Make the local parameters.
  1258. auto local_dict = dyn_cast<AbstractDictionary>(args_spec_list[2]); // Local parameters dict.
  1259. MS_EXCEPTION_IF_NULL(local_dict);
  1260. auto filtered_local_dict = FilterParameters(local_dict);
  1261. MS_LOG(DEBUG) << "arg_2, local_dict: " << local_dict->ToString()
  1262. << ", filtered_local_dict:" << filtered_local_dict->ToString();
  1263. ValuePtr local_dict_value = filtered_local_dict->BuildValue();
  1264. py::object local_params_dict = ValueToPyData(local_dict_value);
  1265. MS_LOG(DEBUG) << "arg_2, python local_params_dict: " << local_dict_value->ToString() << " -> "
  1266. << py::str(local_params_dict);
  1267. params[1] = local_params_dict;
  1268. return params;
  1269. }
  1270. AbstractDictionaryPtr FilterParameters(const AbstractDictionaryPtr &abstract_dict) const {
  1271. std::vector<AbstractAttribute> kv;
  1272. const auto &keys_values = abstract_dict->elements();
  1273. // Filter out the element of Function type.
  1274. (void)std::copy_if(keys_values.cbegin(), keys_values.cend(), std::back_inserter(kv),
  1275. [](const AbstractAttribute &item) {
  1276. MS_EXCEPTION_IF_NULL(item.second);
  1277. return (!item.second->isa<abstract::AbstractFunction>());
  1278. });
  1279. return std::make_shared<AbstractDictionary>(kv);
  1280. }
  1281. };
  1282. class PartialEvaluator : public Evaluator {
  1283. public:
  1284. PartialEvaluator() : Evaluator("PartialEvaluator") {}
  1285. ~PartialEvaluator() override = default;
  1286. EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
  1287. const AnfNodeConfigPtr &out_conf) override {
  1288. if (args_conf_list.size() == 0) {
  1289. MS_LOG(EXCEPTION) << "Args size should be greater than 0";
  1290. }
  1291. MS_EXCEPTION_IF_NULL(out_conf);
  1292. MS_EXCEPTION_IF_NULL(out_conf->node());
  1293. MS_EXCEPTION_IF_NULL(args_conf_list[0]);
  1294. MS_EXCEPTION_IF_NULL(args_conf_list[0]->ObtainEvalResult());
  1295. auto arg0_value = args_conf_list[0]->ObtainEvalResult()->abstract();
  1296. MS_EXCEPTION_IF_NULL(arg0_value);
  1297. AbstractBasePtrList args_spec_list{arg0_value};
  1298. // Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node.
  1299. if (arg0_value->isa<AbstractError>()) {
  1300. MS_EXCEPTION_IF_NULL(arg0_value->GetValueTrack());
  1301. auto ret = std::make_shared<AbstractError>(arg0_value->GetValueTrack()->cast<StringImmPtr>(), out_conf->node());
  1302. MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString()
  1303. << " as func is: " << arg0_value->ToString();
  1304. auto eval_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
  1305. evaluator_cache_mgr_->SetValue(args_spec_list, eval_result);
  1306. return eval_result;
  1307. }
  1308. auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0);
  1309. // Sometimes, node[0] in out_conf becomes phi0;
  1310. if (func->isa<PrimitiveAbstractClosure>()) {
  1311. auto prim_func = dyn_cast<PrimitiveAbstractClosure>(func);
  1312. MS_EXCEPTION_IF_NULL(prim_func->prim());
  1313. if (prim_func->prim()->isa<prim::DoSignaturePrimitive>()) {
  1314. prim::DoSignaturePrimitivePtr do_signature_prim = dyn_cast<prim::DoSignaturePrimitive>(prim_func->prim());
  1315. return HandleDoSignature(engine, do_signature_prim->function(), out_conf);
  1316. }
  1317. }
  1318. (void)std::transform(
  1319. args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list),
  1320. [](const ConfigPtr &config) -> AbstractBasePtr { return config->ObtainEvalResult()->abstract(); });
  1321. AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end());
  1322. auto cnode = out_conf->node()->cast<CNodePtr>();
  1323. MS_EXCEPTION_IF_NULL(cnode);
  1324. if (cnode->size() != (args_conf_list.size() + 1)) {
  1325. MS_LOG(EXCEPTION) << "Out_conf node: " << cnode->DebugString()
  1326. << ", args_conf_list: " << mindspore::ToString(args_conf_list);
  1327. }
  1328. AbstractFuncAtomPtrList partial_funcs_list;
  1329. auto build_partial = [args, cnode, &partial_funcs_list](const AbstractFuncAtomPtr &atom_func) {
  1330. auto new_func = std::make_shared<PartialAbstractClosure>(atom_func, args, cnode);
  1331. partial_funcs_list.push_back(new_func);
  1332. };
  1333. func->Visit(build_partial);
  1334. auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list);
  1335. auto eval_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
  1336. evaluator_cache_mgr_->SetValue(args_spec_list, eval_result);
  1337. return eval_result;
  1338. }
  1339. EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
  1340. MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
  1341. }
  1342. EvalResultPtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value,
  1343. const AnfNodeConfigPtr &out_conf) const {
  1344. MS_EXCEPTION_IF_NULL(engine);
  1345. MS_EXCEPTION_IF_NULL(out_conf);
  1346. MS_EXCEPTION_IF_NULL(out_conf->node());
  1347. auto cnode = out_conf->node()->cast<CNodePtr>();
  1348. if (cnode == nullptr) {
  1349. MS_LOG(EXCEPTION) << "Cnode is nullptr";
  1350. }
  1351. ScopeGuard scope_guard(out_conf->node()->scope());
  1352. TraceGuard trace_guard(std::make_shared<TraceDoSignature>(out_conf->node()->debug_info()));
  1353. std::vector<AnfNodePtr> new_nodes_inputs = cnode->inputs();
  1354. auto new_signature_value = std::make_shared<prim::DoSignatureMetaFuncGraph>("signature", signature_value);
  1355. new_nodes_inputs[1] = NewValueNode(new_signature_value);
  1356. FuncGraphPtr func_graph = cnode->func_graph();
  1357. MS_EXCEPTION_IF_NULL(func_graph);
  1358. CNodePtr new_cnode = func_graph->NewCNode(std::move(new_nodes_inputs));
  1359. AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
  1360. return engine->ForwardConfig(out_conf, fn_conf);
  1361. }
  1362. };
  1363. struct PrimitiveImplInferValue {
  1364. PrimitiveImpl impl_; // implement function of primitive
  1365. bool eval_value_; // whether evaluate value
  1366. TypePtr specify_out_type_; // whether specify return type
  1367. bool in_white_list_; // true if this Primitive in white list, else false.
  1368. };
  1369. using PrimitiveToImplMap = mindspore::HashMap<PrimitivePtr, PrimitiveImplInferValue, PrimitiveHasher, PrimitiveEqual>;
  1370. PrimitiveToImplMap &GetUniformPrimitiveToImplMap() {
  1371. using R = PrimitiveToImplMap::mapped_type;
  1372. static PrimitiveToImplMap uniform_prim_implement_map{
  1373. {prim::kPrimScalarAdd, R{prim::ScalarAdd, true, nullptr, true}},
  1374. {prim::kPrimScalarSub, R{prim::ScalarSub, true, nullptr, true}},
  1375. {prim::kPrimScalarMul, R{prim::ScalarMul, true, nullptr, true}},
  1376. {prim::kPrimScalarDiv, R{prim::ScalarDiv, true, nullptr, true}},
  1377. {prim::kPrimScalarMod, R{prim::ScalarMod, true, nullptr, true}},
  1378. {prim::kPrimScalarPow, R{prim::ScalarPow, true, nullptr, true}},
  1379. {prim::kPrimScalarFloordiv, R{prim::ScalarFloordiv, true, nullptr, true}},
  1380. {prim::kPrimScalarUadd, R{prim::ScalarUAdd, true, nullptr, true}},
  1381. {prim::kPrimScalarUsub, R{prim::ScalarUSub, true, nullptr, true}},
  1382. {prim::kPrimScalarLog, R{prim::ScalarLog, true, nullptr, true}},
  1383. {prim::kPrimScalarEq, R{prim::ScalarEq, true, std::make_shared<Bool>(), true}},
  1384. {prim::kPrimScalarLt, R{prim::ScalarLt, true, std::make_shared<Bool>(), true}},
  1385. {prim::kPrimScalarGt, R{prim::ScalarGt, true, std::make_shared<Bool>(), true}},
  1386. {prim::kPrimScalarNe, R{prim::ScalarNe, true, std::make_shared<Bool>(), true}},
  1387. {prim::kPrimScalarLe, R{prim::ScalarLe, true, std::make_shared<Bool>(), true}},
  1388. {prim::kPrimScalarGe, R{prim::ScalarGe, true, std::make_shared<Bool>(), true}},
  1389. {prim::kPrimBoolNot, R{prim::BoolNot, true, std::make_shared<Bool>(), true}},
  1390. {prim::kPrimBoolAnd, R{prim::BoolAnd, true, std::make_shared<Bool>(), true}},
  1391. {prim::kPrimBoolEq, R{prim::BoolEq, true, std::make_shared<Bool>(), true}},
  1392. {prim::kPrimBoolOr, R{prim::BoolOr, true, std::make_shared<Bool>(), true}},
  1393. };
  1394. return uniform_prim_implement_map;
  1395. }
  1396. PrimEvaluatorMap PrimEvaluatorConstructors = PrimEvaluatorMap();
  1397. std::mutex PrimEvaluatorConstructorMutex;
  1398. void InitPrimEvaluatorConstructors() {
  1399. PrimEvaluatorMap &constructor = PrimEvaluatorConstructors;
  1400. for (const auto &iter : GetPrimitiveToEvalImplMap()) {
  1401. constructor[iter.first] = InitStandardPrimEvaluator(iter.first, iter.second);
  1402. }
  1403. for (const auto &iter : GetUniformPrimitiveToImplMap()) {
  1404. constructor[iter.first] =
  1405. InitUniformPrimEvaluator(iter.first, iter.second.impl_, iter.second.eval_value_, iter.second.specify_out_type_);
  1406. }
  1407. constructor[prim::kPrimEmbed] = std::make_shared<EmbedEvaluator>();
  1408. constructor[prim::kPrimRefToEmbed] = std::make_shared<RefToEmbedEvaluator>();
  1409. constructor[prim::kPrimGetAttr] = std::make_shared<GetAttrEvaluator>();
  1410. constructor[prim::kPrimResolve] = std::make_shared<ResolveEvaluator>();
  1411. constructor[prim::kPrimCreateInstance] = std::make_shared<CreateInstanceEvaluator>();
  1412. constructor[prim::kPrimPartial] = std::make_shared<PartialEvaluator>();
  1413. constructor[prim::kPrimPyInterpret] = std::make_shared<PyInterpretEvaluator>();
  1414. }
  1415. } // namespace
  1416. void ClearPrimEvaluatorMap() {
  1417. PrimEvaluatorConstructors.clear();
  1418. GetPrimitiveToEvalImplMap().clear();
  1419. GetUniformPrimitiveToImplMap().clear();
  1420. }
  1421. bool IsInWhiteList(const PrimitivePtr &primitive) {
  1422. MS_EXCEPTION_IF_NULL(primitive);
  1423. auto iter = GetPrimitiveToEvalImplMap().find(primitive);
  1424. if (iter != GetPrimitiveToEvalImplMap().end()) {
  1425. return iter->second.in_white_list_;
  1426. }
  1427. auto uni_iter = GetUniformPrimitiveToImplMap().find(primitive);
  1428. if (uni_iter != GetUniformPrimitiveToImplMap().end()) {
  1429. return uni_iter->second.in_white_list_;
  1430. }
  1431. return false;
  1432. }
  1433. PrimEvaluatorMap &GetPrimEvaluatorConstructors() {
  1434. PrimEvaluatorMap &constructor = PrimEvaluatorConstructors;
  1435. if (!constructor.empty()) {
  1436. return constructor;
  1437. }
  1438. std::lock_guard<std::mutex> initLock(PrimEvaluatorConstructorMutex);
  1439. if (constructor.empty()) {
  1440. InitPrimEvaluatorConstructors();
  1441. }
  1442. return constructor;
  1443. }
  1444. namespace {
  1445. bool IsSubtypeTuple(const AbstractBasePtr x, const TypePtr model) {
  1446. MS_EXCEPTION_IF_NULL(x);
  1447. MS_EXCEPTION_IF_NULL(model);
  1448. auto x_tuple = dyn_cast<AbstractTuple>(x);
  1449. auto model_tuple = dyn_cast<Tuple>(model);
  1450. if (x_tuple == nullptr || model_tuple == nullptr) {
  1451. return false;
  1452. }
  1453. if (model->IsGeneric()) {
  1454. return true;
  1455. }
  1456. if (x_tuple->size() != model_tuple->size()) {
  1457. return false;
  1458. }
  1459. for (size_t i = 0; i < x_tuple->size(); i++) {
  1460. bool is_subtype = IsSubtype((*x_tuple)[i], (*model_tuple)[i]);
  1461. if (!is_subtype) {
  1462. return false;
  1463. }
  1464. }
  1465. return true;
  1466. }
  1467. bool IsSubtypeArray(const AbstractBasePtr x, const TypePtr model) {
  1468. MS_EXCEPTION_IF_NULL(x);
  1469. MS_EXCEPTION_IF_NULL(model);
  1470. auto x_tensor = dyn_cast<AbstractTensor>(x);
  1471. auto model_tensor = dyn_cast<TensorType>(model);
  1472. if (x_tensor == nullptr || model_tensor == nullptr) {
  1473. return false;
  1474. }
  1475. if (model->IsGeneric()) {
  1476. return true;
  1477. }
  1478. return IsSubtype(x_tensor->element(), model_tensor->element());
  1479. }
  1480. bool IsSubtypeList(const AbstractBasePtr x, const TypePtr model) {
  1481. MS_EXCEPTION_IF_NULL(x);
  1482. MS_EXCEPTION_IF_NULL(model);
  1483. auto x_list = dyn_cast<AbstractList>(x);
  1484. auto model_list = dyn_cast<List>(model);
  1485. if (x_list == nullptr || model_list == nullptr) {
  1486. return false;
  1487. }
  1488. if (model->IsGeneric()) {
  1489. return true;
  1490. }
  1491. if (x_list->size() != model_list->size()) {
  1492. return false;
  1493. }
  1494. bool is_subtype = true;
  1495. for (size_t i = 0; i < x_list->size(); i++) {
  1496. is_subtype = IsSubtype((*x_list)[i], (*model_list)[i]);
  1497. if (!is_subtype) {
  1498. return false;
  1499. }
  1500. }
  1501. return is_subtype;
  1502. }
  1503. bool IsSubtypeClass(const AbstractBasePtr x, const TypePtr model) {
  1504. MS_EXCEPTION_IF_NULL(x);
  1505. MS_EXCEPTION_IF_NULL(model);
  1506. auto x_class = dyn_cast<AbstractClass>(x);
  1507. auto model_class = dyn_cast<Class>(model);
  1508. if (x_class == nullptr) {
  1509. return false;
  1510. }
  1511. if (model->IsGeneric()) {
  1512. return true;
  1513. }
  1514. MS_EXCEPTION_IF_NULL(model_class);
  1515. if (x_class->tag() == model_class->tag()) {
  1516. auto m_attributes = model_class->GetAttributes();
  1517. auto x_attributes = x_class->attributes();
  1518. if (m_attributes.size() != x_attributes.size()) {
  1519. return false;
  1520. }
  1521. for (size_t i = 0; i < m_attributes.size(); i++) {
  1522. if (!IsSubtype(x_attributes[i].second, m_attributes[i].second)) {
  1523. return false;
  1524. }
  1525. }
  1526. return true;
  1527. }
  1528. return false;
  1529. }
  1530. inline bool IsSubtypeScalar(const AbstractBasePtr x, const TypePtr model) {
  1531. MS_EXCEPTION_IF_NULL(x);
  1532. MS_EXCEPTION_IF_NULL(model);
  1533. if (dyn_cast<AbstractScalar>(x) == nullptr) {
  1534. return false;
  1535. }
  1536. TypePtr x_type = x->GetTypeTrack();
  1537. return IsSubType(x_type, model);
  1538. }
  1539. } // namespace
  1540. bool IsSubtype(const AbstractBasePtr x, const TypePtr model) {
  1541. MS_EXCEPTION_IF_NULL(x);
  1542. MS_EXCEPTION_IF_NULL(model);
  1543. TypeId model_typeid = model->type_id();
  1544. switch (model_typeid) {
  1545. case kMetaTypeObject:
  1546. return true;
  1547. case kObjectTypeTuple:
  1548. return IsSubtypeTuple(x, model);
  1549. case kObjectTypeTensorType:
  1550. return IsSubtypeArray(x, model);
  1551. case kObjectTypeList:
  1552. return IsSubtypeList(x, model);
  1553. case kObjectTypeClass:
  1554. return IsSubtypeClass(x, model);
  1555. default:
  1556. if (IsSubType(model, std::make_shared<Number>())) {
  1557. return IsSubtypeScalar(x, model);
  1558. }
  1559. MS_LOG(EXCEPTION) << "Invalid model type: " << model->ToString() << ".";
  1560. }
  1561. }
  1562. } // namespace abstract
  1563. } // namespace mindspore