You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prim.cc 57 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/jit/static_analysis/prim.h"
  19. #include <algorithm>
  20. #include <limits>
  21. #include <mutex>
  22. #include <string>
  23. #include <utility>
  24. #include <unordered_set>
  25. #include "frontend/operator/cc_implementations.h"
  26. #include "frontend/operator/ops.h"
  27. #include "frontend/operator/composite/do_signature.h"
  28. #include "frontend/operator/prim_to_function.h"
  29. #include "abstract/utils.h"
  30. #include "utils/symbolic.h"
  31. #include "pipeline/jit/resource.h"
  32. #include "pipeline/jit/parse/resolve.h"
  33. #include "utils/convert_utils.h"
  34. #include "utils/convert_utils_py.h"
  35. #include "utils/ms_context.h"
  36. #include "pipeline/jit/parse/data_converter.h"
  37. #include "abstract/primitive_infer_map.h"
  38. #include "abstract/param_validator.h"
  39. #include "utils/ms_utils.h"
  40. #include "utils/shape_utils.h"
  41. namespace mindspore {
  42. namespace abstract {
  43. using mindspore::parse::PyObjectWrapper;
  44. std::unordered_set<std::string> prims_to_skip_undetermined_infer{"make_tuple", "make_list", "switch", "env_setitem",
  45. "env_getitem"};
  46. EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
  47. AnfNodeConfigPtr out_conf) {
  48. AbstractBasePtrList args_spec_list;
  49. (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
  50. [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); });
  51. auto do_signature = prim_->cast<prim::DoSignaturePrimitivePtr>();
  52. auto &func = do_signature->function();
  53. if (func->isa<Primitive>()) {
  54. auto sig_prim = func->cast<PrimitivePtr>();
  55. if (prims_to_skip_undetermined_infer.find(sig_prim->name()) == prims_to_skip_undetermined_infer.end()) {
  56. auto ret_abstract = AbstractEval(args_spec_list);
  57. if (ret_abstract != nullptr) {
  58. MS_LOG(DEBUG) << "DoSignatureEvaluator eval Undetermined";
  59. return ret_abstract;
  60. }
  61. }
  62. }
  63. if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
  64. MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
  65. }
  66. auto out_node = dyn_cast<CNode>(out_conf->node());
  67. const auto &out_node_inputs = out_node->inputs();
  68. if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) {
  69. MS_LOG(EXCEPTION) << "Op: " << do_signature->function()->ToString()
  70. << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
  71. << ", inputs size " << out_node_inputs.size();
  72. }
  73. AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
  74. ScopePtr scope = kDefaultScope;
  75. if (out_conf != nullptr) {
  76. scope = out_conf->node()->scope();
  77. }
  78. ScopeGuard scope_guard(scope);
  79. AnfNodePtr new_cnode = nullptr;
  80. if (bound_node() != nullptr) {
  81. TraceManager::DebugTrace(std::make_shared<TraceDoSignature>(bound_node()->debug_info()));
  82. new_cnode = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list,
  83. args_inputs);
  84. TraceManager::EndTrace();
  85. } else {
  86. new_cnode = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list,
  87. args_inputs);
  88. }
  89. AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context());
  90. return engine->ForwardConfig(out_conf, fn_conf);
  91. }
  92. static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_spec_list, bool need_unpack) {
  93. // arg[0] is the func graph to unpack, ignore it
  94. AbstractBasePtrList specialize_args_before_unpack(args_spec_list.begin() + 1, args_spec_list.end());
  95. AbstractBasePtrList graph_specialize_args;
  96. if (need_unpack) {
  97. for (size_t index = 0; index < specialize_args_before_unpack.size(); index++) {
  98. MS_EXCEPTION_IF_NULL(specialize_args_before_unpack[index]);
  99. if (specialize_args_before_unpack[index]->isa<AbstractTuple>()) {
  100. auto arg_tuple = specialize_args_before_unpack[index]->cast<AbstractTuplePtr>();
  101. std::transform(arg_tuple->elements().begin(), arg_tuple->elements().end(),
  102. std::back_inserter(graph_specialize_args), [](AbstractBasePtr abs) { return abs; });
  103. } else if (specialize_args_before_unpack[index]->isa<AbstractDictionary>()) {
  104. auto arg_dict = specialize_args_before_unpack[index]->cast<AbstractDictionaryPtr>();
  105. auto dict_elems = arg_dict->elements();
  106. (void)std::transform(
  107. dict_elems.begin(), dict_elems.end(), std::back_inserter(graph_specialize_args),
  108. [](const AbstractAttribute &item) { return std::make_shared<AbstractKeywordArg>(item.first, item.second); });
  109. } else {
  110. MS_LOG(EXCEPTION) << "UnpackGraph require args should be tuple or dict, but got "
  111. << specialize_args_before_unpack[index]->ToString();
  112. }
  113. }
  114. } else {
  115. graph_specialize_args = specialize_args_before_unpack;
  116. }
  117. return graph_specialize_args;
  118. }
  119. EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
  120. AnfNodeConfigPtr out_conf) {
  121. if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
  122. MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
  123. }
  124. auto unpack_graph = prim_->cast<prim::UnpackGraphPrimitivePtr>();
  125. auto out_node = out_conf->node()->cast<CNodePtr>();
  126. const auto &out_node_inputs = out_node->inputs();
  127. if (out_node->inputs().empty() || (out_node_inputs.size() - 1) != args_conf_list.size()) {
  128. MS_LOG(EXCEPTION) << "UnpackGraphPrimitive"
  129. << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
  130. << ", inputs size " << out_node_inputs.size();
  131. }
  132. AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
  133. AbstractBasePtrList args_spec_list;
  134. (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
  135. [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); });
  136. // get the forward graph
  137. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  138. auto fn = args_spec_list[0]->cast<AbstractFunctionPtr>();
  139. if (fn == nullptr) {
  140. MS_LOG(EXCEPTION) << "UnpackGraphPrimitive arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString();
  141. }
  142. auto real_fn = fn->cast<FuncGraphAbstractClosurePtr>();
  143. MS_EXCEPTION_IF_NULL(real_fn);
  144. FuncGraphPtr forward_graph = real_fn->func_graph();
  145. MS_EXCEPTION_IF_NULL(forward_graph);
  146. AbstractBasePtrList graph_specialize_args =
  147. GetUnpackGraphSpecArgsList(args_spec_list, unpack_graph->need_unpack_args());
  148. AbstractBasePtrList graph_specialize_args_without_sens;
  149. if (unpack_graph->with_sens_in_args() && graph_specialize_args.empty()) {
  150. MS_EXCEPTION(ValueError) << "Grad with sens, but the sens is not provided.";
  151. }
  152. (void)std::transform(graph_specialize_args.begin(),
  153. graph_specialize_args.end() - (unpack_graph->with_sens_in_args() ? 1 : 0),
  154. std::back_inserter(graph_specialize_args_without_sens), [](AbstractBasePtr abs) { return abs; });
  155. auto new_graph = forward_graph->GenerateGraph(graph_specialize_args_without_sens);
  156. engine->func_graph_manager()->AddFuncGraph(new_graph);
  157. ScopePtr scope = kDefaultScope;
  158. if (out_conf != nullptr) {
  159. scope = out_conf->node()->scope();
  160. }
  161. ScopeGuard scope_guard(scope);
  162. AnfNodePtr new_vnode = NewValueNode(new_graph);
  163. AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_vnode, out_conf->context());
  164. return engine->ForwardConfig(out_conf, fn_conf);
  165. }
  166. AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const AbstractBasePtr &node_type,
  167. const AnfNodePtr &target_type, const FuncGraphPtr &func_graph) {
  168. AnfNodePtr target_node = source_node;
  169. if (node_type->isa<AbstractTensor>()) {
  170. auto x = node_type->cast<AbstractTensorPtr>();
  171. if (x->element()->BuildType()->isa<Float>()) {
  172. auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional");
  173. MS_EXCEPTION_IF_NULL(cast);
  174. target_node = func_graph->NewCNode({NewValueNode(cast), source_node, target_type});
  175. }
  176. } else if (node_type->isa<AbstractTuple>()) {
  177. auto x = node_type->cast<AbstractTuplePtr>();
  178. auto &items = x->elements();
  179. std::vector<AnfNodePtr> nodes;
  180. nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  181. int idx = 0;
  182. for (const auto &item : items) {
  183. AnfNodePtr tuple_node =
  184. func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), source_node, NewValueNode(idx)});
  185. AnfNodePtr node = MixedPrecisionCastHelper(tuple_node, item, target_type, func_graph);
  186. nodes.emplace_back(node);
  187. ++idx;
  188. }
  189. target_node = func_graph->NewCNode(nodes);
  190. } else if (node_type->isa<AbstractDictionary>()) {
  191. auto x = node_type->cast<AbstractDictionaryPtr>();
  192. auto &items = x->elements();
  193. std::vector<AnfNodePtr> dict_key_nodes;
  194. std::vector<AnfNodePtr> dict_value_nodes;
  195. dict_key_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  196. dict_value_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  197. for (const auto &item : items) {
  198. AnfNodePtr dict_value_node =
  199. func_graph->NewCNode({NewValueNode(prim::kPrimDictGetItem), source_node, NewValueNode(item.first)});
  200. AnfNodePtr node = MixedPrecisionCastHelper(dict_value_node, item.second, target_type, func_graph);
  201. dict_key_nodes.emplace_back(NewValueNode(item.first));
  202. dict_value_nodes.emplace_back(node);
  203. }
  204. target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), func_graph->NewCNode(dict_key_nodes),
  205. func_graph->NewCNode(dict_value_nodes)});
  206. } else if (node_type->isa<AbstractKeywordArg>()) {
  207. auto x = node_type->cast<AbstractKeywordArgPtr>();
  208. std::string kwarg_key = x->get_key();
  209. AnfNodePtr kwarg_value_node =
  210. func_graph->NewCNode({NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kwarg_key), source_node});
  211. AnfNodePtr node = MixedPrecisionCastHelper(kwarg_value_node, x->get_arg(), target_type, func_graph);
  212. target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(kwarg_key), node});
  213. }
  214. return target_node;
  215. }
  216. EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
  217. AnfNodeConfigPtr out_conf) {
  218. AbstractBasePtrList args_spec_list;
  219. if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
  220. MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
  221. }
  222. auto out_node = out_conf->node()->cast<CNodePtr>();
  223. const auto &out_node_inputs = out_node->inputs();
  224. if (out_node->inputs().empty() || (out_node_inputs.size() - 1) != args_conf_list.size()) {
  225. MS_LOG(EXCEPTION) << "MixedPrecisionCast"
  226. << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
  227. << ", inputs size " << out_node_inputs.size();
  228. }
  229. AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
  230. (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
  231. [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); });
  232. ScopePtr scope = kDefaultScope;
  233. if (out_conf != nullptr) {
  234. scope = out_conf->node()->scope();
  235. }
  236. ScopeGuard scope_guard(scope);
  237. FuncGraphPtr func_graph = out_conf->node()->func_graph();
  238. AnfNodePtr new_node = MixedPrecisionCastHelper(out_node_inputs[2], args_spec_list[1], out_node_inputs[1], func_graph);
  239. AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context());
  240. return engine->ForwardConfig(out_conf, fn_conf);
  241. }
  242. namespace {
  243. py::object BuildValue(const ValuePtr &value_ptr) {
  244. if (value_ptr == nullptr) {
  245. return py::none();
  246. } else {
  247. return ValuePtrToPyData(value_ptr);
  248. }
  249. }
  250. } // end anonymous namespace
  251. py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
  252. MS_EXCEPTION_IF_NULL(abs_base);
  253. py::dict dic;
  254. if (abs_base->isa<AbstractTensor>()) {
  255. auto arg_tensor = dyn_cast<AbstractTensor>(abs_base);
  256. dic[ATTR_SHAPE] = arg_tensor->shape()->shape();
  257. if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
  258. const auto &min_shape = arg_tensor->shape()->min_shape();
  259. const auto &max_shape = arg_tensor->shape()->max_shape();
  260. if (!min_shape.empty() && !max_shape.empty()) {
  261. dic[ATTR_MIN_SHAPE] = min_shape;
  262. dic[ATTR_MAX_SHAPE] = max_shape;
  263. }
  264. }
  265. dic[ATTR_DTYPE] = arg_tensor->BuildType();
  266. dic[ATTR_VALUE] = BuildValue(arg_tensor->BuildValue());
  267. } else if (abs_base->isa<AbstractRowTensor>()) {
  268. auto arg = dyn_cast<AbstractRowTensor>(abs_base);
  269. dic[ATTR_SHAPE] = arg->shape()->shape();
  270. dic[ATTR_DTYPE] = arg->BuildType();
  271. dic[ATTR_VALUE] = BuildValue(arg->BuildValue());
  272. } else if (abs_base->isa<AbstractSparseTensor>()) {
  273. auto arg = dyn_cast<AbstractSparseTensor>(abs_base);
  274. dic[ATTR_SHAPE] = arg->shape()->shape();
  275. dic[ATTR_DTYPE] = arg->BuildType();
  276. dic[ATTR_VALUE] = BuildValue(arg->BuildValue());
  277. } else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>() || abs_base->isa<AbstractRefKey>()) {
  278. ShapeVector shape;
  279. dic[ATTR_SHAPE] = shape;
  280. dic[ATTR_DTYPE] = abs_base->BuildType();
  281. dic[ATTR_VALUE] = BuildValue(abs_base->BuildValue());
  282. } else if (abs_base->isa<AbstractSlice>()) {
  283. auto arg_slice = dyn_cast<AbstractSlice>(abs_base);
  284. ShapeVector shape;
  285. dic[ATTR_SHAPE] = shape;
  286. dic[ATTR_DTYPE] = arg_slice->BuildType();
  287. dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue());
  288. } else if (abs_base->isa<AbstractEllipsis>()) {
  289. dic[ATTR_SHAPE] = py::none();
  290. dic[ATTR_DTYPE] = py::ellipsis();
  291. dic[ATTR_VALUE] = py::ellipsis();
  292. } else if (abs_base->isa<AbstractTuple>()) {
  293. auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
  294. size_t len = arg_tuple->size();
  295. py::tuple shape_tuple(len);
  296. py::tuple dtype_tuple(len);
  297. for (size_t i = 0; i < len; i++) {
  298. py::dict out = ConvertAbstractToPython(arg_tuple->elements()[i]);
  299. shape_tuple[i] = out[ATTR_SHAPE];
  300. dtype_tuple[i] = out[ATTR_DTYPE];
  301. }
  302. dic[ATTR_SHAPE] = shape_tuple;
  303. dic[ATTR_DTYPE] = dtype_tuple;
  304. dic[ATTR_VALUE] = BuildValue(arg_tuple->BuildValue());
  305. } else if (abs_base->isa<AbstractList>()) {
  306. auto arg_list = dyn_cast<AbstractList>(abs_base);
  307. size_t len = arg_list->size();
  308. py::list shape_list(len);
  309. py::list dtype_list(len);
  310. for (size_t i = 0; i < len; i++) {
  311. py::dict out = ConvertAbstractToPython(arg_list->elements()[i]);
  312. shape_list[i] = out[ATTR_SHAPE];
  313. dtype_list[i] = out[ATTR_DTYPE];
  314. }
  315. dic[ATTR_SHAPE] = shape_list;
  316. dic[ATTR_DTYPE] = dtype_list;
  317. dic[ATTR_VALUE] = BuildValue(arg_list->BuildValue());
  318. } else if (abs_base->isa<AbstractNone>()) {
  319. dic[ATTR_SHAPE] = py::none();
  320. dic[ATTR_DTYPE] = py::none();
  321. dic[ATTR_VALUE] = py::none();
  322. } else if (abs_base->isa<AbstractFunction>()) {
  323. dic[ATTR_SHAPE] = py::none();
  324. dic[ATTR_DTYPE] = abs_base->BuildType();
  325. dic[ATTR_VALUE] = py::none();
  326. } else if (abs_base->isa<AbstractUndetermined>()) {
  327. auto arg = dyn_cast<AbstractUndetermined>(abs_base);
  328. dic[ATTR_SHAPE] = py::none();
  329. dic[ATTR_DTYPE] = arg->BuildType();
  330. dic[ATTR_VALUE] = py::none();
  331. } else {
  332. auto value = abs_base->BuildValue();
  333. if ((*value == *kAnyValue)) {
  334. auto value_desc = abs_base->value_desc();
  335. MS_EXCEPTION(TypeError) << "Unsupported parameter " << (value_desc.empty() ? "type" : value_desc)
  336. << " for python primitive." << abs_base->ToString();
  337. }
  338. MS_EXCEPTION(TypeError) << "Unsupported parameter type for python primitive, the parameter value is "
  339. << value->ToString();
  340. }
  341. return dic;
  342. }
  343. namespace {
  344. py::tuple PreparePyInputs(const PrimitivePyPtr &prim_py, const AbstractBasePtrList &args) {
  345. const AbstractBasePtrList *args_ptr;
  346. if (prim_py->is_tuple_input_) {
  347. if (args.empty()) {
  348. MS_LOG(EXCEPTION) << "Primitive args is empty";
  349. }
  350. if (args[0] == nullptr || !args[0]->isa<AbstractTuple>()) {
  351. MS_LOG(EXCEPTION) << "Custom Primitive inputs should be packed into a Tuple after converting"
  352. "prim convert pass for GE.";
  353. }
  354. args_ptr = &(args[0]->cast<AbstractTuplePtr>()->elements());
  355. } else {
  356. args_ptr = &args;
  357. }
  358. py::tuple py_args(args_ptr->size());
  359. for (size_t i = 0; i < args_ptr->size(); i++) {
  360. auto arg_i = (*args_ptr)[i];
  361. py_args[i] = ConvertAbstractToPython(arg_i);
  362. }
  363. return py_args;
  364. }
  365. AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) {
  366. // Convert to AbstractValue based on type and shape
  367. auto out_dtype = output[ATTR_DTYPE];
  368. if (output[ATTR_VALUE].is_none()) {
  369. auto out_shape = output[ATTR_SHAPE];
  370. py::object min_shape =
  371. output.contains(py::str(ATTR_MIN_SHAPE)) ? (py::object)output[ATTR_MIN_SHAPE] : (py::object)py::none();
  372. py::object max_shape =
  373. output.contains(py::str(ATTR_MAX_SHAPE)) ? (py::object)output[ATTR_MAX_SHAPE] : (py::object)py::none();
  374. return PyListDtype2AbstractTensor(out_shape, out_dtype, min_shape, max_shape);
  375. }
  376. // Convert pyobject to Value, then to AbstractValue
  377. ValuePtr converted_ret = nullptr;
  378. TypePtr dtype = py::isinstance<Type>(out_dtype) ? out_dtype.cast<TypePtr>() : nullptr;
  379. bool converted = parse::ConvertData(output[ATTR_VALUE], &converted_ret, false, dtype);
  380. if (!converted) {
  381. MS_LOG(EXCEPTION) << "Convert data failed";
  382. }
  383. auto res_spec = FromValue(converted_ret);
  384. MS_EXCEPTION_IF_NULL(res_spec);
  385. if (res_spec->isa<AbstractTensor>()) {
  386. // Replace to tensor constant node in specialize
  387. auto res_tensor = res_spec->cast<AbstractTensorPtr>();
  388. res_tensor->set_value(converted_ret);
  389. }
  390. if (prim_py->IsCustomPrim()) {
  391. // Raise error if output_num is not match the infer result.
  392. int output_num = GetValue<int>(prim_py->GetAttr("output_num"));
  393. if (res_spec->isa<AbstractTensor>() && output_num != 1) {
  394. MS_LOG(EXCEPTION) << "Custom primitive " << prim_py->ToString() << " output_num " << output_num
  395. << " not matches the infer result.";
  396. } else if (res_spec->isa<AbstractTuple>() &&
  397. (res_spec->cast<AbstractTuplePtr>()->size() != IntToSize(output_num))) {
  398. MS_LOG(EXCEPTION) << "Custom primitive " << prim_py->ToString() << " output_num " << output_num
  399. << " not matches the infer result.";
  400. }
  401. }
  402. return res_spec;
  403. }
  404. } // end anonymous namespace
  405. EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
  406. auto prim_py = dyn_cast<PrimitivePy>(prim_);
  407. if (prim_py == nullptr) {
  408. MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyInferCheck' should be a python primitive.";
  409. }
  410. // Call checking method '__check__' for subclass of 'PrimitiveWithCheck'
  411. MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString();
  412. auto py_args = PreparePyInputs(prim_py, args);
  413. prim_py->RunCheck(py_args);
  414. prim_->BeginRecordAddAttr();
  415. AbstractBasePtr abs_base = eval_impl_(engine, prim_, args);
  416. prim_->EndRecordAddAttr();
  417. auto added_attrs = prim_->evaluate_added_attrs();
  418. if (!py::hasattr(prim_py->GetPyObj(), PY_PRIM_METHOD_INFER_VALUE)) {
  419. return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
  420. }
  421. // Call method 'infer_value' for primitive with this method for constant propagation
  422. py::tuple py_vals(py_args.size());
  423. for (size_t i = 0; i < py_args.size(); ++i) {
  424. py_vals[i] = py_args[i][ATTR_VALUE];
  425. }
  426. py::object py_ret = prim_py->RunInferValue(py_vals);
  427. if (py::isinstance<py::none>(py_ret)) {
  428. return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
  429. }
  430. // Convert pyobject to Value, then to AbstractValue
  431. ValuePtr converted_ret = nullptr;
  432. TypePtr dtype = abs_base->BuildType();
  433. bool converted = parse::ConvertData(py_ret, &converted_ret, false, dtype);
  434. if (!converted) {
  435. MS_LOG(EXCEPTION) << "Convert data failed";
  436. }
  437. auto res_spec = FromValue(converted_ret);
  438. MS_EXCEPTION_IF_NULL(res_spec);
  439. if (res_spec->isa<AbstractTensor>()) {
  440. // Replace to tensor constant node in specialize
  441. auto res_tensor = res_spec->cast<AbstractTensorPtr>();
  442. res_tensor->set_value(converted_ret);
  443. }
  444. return std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
  445. }
  446. EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
  447. if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) {
  448. auto ret_abstract = AbstractEval(args);
  449. if (ret_abstract != nullptr) {
  450. MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined";
  451. return ret_abstract;
  452. }
  453. }
  454. if (prim_->prim_type() == PrimType::kPrimTypePyInferCheck) {
  455. return EvalPyCheckPrim(engine, args);
  456. }
  457. prim_->BeginRecordAddAttr();
  458. AbstractBasePtr abs_base = eval_impl_(engine, prim_, args);
  459. prim_->EndRecordAddAttr();
  460. auto added_attrs = prim_->evaluate_added_attrs();
  461. return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
  462. }
  463. EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
  464. auto ret_abstract = AbstractEval(args);
  465. if (ret_abstract != nullptr) {
  466. MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined";
  467. return ret_abstract;
  468. }
  469. MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString();
  470. const auto &iter = cache_->find(args);
  471. if (iter != cache_->end()) {
  472. return iter->second;
  473. }
  474. auto py_args = PreparePyInputs(prim_py_, args);
  475. prim_py_->BeginRecordAddAttr();
  476. py::dict output = prim_py_->RunInfer(py_args);
  477. prim_py_->EndRecordAddAttr();
  478. auto added_attrs = prim_py_->evaluate_added_attrs();
  479. MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output);
  480. auto res_spec = PyInferRes2Abstract(prim_py_, output);
  481. MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << ".";
  482. auto infer_result = std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
  483. (*cache_)[args] = infer_result;
  484. return infer_result;
  485. }
  486. EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
  487. auto ret_abstract = AbstractEval(args);
  488. if (ret_abstract != nullptr) {
  489. MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined";
  490. return ret_abstract;
  491. }
  492. // if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type.
  493. if (nargs_ != args.size()) {
  494. MS_LOG(EXCEPTION) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size() << " inputs";
  495. }
  496. TypePtr ret_value_type = return_value_type_;
  497. ValuePtrList value_list;
  498. for (const auto &arg : args) {
  499. // Check if all arguments are scalar type.
  500. MS_EXCEPTION_IF_NULL(arg);
  501. if (arg->isa<AbstractScalar>()) {
  502. auto arg_scalar = dyn_cast<AbstractScalar>(arg);
  503. auto arg_value = arg_scalar->GetValueTrack();
  504. value_list.push_back(arg_value);
  505. } else {
  506. // Raise TypeError Expected Scalar.
  507. MS_LOG(EXCEPTION) << "Expect scalar arguments for uniform primitives.";
  508. }
  509. }
  510. for (const auto &item : type_map_) {
  511. TypePtrList selections;
  512. MS_EXCEPTION_IF_NULL(item.second);
  513. (void)std::transform(item.second->begin(), item.second->end(), std::back_inserter(selections),
  514. [&args](size_t arg_idx) -> TypePtr { return args[arg_idx]->GetTypeTrack(); });
  515. TypePtr res = CheckTypeList(item.first, selections);
  516. if (*return_value_type_ == *(item.first)) {
  517. ret_value_type = res;
  518. }
  519. }
  520. ValuePtr evaluated_value = RunImpl(value_list);
  521. if (!(*evaluated_value == *kAnyValue)) {
  522. ret_value_type = evaluated_value->type();
  523. }
  524. // for comparison primitives , return type shall have be specified to be bool.
  525. if (specify_out_type_ != nullptr) {
  526. ret_value_type = specify_out_type_;
  527. }
  528. AbstractScalarPtr abs_base = std::make_shared<AbstractScalar>(evaluated_value, ret_value_type);
  529. return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>());
  530. }
  531. ValuePtr UniformPrimEvaluator::RunImpl(const ValuePtrList &args) const {
  532. if (!eval_value_) {
  533. return kAnyValue;
  534. } else {
  535. if (std::any_of(args.begin(), args.end(), [](const ValuePtr &arg) {
  536. MS_EXCEPTION_IF_NULL(arg);
  537. return arg->isa<AnyValue>();
  538. })) {
  539. return kAnyValue;
  540. }
  541. return impl_(args);
  542. }
  543. }
  544. // Primitive implementation
  545. // static function start
  546. namespace {
  547. EvaluatorPtr InitStandardPrimEvaluator(PrimitivePtr primitive, const StandardPrimitiveEvalImpl eval_impl) {
  548. EvaluatorPtr prim_evaluator = std::make_shared<StandardPrimEvaluator>(primitive, eval_impl);
  549. return prim_evaluator;
  550. }
  551. EvaluatorPtr InitUniformPrimEvaluator(const PrimitivePtr &primitive, PrimitiveImpl prim_impl, bool eval_value,
  552. const TypePtr &specify_out_type) {
  553. FunctionPtr func = nullptr;
  554. (void)prim::PrimToFunction::GetInstance().GetFunction(primitive, &func);
  555. MS_EXCEPTION_IF_NULL(func);
  556. EvaluatorPtr uniform_primitive_evaluator =
  557. std::make_shared<UniformPrimEvaluator>(func, prim_impl, eval_value, specify_out_type);
  558. return uniform_primitive_evaluator;
  559. }
  560. const int kResolveCaseUserDefineClass = 1;
  561. const int kResolveCaseBuiltInType = 2;
  562. const int kResolveCaseFunction = 3;
  563. int GetResolveCase(const TypePtr &data_type) {
  564. MS_EXCEPTION_IF_NULL(data_type);
  565. if (data_type->type_id() == kObjectTypeClass) {
  566. return kResolveCaseUserDefineClass;
  567. }
  568. // try method map, if not in method map, the data_type should be External type.
  569. if (pipeline::Resource::IsTypeInBuiltInMap(data_type->type_id())) {
  570. return kResolveCaseBuiltInType;
  571. }
  572. return kResolveCaseFunction;
  573. }
  574. FuncGraphPtr PyObjToGraph(const AnalysisEnginePtr &engine, const ValuePtr &method) {
  575. MS_EXCEPTION_IF_NULL(engine);
  576. MS_EXCEPTION_IF_NULL(method);
  577. if (!method->isa<parse::PyObjectWrapper>()) {
  578. MS_LOG(EXCEPTION) << "Method type error: " << method->ToString();
  579. }
  580. std::shared_ptr<PyObjectWrapper> obj = method->cast<std::shared_ptr<PyObjectWrapper>>();
  581. FuncGraphPtr func_graph = mindspore::parse::ConvertToFuncGraph(obj->obj());
  582. if (func_graph == nullptr) {
  583. MS_LOG(EXCEPTION) << "Parse python object: " << method->ToString() << " failed";
  584. }
  585. FuncGraphManagerPtr manager = engine->func_graph_manager();
  586. manager->AddFuncGraph(func_graph);
  587. return func_graph;
  588. }
  589. inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr func_graph) {
  590. MS_EXCEPTION_IF_NULL(engine);
  591. FuncGraphManagerPtr manager = engine->func_graph_manager();
  592. manager->AddFuncGraph(func_graph);
  593. }
  594. enum REQUIRE_TYPE { ATTR, METHOD };
  595. EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, const AnfNodeConfigPtr &old_conf,
  596. REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD) {
  597. MS_EXCEPTION_IF_NULL(old_conf);
  598. AbstractBasePtr abs_ptr = ToAbstract(value, AnalysisContext::DummyContext(), old_conf);
  599. AbstractFunctionPtr abs_func = dyn_cast<abstract::AbstractFunction>(abs_ptr);
  600. MS_EXCEPTION_IF_NULL(abs_func);
  601. // Create new cnode
  602. std::vector<AnfNodePtr> input = {NewValueNode(prim::kPrimPartial)};
  603. auto func_graph_func = dyn_cast<abstract::FuncGraphAbstractClosure>(abs_func);
  604. if (func_graph_func != nullptr) {
  605. FuncGraphPtr fg = func_graph_func->func_graph();
  606. input.push_back(NewValueNode(fg));
  607. } else {
  608. auto prim_func = dyn_cast<abstract::PrimitiveAbstractClosure>(abs_func);
  609. MS_EXCEPTION_IF_NULL(prim_func);
  610. PrimitivePtr prim = prim_func->prim();
  611. input.push_back(NewValueNode(prim));
  612. }
  613. AnfNodeConfigPtr conf = dyn_cast<abstract::AnfNodeConfig>(data_conf);
  614. MS_EXCEPTION_IF_NULL(conf);
  615. input.push_back(conf->node());
  616. MS_EXCEPTION_IF_NULL(old_conf);
  617. FuncGraphPtr func_graph = old_conf->node()->func_graph();
  618. CNodePtr new_cnode = func_graph->NewCNode(input);
  619. if (require_type == REQUIRE_TYPE::ATTR) {
  620. new_cnode = func_graph->NewCNode({new_cnode});
  621. }
  622. AnalysisEnginePtr eng = old_conf->engine();
  623. AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, old_conf->context());
  624. return eng->ForwardConfig(old_conf, fn_conf);
  625. }
  626. EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engine,
  627. const AbstractBasePtrList &args_spec_list,
  628. const AnfNodeConfigPtr &out_conf) {
  629. // args_spec_list: same as StaticGetter
  630. if (args_spec_list.size() < 2) {
  631. MS_LOG(EXCEPTION) << "Size of args_spec_list is less than 2";
  632. }
  633. MS_EXCEPTION_IF_NULL(out_conf);
  634. // An external type.
  635. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  636. MS_EXCEPTION_IF_NULL(args_spec_list[1]);
  637. MS_LOG(DEBUG) << "Args[0]: " << args_spec_list[0]->ToString();
  638. MS_LOG(DEBUG) << "Args[1]: " << args_spec_list[1]->ToString();
  639. auto data_v = args_spec_list[0]->BuildValue();
  640. if (!data_v->isa<parse::NameSpace>()) {
  641. MS_LOG(EXCEPTION) << "Data is not NameSpace : " << data_v->ToString();
  642. }
  643. auto item_v = args_spec_list[1]->BuildValue();
  644. if (item_v->isa<StringImm>()) {
  645. item_v = std::make_shared<parse::Symbol>(item_v->cast<StringImmPtr>()->value());
  646. }
  647. if (!item_v->isa<parse::Symbol>()) {
  648. MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_v->ToString();
  649. }
  650. // item_name to func addr from obj_map
  651. parse::SymbolPtr symbol = item_v->cast<parse::SymbolPtr>();
  652. parse::NameSpacePtr name_space = data_v->cast<parse::NameSpacePtr>();
  653. FuncGraphPtr func_graph = out_conf->node()->func_graph();
  654. auto new_node = parse::ResolveSymbol(func_graph->manager(), name_space, symbol, out_conf->node());
  655. if (new_node == nullptr) {
  656. MS_LOG(EXCEPTION) << "Resolve node failed";
  657. }
  658. AnalysisEnginePtr eng = out_conf->engine();
  659. AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context());
  660. return eng->ForwardConfig(out_conf, fn_conf);
  661. }
  662. EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine,
  663. const AbstractBasePtrList &args_spec_list, const ValuePtr &item_v,
  664. const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
  665. if (args_spec_list.empty()) {
  666. MS_LOG(EXCEPTION) << "args_spec_list is empty";
  667. }
  668. AbstractClassPtr cls = CheckArg<AbstractClass>("__FUNC__", args_spec_list, 0);
  669. // If item_v is an attribute, get abstract value from AbstractClass
  670. MS_EXCEPTION_IF_NULL(item_v);
  671. if (!item_v->isa<StringImm>()) {
  672. MS_LOG(EXCEPTION) << "Attribute type error";
  673. }
  674. std::string item_name = item_v->cast<StringImmPtr>()->value();
  675. MS_LOG(DEBUG) << "Resolve name: " << cls->tag().name();
  676. MS_LOG(DEBUG) << "Resolve item: " << item_name;
  677. AbstractBasePtr attr = cls->GetAttribute(item_name);
  678. if (attr != nullptr) {
  679. return std::make_shared<EvalResult>(attr, nullptr);
  680. }
  681. ValuePtr method = cls->GetMethod(item_name);
  682. if (method->isa<AnyValue>()) {
  683. MS_EXCEPTION(AttributeError) << "Unknown field, data type: " << args_spec_list[0]->BuildType()->ToString()
  684. << ", item value: " << item_v->ToString();
  685. }
  686. // Infer class method
  687. ValuePtr converted_v = PyObjToGraph(engine, method);
  688. return StaticGetterInferred(converted_v, data_conf, out_conf);
  689. }
  690. EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v,
  691. const TypePtr &data_type, const ConfigPtr &data_conf,
  692. const AnfNodeConfigPtr &out_conf) {
  693. MS_EXCEPTION_IF_NULL(item_v);
  694. MS_EXCEPTION_IF_NULL(data_type);
  695. // The method maybe a Primitive or Composite
  696. if (!item_v->isa<StringImm>()) {
  697. MS_LOG(EXCEPTION) << "Error item is not string";
  698. }
  699. std::string item_name = item_v->cast<StringImmPtr>()->value();
  700. REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD;
  701. Any require = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name);
  702. if (require.empty()) {
  703. require = pipeline::Resource::GetAttrPtr(data_type->type_id(), item_name);
  704. if (require.empty()) {
  705. MS_LOG(EXCEPTION) << "The object of type: " << data_type->ToString() << " has no method or attr: " << item_name;
  706. }
  707. require_type = REQUIRE_TYPE::ATTR;
  708. }
  709. ValuePtr converted_v = nullptr;
  710. if (require.is<std::string>()) {
  711. // composite registered in standard_method_map go to this branch
  712. converted_v = prim::GetPythonOps(require.cast<std::string>());
  713. if (!converted_v->isa<Primitive>()) {
  714. AddToManager(engine, converted_v->cast<FuncGraphPtr>());
  715. }
  716. } else if (require.is<PrimitivePtr>()) {
  717. converted_v = require.cast<PrimitivePtr>();
  718. } else {
  719. MS_LOG(EXCEPTION) << "Expect to get string or PrimitivePtr from attr or method map, but got " << require.ToString();
  720. }
  721. return StaticGetterInferred(converted_v, data_conf, out_conf, require_type);
  722. }
  723. EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
  724. const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
  725. // Inputs: namespace and its static function; or class and its member function
  726. CheckArgsSize("StaticGetter", args_spec_list, 2);
  727. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  728. MS_EXCEPTION_IF_NULL(args_spec_list[1]);
  729. TypePtr data_type = args_spec_list[0]->BuildType();
  730. ValuePtr item_value = args_spec_list[1]->BuildValue();
  731. ScopePtr scope = kDefaultScope;
  732. if (out_conf != nullptr) {
  733. scope = out_conf->node()->scope();
  734. }
  735. ScopeGuard scope_guard(scope);
  736. if (item_value->isa<AnyValue>()) {
  737. MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
  738. }
  739. int case_v = GetResolveCase(data_type);
  740. if (case_v == kResolveCaseUserDefineClass) {
  741. return GetEvaluatedValueForClassAttrOrMethod(engine, args_spec_list, item_value, data_conf, out_conf);
  742. } else if (case_v == kResolveCaseBuiltInType) {
  743. return GetEvaluatedValueForBuiltinTypeAttrOrMethod(engine, item_value, data_type, data_conf, out_conf);
  744. } else {
  745. return GetEvaluatedValueForNameSpaceString(engine, args_spec_list, out_conf);
  746. }
  747. }
  748. } // end anonymous namespace
  749. // static variable start;
  750. namespace {
  751. class EmbedEvaluator : public SymbolicPrimEvaluator {
  752. public:
  753. EmbedEvaluator() : SymbolicPrimEvaluator("EmbedEvaluator") {}
  754. ~EmbedEvaluator() override = default;
  755. MS_DECLARE_PARENT(EmbedEvaluator, SymbolicPrimEvaluator);
  756. EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override {
  757. // arg: free variable to be embedded
  758. if (args_conf_list.size() != 1) {
  759. MS_LOG(EXCEPTION) << "EmbedEvaluator requires 1 parameter, but got " << args_conf_list.size();
  760. }
  761. AnfNodeConfigPtr node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]);
  762. MS_EXCEPTION_IF_NULL(node_conf);
  763. AbstractBasePtr x = node_conf->GetEvaluatedValue()->abstract();
  764. x = SensitivityTransform(x);
  765. SymbolicKeyInstancePtr key = std::make_shared<SymbolicKeyInstance>(node_conf->node(), x);
  766. AbstractScalarPtr abs_scalar = std::make_shared<AbstractScalar>(key, std::make_shared<SymbolicKeyType>());
  767. return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
  768. }
  769. };
  770. static AnfNodePtr FindParameterNodeByString(const FuncGraphManagerPtr &manager, const std::string &name) {
  771. auto root_g_set = manager->roots();
  772. if (root_g_set.size() != 1) {
  773. return nullptr;
  774. }
  775. const FuncGraphPtr &root_g = root_g_set.back();
  776. for (auto &param_node : root_g->parameters()) {
  777. auto param = param_node->cast<ParameterPtr>();
  778. if (param && name == param->name()) {
  779. return param;
  780. }
  781. }
  782. return nullptr;
  783. }
  784. class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
  785. public:
  786. RefToEmbedEvaluator() : SymbolicPrimEvaluator("RefToEmbedEvaluator") {}
  787. ~RefToEmbedEvaluator() override = default;
  788. MS_DECLARE_PARENT(RefToEmbedEvaluator, SymbolicPrimEvaluator);
  789. EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override {
  790. if (args_conf_list.size() != 1) {
  791. MS_LOG(ERROR) << "Requires 1 parameter, but has: " << args_conf_list.size();
  792. return nullptr;
  793. }
  794. static TypePtr type = std::make_shared<SymbolicKeyType>();
  795. auto node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]);
  796. if (node_conf == nullptr) {
  797. MS_LOG(ERROR) << "Conf should be AnfNodeConfig";
  798. return nullptr;
  799. }
  800. AbstractBasePtr abs = node_conf->GetEvaluatedValue()->abstract();
  801. AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>();
  802. if (ref_abs == nullptr) {
  803. MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString();
  804. return nullptr;
  805. }
  806. auto key_abs = ref_abs->ref_key();
  807. if (key_abs == nullptr) {
  808. MS_LOG(ERROR) << "RefToEmbed input Ref key is nullptr.";
  809. return nullptr;
  810. }
  811. auto key_value = key_abs->BuildValue();
  812. if (key_value == nullptr) {
  813. MS_LOG(ERROR) << "RefToEmbed input Ref key value is nullptr.";
  814. return nullptr;
  815. }
  816. auto refkey = key_value->cast<RefKeyPtr>();
  817. if (refkey == nullptr) {
  818. auto ret = std::make_shared<AbstractScalar>(type);
  819. auto ref_value = ref_abs->ref();
  820. MS_EXCEPTION_IF_NULL(ref_value);
  821. return std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
  822. }
  823. std::string name = refkey->tag();
  824. const auto &manager = node_conf->node()->func_graph()->manager();
  825. auto node = FindParameterNodeByString(manager, name);
  826. if (node == nullptr) {
  827. MS_LOG(ERROR) << "RefToEmbed input can't find parameter \"" << name << "\" in graph.";
  828. return nullptr;
  829. }
  830. AbstractBasePtr x = ref_abs->ref();
  831. x = SensitivityTransform(x);
  832. std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x);
  833. std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type);
  834. return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
  835. }
  836. };
  837. class GetAttrEvaluator : public TransitionPrimEvaluator {
  838. public:
  839. GetAttrEvaluator() : TransitionPrimEvaluator("GetAttrEvaluator") {}
  840. ~GetAttrEvaluator() override = default;
  841. MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator);
  842. EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
  843. const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
  844. auto ret_abstract = AbstractEval(args_spec_list);
  845. if (ret_abstract != nullptr) {
  846. MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined";
  847. return ret_abstract;
  848. }
  849. // Inputs: data, item
  850. if (args_spec_list.size() != 2) {
  851. MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
  852. }
  853. EvalResultPtr ret = nullptr;
  854. if (bound_node() != nullptr) {
  855. TraceManager::DebugTrace(std::make_shared<TraceResolve>(bound_node()->debug_info()));
  856. ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
  857. TraceManager::EndTrace();
  858. } else {
  859. ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
  860. }
  861. // don't lookup from cache, as different out_conf with same node but different context
  862. // may add different entry to anfnode_config_map, like getattr primitive;
  863. (*cache_)[args_spec_list] = ret;
  864. return ret;
  865. }
  866. };
  867. class ResolveEvaluator : public TransitionPrimEvaluator {
  868. public:
  869. ResolveEvaluator() : TransitionPrimEvaluator("ResolveEvaluator") {}
  870. ~ResolveEvaluator() override = default;
  871. MS_DECLARE_PARENT(ResolveEvaluator, TransitionPrimEvaluator);
  872. EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
  873. const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
  874. // Inputs: namespace, symbol
  875. if (args_spec_list.size() != 2) {
  876. MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
  877. }
  878. EvalResultPtr ret = nullptr;
  879. if (bound_node() != nullptr) {
  880. TraceManager::DebugTrace(std::make_shared<TraceResolve>(bound_node()->debug_info()));
  881. ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
  882. TraceManager::EndTrace();
  883. } else {
  884. ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
  885. }
  886. return ret;
  887. }
  888. };
  889. class CreateInstanceEvaluator : public TransitionPrimEvaluator {
  890. public:
  891. CreateInstanceEvaluator() : TransitionPrimEvaluator("CreateInstanceEvaluator") {}
  892. ~CreateInstanceEvaluator() override = default;
  893. MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator);
  894. EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
  895. const AnfNodeConfigPtr &out_conf) override {
  896. if (args_spec_list.empty()) {
  897. MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty";
  898. }
  899. // get the type parameter
  900. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  901. TypePtr type = args_spec_list[0]->GetTypeTrack();
  902. if (type->type_id() != kMetaTypeTypeType) {
  903. MS_LOG(EXCEPTION) << "CreateInstanceEvaluator require first parameter should be an object of TypeType, but got "
  904. << type->ToString();
  905. }
  906. ValuePtr value_track = args_spec_list[0]->GetValueTrack();
  907. MS_EXCEPTION_IF_NULL(value_track);
  908. std::shared_ptr<parse::PyObjectWrapper> type_obj = dyn_cast<parse::PyObjectWrapper>(value_track);
  909. if (type_obj == nullptr) {
  910. MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << ".";
  911. }
  912. if (!type_obj->isa<parse::ClassType>()) {
  913. MS_LOG(EXCEPTION) << "CreateInstanceEvaluator the type_obj should be an object of ClassType, but got "
  914. << type_obj->ToString() << ".";
  915. }
  916. auto class_type = type_obj->obj();
  917. MS_LOG(DEBUG) << "Get class type is " << type_obj->ToString() << ".";
  918. // get the create instance obj's parameters
  919. pybind11::tuple params = GetParameters(args_spec_list);
  920. // create class instance
  921. auto obj = parse::data_converter::CreatePythonObject(class_type, params);
  922. if (py::isinstance<py::none>(obj)) {
  923. MS_LOG(EXCEPTION) << "Create python object" << py::str(class_type)
  924. << " failed, only support create Cell or Primitive object.";
  925. }
  926. // process the object
  927. ValuePtr converted_ret = nullptr;
  928. bool converted = parse::ConvertData(obj, &converted_ret, true);
  929. if (!converted) {
  930. MS_LOG(EXCEPTION) << "Convert the python object failed";
  931. }
  932. MS_EXCEPTION_IF_NULL(converted_ret);
  933. if (converted_ret->isa<FuncGraph>()) {
  934. AddToManager(engine, converted_ret->cast<FuncGraphPtr>());
  935. }
  936. AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf);
  937. auto infer_result = std::make_shared<EvalResult>(ret, nullptr);
  938. (*cache_)[args_spec_list] = infer_result;
  939. return infer_result;
  940. }
  941. pybind11::tuple GetParameters(const AbstractBasePtrList &args_spec_list) const {
  942. // Exclude class type by minus 1;
  943. std::size_t params_size = args_spec_list.size() - 1;
  944. auto params = py::tuple(params_size);
  945. if (params_size > 0) {
  946. for (size_t i = 0; i < params_size; i++) {
  947. // Only support the Scalar parameters type. Bypass class type by offset with 1.
  948. auto arg = args_spec_list[i + 1];
  949. MS_EXCEPTION_IF_NULL(arg);
  950. // Because the Tensor's AbstractTensor can't get value from GetValueTrack.
  951. ValuePtr param_value = arg->BuildValue();
  952. py::object param = ValuePtrToPyData(param_value);
  953. params[i] = param;
  954. }
  955. }
  956. return params;
  957. }
  958. };
  959. class PartialEvaluator : public Evaluator {
  960. public:
  961. PartialEvaluator() : Evaluator("PartialEvaluator") {}
  962. ~PartialEvaluator() override = default;
  963. EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
  964. AnfNodeConfigPtr out_conf = nullptr) override {
  965. if (args_conf_list.size() == 0) {
  966. MS_LOG(EXCEPTION) << "Args size should be greater than 0";
  967. }
  968. MS_EXCEPTION_IF_NULL(out_conf);
  969. MS_EXCEPTION_IF_NULL(out_conf->node());
  970. auto arg0_value = args_conf_list[0]->GetEvaluatedValue()->abstract();
  971. AbstractBasePtrList args_spec_list{arg0_value};
  972. // Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node.
  973. if (arg0_value->isa<AbstractError>()) {
  974. auto ret = std::make_shared<AbstractError>(arg0_value->GetValueTrack()->cast<StringImmPtr>(), out_conf->node());
  975. MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString()
  976. << " as func is: " << arg0_value->ToString();
  977. auto eval_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
  978. (*cache_)[args_spec_list] = eval_result;
  979. return eval_result;
  980. }
  981. auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0);
  982. // Sometimes, node[0] in out_conf becomes phi0;
  983. if (func->isa<PrimitiveAbstractClosure>()) {
  984. auto prim_func = dyn_cast<PrimitiveAbstractClosure>(func);
  985. if (prim_func->prim()->isa<prim::DoSignaturePrimitive>()) {
  986. prim::DoSignaturePrimitivePtr do_signature_prim = dyn_cast<prim::DoSignaturePrimitive>(prim_func->prim());
  987. return HandleDoSignature(engine, do_signature_prim->function(), out_conf);
  988. }
  989. }
  990. (void)std::transform(
  991. args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list),
  992. [](const ConfigPtr &config) -> AbstractBasePtr { return config->GetEvaluatedValue()->abstract(); });
  993. AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end());
  994. auto cnode = out_conf->node()->cast<CNodePtr>();
  995. MS_EXCEPTION_IF_NULL(cnode);
  996. if (cnode->size() != (args_conf_list.size() + 1)) {
  997. MS_LOG(EXCEPTION) << "Out_conf node: " << cnode->DebugString()
  998. << ", args_conf_list: " << mindspore::ToString(args_conf_list);
  999. }
  1000. AbstractFuncAtomPtrList partial_funcs_list;
  1001. auto build_partial = [args, cnode, &partial_funcs_list](const AbstractFuncAtomPtr &atom_func) {
  1002. auto new_func = std::make_shared<PartialAbstractClosure>(atom_func, args, cnode);
  1003. partial_funcs_list.push_back(new_func);
  1004. };
  1005. func->Visit(build_partial);
  1006. auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list);
  1007. auto infer_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
  1008. (*cache_)[args_spec_list] = infer_result;
  1009. return infer_result;
  1010. }
  1011. EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
  1012. MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
  1013. }
  1014. EvalResultPtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value,
  1015. const AnfNodeConfigPtr &out_conf = nullptr) const {
  1016. MS_EXCEPTION_IF_NULL(out_conf);
  1017. MS_EXCEPTION_IF_NULL(out_conf->node());
  1018. auto cnode = out_conf->node()->cast<CNodePtr>();
  1019. if (cnode == nullptr) {
  1020. MS_LOG(EXCEPTION) << "Cnode is nullptr";
  1021. }
  1022. std::vector<AnfNodePtr> new_nodes_inputs = cnode->inputs();
  1023. auto new_signature_value = std::make_shared<prim::DoSignatureMetaFuncGraph>("signature", signature_value);
  1024. new_nodes_inputs[1] = NewValueNode(new_signature_value);
  1025. FuncGraphPtr func_graph = cnode->func_graph();
  1026. ScopePtr scope = out_conf->node()->scope();
  1027. ScopeGuard scope_guard(scope);
  1028. CNodePtr new_cnode = func_graph->NewCNode(new_nodes_inputs);
  1029. AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context());
  1030. return engine->ForwardConfig(out_conf, fn_conf);
  1031. }
  1032. };
  1033. struct PrimitiveImplInferValue {
  1034. PrimitiveImpl impl_; // implement function of primitive
  1035. bool eval_value_; // whether evaluate value
  1036. TypePtr specify_out_type_; // whether specify return type
  1037. bool in_white_list_; // true if this Primitive in white list, else false.
  1038. };
  1039. using PrimitiveToImplMap = std::unordered_map<PrimitivePtr, PrimitiveImplInferValue, PrimitiveHasher, PrimitiveEqual>;
  1040. PrimitiveToImplMap &GetUniformPrimitiveToImplMap() {
  1041. static PrimitiveToImplMap uniform_prim_implement_map = {
  1042. {prim::kPrimScalarAdd, {prim::ScalarAdd, true, nullptr, true}},
  1043. {prim::kPrimScalarSub, {prim::ScalarSub, true, nullptr, true}},
  1044. {prim::kPrimScalarMul, {prim::ScalarMul, true, nullptr, true}},
  1045. {prim::kPrimScalarDiv, {prim::ScalarDiv, true, nullptr, true}},
  1046. {prim::kPrimScalarMod, {prim::ScalarMod, true, nullptr, true}},
  1047. {prim::kPrimScalarPow, {prim::ScalarPow, true, nullptr, true}},
  1048. {prim::kPrimScalarFloordiv, {prim::ScalarFloordiv, true, nullptr, true}},
  1049. {prim::kPrimScalarUadd, {prim::ScalarUAdd, true, nullptr, true}},
  1050. {prim::kPrimScalarUsub, {prim::ScalarUSub, true, nullptr, true}},
  1051. {prim::kPrimScalarLog, {prim::ScalarLog, true, nullptr, true}},
  1052. {prim::kPrimScalarEq, {prim::ScalarEq, true, std::make_shared<Bool>(), true}},
  1053. {prim::kPrimScalarLt, {prim::ScalarLt, true, std::make_shared<Bool>(), true}},
  1054. {prim::kPrimScalarGt, {prim::ScalarGt, true, std::make_shared<Bool>(), true}},
  1055. {prim::kPrimScalarNe, {prim::ScalarNe, true, std::make_shared<Bool>(), true}},
  1056. {prim::kPrimScalarLe, {prim::ScalarLe, true, std::make_shared<Bool>(), true}},
  1057. {prim::kPrimScalarGe, {prim::ScalarGe, true, std::make_shared<Bool>(), true}},
  1058. {prim::kPrimBoolNot, {prim::BoolNot, true, std::make_shared<Bool>(), true}},
  1059. {prim::kPrimBoolAnd, {prim::BoolAnd, true, std::make_shared<Bool>(), true}},
  1060. {prim::kPrimBoolEq, {prim::BoolEq, true, std::make_shared<Bool>(), true}},
  1061. {prim::kPrimBoolOr, {prim::BoolOr, true, std::make_shared<Bool>(), true}},
  1062. };
  1063. return uniform_prim_implement_map;
  1064. }
  1065. PrimEvaluatorMap PrimEvaluatorConstructors = PrimEvaluatorMap();
  1066. std::mutex PrimEvaluatorConstructorMutex;
  1067. void InitPrimEvaluatorConstructors() {
  1068. PrimEvaluatorMap &constructor = PrimEvaluatorConstructors;
  1069. for (const auto &iter : GetPrimitiveToEvalImplMap()) {
  1070. constructor[iter.first] = InitStandardPrimEvaluator(iter.first, iter.second.impl_);
  1071. }
  1072. for (const auto &iter : GetUniformPrimitiveToImplMap()) {
  1073. constructor[iter.first] =
  1074. InitUniformPrimEvaluator(iter.first, iter.second.impl_, iter.second.eval_value_, iter.second.specify_out_type_);
  1075. }
  1076. constructor[prim::kPrimEmbed] = std::make_shared<EmbedEvaluator>();
  1077. constructor[prim::kPrimRefToEmbed] = std::make_shared<RefToEmbedEvaluator>();
  1078. constructor[prim::kPrimGetAttr] = std::make_shared<GetAttrEvaluator>();
  1079. constructor[prim::kPrimResolve] = std::make_shared<ResolveEvaluator>();
  1080. constructor[prim::kPrimCreateInstance] = std::make_shared<CreateInstanceEvaluator>();
  1081. constructor[prim::kPrimPartial] = std::make_shared<PartialEvaluator>();
  1082. }
  1083. } // namespace
  1084. void ClearPrimEvaluatorMap() {
  1085. PrimEvaluatorConstructors.clear();
  1086. GetPrimitiveToEvalImplMap().clear();
  1087. GetUniformPrimitiveToImplMap().clear();
  1088. }
  1089. bool IsInWhiteList(const PrimitivePtr &primitive) {
  1090. MS_EXCEPTION_IF_NULL(primitive);
  1091. auto iter = GetPrimitiveToEvalImplMap().find(primitive);
  1092. if (iter != GetPrimitiveToEvalImplMap().end()) {
  1093. return iter->second.in_white_list_;
  1094. }
  1095. auto uni_iter = GetUniformPrimitiveToImplMap().find(primitive);
  1096. if (uni_iter != GetUniformPrimitiveToImplMap().end()) {
  1097. return uni_iter->second.in_white_list_;
  1098. }
  1099. return false;
  1100. }
  1101. StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive) {
  1102. MS_EXCEPTION_IF_NULL(primitive);
  1103. auto iter = GetPrimitiveToEvalImplMap().find(primitive);
  1104. if (iter == GetPrimitiveToEvalImplMap().end()) {
  1105. return nullptr;
  1106. }
  1107. return iter->second.impl_;
  1108. }
  1109. PrimEvaluatorMap &GetPrimEvaluatorConstructors() {
  1110. PrimEvaluatorMap &constructor = PrimEvaluatorConstructors;
  1111. if (!constructor.empty()) {
  1112. return constructor;
  1113. }
  1114. std::lock_guard<std::mutex> initLock(PrimEvaluatorConstructorMutex);
  1115. if (constructor.empty()) {
  1116. InitPrimEvaluatorConstructors();
  1117. }
  1118. return constructor;
  1119. }
  1120. namespace {
  1121. bool IsSubtypeTuple(const AbstractBasePtr x, const TypePtr model) {
  1122. MS_EXCEPTION_IF_NULL(x);
  1123. MS_EXCEPTION_IF_NULL(model);
  1124. auto x_tuple = dyn_cast<AbstractTuple>(x);
  1125. auto model_tuple = dyn_cast<Tuple>(model);
  1126. if (x_tuple == nullptr || model_tuple == nullptr) {
  1127. return false;
  1128. }
  1129. if (model->IsGeneric()) {
  1130. return true;
  1131. }
  1132. if (x_tuple->size() != model_tuple->size()) {
  1133. return false;
  1134. }
  1135. for (size_t i = 0; i < x_tuple->size(); i++) {
  1136. bool is_subtype = IsSubtype((*x_tuple)[i], (*model_tuple)[i]);
  1137. if (!is_subtype) {
  1138. return false;
  1139. }
  1140. }
  1141. return true;
  1142. }
  1143. bool IsSubtypeArray(const AbstractBasePtr x, const TypePtr model) {
  1144. MS_EXCEPTION_IF_NULL(x);
  1145. MS_EXCEPTION_IF_NULL(model);
  1146. auto x_tensor = dyn_cast<AbstractTensor>(x);
  1147. auto model_tensor = dyn_cast<TensorType>(model);
  1148. if (x_tensor == nullptr || model_tensor == nullptr) {
  1149. return false;
  1150. }
  1151. if (model->IsGeneric()) {
  1152. return true;
  1153. }
  1154. return IsSubtype(x_tensor->element(), model_tensor->element());
  1155. }
  1156. bool IsSubtypeList(const AbstractBasePtr x, const TypePtr model) {
  1157. MS_EXCEPTION_IF_NULL(x);
  1158. MS_EXCEPTION_IF_NULL(model);
  1159. auto x_list = dyn_cast<AbstractList>(x);
  1160. auto model_list = dyn_cast<List>(model);
  1161. if (x_list == nullptr || model_list == nullptr) {
  1162. return false;
  1163. }
  1164. if (model->IsGeneric()) {
  1165. return true;
  1166. }
  1167. if (x_list->size() != model_list->size()) {
  1168. return false;
  1169. }
  1170. bool is_subtype = true;
  1171. for (size_t i = 0; i < x_list->size(); i++) {
  1172. is_subtype = IsSubtype((*x_list)[i], (*model_list)[i]);
  1173. if (!is_subtype) {
  1174. return false;
  1175. }
  1176. }
  1177. return is_subtype;
  1178. }
  1179. bool IsSubtypeClass(const AbstractBasePtr x, const TypePtr model) {
  1180. MS_EXCEPTION_IF_NULL(x);
  1181. MS_EXCEPTION_IF_NULL(model);
  1182. auto x_class = dyn_cast<AbstractClass>(x);
  1183. auto model_class = dyn_cast<Class>(model);
  1184. if (x_class == nullptr) {
  1185. return false;
  1186. }
  1187. if (model->IsGeneric()) {
  1188. return true;
  1189. }
  1190. if (x_class->tag() == model_class->tag()) {
  1191. auto m_attributes = model_class->GetAttributes();
  1192. auto x_attributes = x_class->attributes();
  1193. if (m_attributes.size() != x_attributes.size()) {
  1194. return false;
  1195. }
  1196. for (size_t i = 0; i < m_attributes.size(); i++) {
  1197. if (!IsSubtype(x_attributes[i].second, m_attributes[i].second)) {
  1198. return false;
  1199. }
  1200. }
  1201. return true;
  1202. }
  1203. return false;
  1204. }
  1205. inline bool IsSubtypeScalar(const AbstractBasePtr x, const TypePtr model) {
  1206. MS_EXCEPTION_IF_NULL(x);
  1207. MS_EXCEPTION_IF_NULL(model);
  1208. if (dyn_cast<AbstractScalar>(x) == nullptr) {
  1209. return false;
  1210. }
  1211. TypePtr x_type = x->GetTypeTrack();
  1212. return IsSubType(x_type, model);
  1213. }
  1214. } // namespace
  1215. bool IsSubtype(const AbstractBasePtr x, const TypePtr model) {
  1216. MS_EXCEPTION_IF_NULL(x);
  1217. MS_EXCEPTION_IF_NULL(model);
  1218. TypeId model_typeid = model->type_id();
  1219. switch (model_typeid) {
  1220. case kMetaTypeObject:
  1221. return true;
  1222. case kObjectTypeTuple:
  1223. return IsSubtypeTuple(x, model);
  1224. case kObjectTypeTensorType:
  1225. return IsSubtypeArray(x, model);
  1226. case kObjectTypeList:
  1227. return IsSubtypeList(x, model);
  1228. case kObjectTypeClass:
  1229. return IsSubtypeClass(x, model);
  1230. default:
  1231. if (IsSubType(model, std::make_shared<Number>())) {
  1232. return IsSubtypeScalar(x, model);
  1233. }
  1234. MS_LOG(EXCEPTION) << "Invalid model type: " << model->ToString() << ".";
  1235. }
  1236. }
  1237. } // namespace abstract
  1238. } // namespace mindspore