You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prim.cc 50 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/static_analysis/prim.h"
  19. #include <algorithm>
  20. #include <limits>
  21. #include <mutex>
  22. #include <set>
  23. #include <string>
  24. #include <utility>
  25. #include "operator/cc_implementations.h"
  26. #include "operator/ops.h"
  27. #include "operator/composite/do_signature.h"
  28. #include "operator/prim_to_function.h"
  29. #include "pipeline/static_analysis/utils.h"
  30. #include "utils/symbolic.h"
  31. #include "./common.h"
  32. #include "pipeline/resource.h"
  33. #include "pipeline/parse/resolve.h"
  34. #include "ir/meta_tensor.h"
  35. #include "utils/convert_utils.h"
  36. #include "pipeline/parse/data_converter.h"
  37. #include "pipeline/static_analysis/param_validator.h"
  38. #include "common/utils.h"
  39. namespace mindspore {
  40. namespace abstract {
  41. PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
  42. static PrimitiveEvalImplMap prim_eval_implement_map = {
  43. // Statements
  44. {prim::kPrimReturn, {InferImplReturn, true}},
  45. {prim::kPrimTypeOf, {InferImplTypeof, false}},
  46. {prim::kPrimHasType, {InferImplHasType, false}},
  47. {prim::kPrimDot, {InferImplDot, true}},
  48. {prim::kPrimSwitch, {InferImplSwitch, true}},
  49. {prim::kPrimIs_, {InferImplIs_, true}},
  50. {prim::kPrimIsNot, {InferImplIsNot, true}},
  51. {prim::kPrimInDict, {InferImplInDict, true}},
  52. {prim::kPrimNotInDict, {InferImplNotInDict, true}},
  53. // Maths
  54. {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}},
  55. {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}},
  56. // Array
  57. {prim::kPrimScalarToArray, {InferImplScalarToArray, true}},
  58. {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}},
  59. {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}},
  60. {prim::kPrimShape, {InferImplShape, true}},
  61. {prim::kPrimPack, {InferImplPack, true}},
  62. // Structure
  63. {prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
  64. {prim::kPrimMakeList, {InferImplMakeList, true}},
  65. {prim::kPrimMakeDict, {InferImplMakeDict, true}},
  66. {prim::kPrimMakeSlice, {InferImplMakeSlice, true}},
  67. {prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, true}},
  68. {prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, true}},
  69. {prim::kPrimMakeRecord, {InferImplMakeRecord, false}},
  70. {prim::kPrimTupleGetItem, {InferImplTupleGetItem, true}},
  71. {prim::kPrimListGetItem, {InferImplListGetItem, true}},
  72. {prim::kPrimTupleSetItem, {InferImplTupleSetItem, true}},
  73. {prim::kPrimListSetItem, {InferImplListSetItem, true}},
  74. {prim::kPrimDictGetItem, {InferImplDictGetItem, true}},
  75. {prim::kPrimDictSetItem, {InferImplDictSetItem, true}},
  76. {prim::kPrimListAppend, {InferImplListAppend, true}},
  77. {prim::kPrimTupleLen, {InferImplTupleLen, true}},
  78. {prim::kPrimListLen, {InferImplListLen, true}},
  79. {prim::kPrimArrayLen, {InferImplArrayLen, true}},
  80. {prim::kPrimListMap, {InferImplListMap, false}},
  81. {prim::kPrimListReduce, {InferImplListReduce, false}},
  82. {prim::kPrimTupleReversed, {InferImplTupleReversed, false}},
  83. {prim::kPrimReducedShape, {InferImplReduceShape, false}},
  84. {prim::kPrimTupleDiv, {InferImplTupleDiv, false}},
  85. {prim::kPrimTupleToArray, {InferImplTuple2Array, false}},
  86. {prim::kPrimShapeMul, {InferImplShapeMul, false}},
  87. {prim::kPrimTupleEqual, {InferImplTupleEqual, false}},
  88. {prim::kPrimListEqual, {InferImplListEqual, false}},
  89. {prim::kPrimMakeRange, {InferImplMakeRange, false}},
  90. {prim::kPrimStopGradient, {InferImplStopGradient, false}},
  91. {prim::kPrimStringEqual, {InferImplStringEqual, false}},
  92. {prim::kPrimStringConcat, {InferImplStringConcat, false}},
  93. {prim::kPrimDictLen, {InferImplDictLen, false}},
  94. // NN
  95. {prim::kPrimPooling, {InferImplPooling, true}},
  96. {prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}},
  97. {prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}},
  98. {prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}},
  99. {prim::kPrimReluGrad, {InferImplReluGrad, true}},
  100. {prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}},
  101. {prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}},
  102. {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}},
  103. {prim::kPrimRelu, {InferImplRelu, true}},
  104. {prim::kPrimZerosLikeTensor, {InferImplZerosLikeTensor, true}},
  105. {prim::kPrimFakeBprop, {InferImplFakeBprop, false}},
  106. {prim::kPrimLayerNorm, {InferImplLayerNorm, true}},
  107. {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}},
  108. {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}},
  109. // Others
  110. {prim::kPrimIdentity, {InferImplIdentity, true}},
  111. // Set impl to null as it will use PartialEvaluator;
  112. {prim::kPrimPartial, {nullptr, true}},
  113. {prim::kPrimJ, {InferImplJ, false}},
  114. {prim::kPrimEnvGetItem, {InferImplEnvGetItem, true}},
  115. {prim::kPrimEnvSetItem, {InferImplEnvSetItem, true}},
  116. {prim::kPrimEnvAdd, {InferImplEnvAdd, true}},
  117. {prim::kPrimMakeRefKey, {InferImplMakeRefKey, true}},
  118. {prim::kPrimMakeRef, {InferImplMakeRef, true}},
  119. {prim::kPrimGetRefKey, {InferImplGetRefKey, true}},
  120. {prim::kPrimGetRefValue, {InferImplGetRefValue, true}},
  121. {prim::kPrimGetRefOrigin, {InferImplGetRefOrigin, true}},
  122. {prim::kPrimStateSetItem, {InferImplStateSetItem, true}},
  123. {prim::kPrimDepend, {InferImplDepend, true}},
  124. {prim::kPrimBroadcastGradientArgs, {InferImplBroadcastGradientArgs, false}},
  125. {prim::kPrimControlDepend, {InferImplControlDepend, true}},
  126. // Debug
  127. {prim::kPrimScalarSummary, {InferImplScalarSummary, true}},
  128. {prim::kPrimImageSummary, {InferImplTensorSummary, true}},
  129. {prim::kPrimTensorSummary, {InferImplTensorSummary, true}},
  130. {prim::kPrimHistogramSummary, {InferImplTensorSummary, true}},
  131. };
  132. return prim_eval_implement_map;
  133. }
  134. using mindspore::parse::PyObjectWrapper;
  135. AbstractBasePtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
  136. AbstractBasePtr abs_base = eval_impl_(engine, prim_, args);
  137. return abs_base;
  138. }
  139. AbstractBasePtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
  140. AnfNodeConfigPtr out_conf) {
  141. AbstractBasePtrList args_spec_list;
  142. if (!prim_->isa<prim::DoSignaturePrimitive>()) {
  143. MS_LOG(EXCEPTION) << "Primitive should be DoSignature, but " << prim_->ToString();
  144. }
  145. if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
  146. MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
  147. }
  148. auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(prim_);
  149. auto out_node = dyn_cast<CNode>(out_conf->node());
  150. const auto &out_node_inputs = out_node->inputs();
  151. if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) {
  152. MS_LOG(EXCEPTION) << "Op: " << do_signature->function()->ToString()
  153. << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
  154. << ", inputs size " << out_node_inputs.size();
  155. }
  156. AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
  157. (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
  158. [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue(); });
  159. ScopePtr scope = kDefaultScope;
  160. if (out_conf != nullptr) {
  161. scope = out_conf->node()->scope();
  162. }
  163. ScopeGuard scope_guard(scope);
  164. AnfNodePtr new_cnode = nullptr;
  165. if (bound_node() != nullptr) {
  166. TraceManager::DebugTrace(std::make_shared<TraceDoSignature>(bound_node()->debug_info()));
  167. new_cnode = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list,
  168. args_inputs);
  169. TraceManager::EndTrace();
  170. } else {
  171. new_cnode = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list,
  172. args_inputs);
  173. }
  174. AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context());
  175. return engine->ForwardConfig(out_conf, fn_conf);
  176. }
  177. static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_spec_list, bool need_unpack) {
  178. // arg[0] is the func graph to unpack, ignore it
  179. AbstractBasePtrList specialize_args_before_unpack(args_spec_list.begin() + 1, args_spec_list.end());
  180. AbstractBasePtrList graph_specialize_args;
  181. if (need_unpack) {
  182. for (size_t index = 0; index < specialize_args_before_unpack.size(); index++) {
  183. MS_EXCEPTION_IF_NULL(specialize_args_before_unpack[index]);
  184. if (specialize_args_before_unpack[index]->isa<AbstractTuple>()) {
  185. AbstractTuplePtr arg_tuple = specialize_args_before_unpack[index]->cast<AbstractTuplePtr>();
  186. std::transform(arg_tuple->elements().begin(), arg_tuple->elements().end(),
  187. std::back_inserter(graph_specialize_args), [](AbstractBasePtr abs) { return abs; });
  188. } else if (specialize_args_before_unpack[index]->isa<AbstractDictionary>()) {
  189. AbstractDictionaryPtr arg_dict = specialize_args_before_unpack[index]->cast<AbstractDictionaryPtr>();
  190. auto dict_elems = arg_dict->elements();
  191. (void)std::transform(
  192. dict_elems.begin(), dict_elems.end(), std::back_inserter(graph_specialize_args),
  193. [](const AbstractAttribute &item) { return std::make_shared<AbstractKeywordArg>(item.first, item.second); });
  194. } else {
  195. MS_LOG(EXCEPTION) << "UnpackGraph require args should be tuple or dict, but got "
  196. << specialize_args_before_unpack[index]->ToString();
  197. }
  198. }
  199. } else {
  200. graph_specialize_args = specialize_args_before_unpack;
  201. }
  202. return graph_specialize_args;
  203. }
  204. AbstractBasePtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
  205. AnfNodeConfigPtr out_conf) {
  206. if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
  207. MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
  208. }
  209. if (!prim_->isa<prim::UnpackGraphPrimitive>()) {
  210. MS_LOG(EXCEPTION) << "Primitive should be UnpackGraphPrimitive, but got " << prim_->ToString();
  211. }
  212. auto unpack_graph = prim_->cast<prim::UnpackGraphPrimitivePtr>();
  213. auto out_node = out_conf->node()->cast<CNodePtr>();
  214. const auto &out_node_inputs = out_node->inputs();
  215. if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) {
  216. MS_LOG(EXCEPTION) << "UnpackGraphPrimitive"
  217. << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
  218. << ", inputs size " << out_node_inputs.size();
  219. }
  220. AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
  221. AbstractBasePtrList args_spec_list;
  222. (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
  223. [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue(); });
  224. // get the forward graph
  225. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  226. AbstractFunctionPtr fn = args_spec_list[0]->cast<AbstractFunctionPtr>();
  227. if (fn == nullptr) {
  228. MS_LOG(EXCEPTION) << "UnpackGraphPrimitive arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString();
  229. }
  230. auto real_fn = fn->cast<FuncGraphAbstractClosurePtr>();
  231. MS_EXCEPTION_IF_NULL(real_fn);
  232. FuncGraphPtr forward_graph = real_fn->func_graph();
  233. MS_EXCEPTION_IF_NULL(forward_graph);
  234. AbstractBasePtrList graph_specialize_args =
  235. GetUnpackGraphSpecArgsList(args_spec_list, unpack_graph->need_unpack_args());
  236. AbstractBasePtrList graph_specialize_args_without_sens;
  237. (void)std::transform(graph_specialize_args.begin(),
  238. graph_specialize_args.end() - (unpack_graph->with_sens_in_args() ? 1 : 0),
  239. std::back_inserter(graph_specialize_args_without_sens), [](AbstractBasePtr abs) { return abs; });
  240. auto new_graph = forward_graph->GenerateGraph(graph_specialize_args_without_sens);
  241. engine->func_graph_manager()->AddFuncGraph(new_graph);
  242. ScopePtr scope = kDefaultScope;
  243. if (out_conf != nullptr) {
  244. scope = out_conf->node()->scope();
  245. }
  246. ScopeGuard scope_guard(scope);
  247. AnfNodePtr new_vnode = NewValueNode(new_graph);
  248. AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_vnode, out_conf->context());
  249. return engine->ForwardConfig(out_conf, fn_conf);
  250. }
  251. namespace {
  252. py::object BuildValue(const ValuePtr &value_ptr) {
  253. if (value_ptr == nullptr) {
  254. return py::none();
  255. } else {
  256. return ValuePtrToPyData(value_ptr);
  257. }
  258. }
  259. } // end anonymous namespace
  260. py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
  261. MS_EXCEPTION_IF_NULL(abs_base);
  262. py::dict dic;
  263. if (abs_base->isa<AbstractTensor>()) {
  264. auto arg_tensor = dyn_cast<AbstractTensor>(abs_base);
  265. dic["shape"] = arg_tensor->shape()->shape();
  266. dic["dtype"] = arg_tensor->BuildType();
  267. dic["value"] = BuildValue(arg_tensor->BuildValue());
  268. } else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>() || abs_base->isa<AbstractRefKey>()) {
  269. std::vector<int> shape;
  270. dic["shape"] = shape;
  271. dic["dtype"] = abs_base->BuildType();
  272. dic["value"] = BuildValue(abs_base->BuildValue());
  273. } else if (abs_base->isa<AbstractSlice>()) {
  274. auto arg_slice = dyn_cast<AbstractSlice>(abs_base);
  275. std::vector<int> shape;
  276. dic["shape"] = shape;
  277. dic["dtype"] = arg_slice->BuildType();
  278. dic["value"] = BuildValue(arg_slice->BuildValue());
  279. } else if (abs_base->isa<AbstractRef>()) {
  280. auto value = abs_base->cast<AbstractRefPtr>()->ref();
  281. dic = ConvertAbstractToPython(value);
  282. } else if (abs_base->isa<AbstractTuple>()) {
  283. auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
  284. size_t len = arg_tuple->size();
  285. py::tuple shape_tuple(len);
  286. py::tuple dtype_tuple(len);
  287. for (size_t i = 0; i < len; i++) {
  288. py::dict out = ConvertAbstractToPython(arg_tuple->elements()[i]);
  289. shape_tuple[i] = out["shape"];
  290. dtype_tuple[i] = out["dtype"];
  291. }
  292. dic["shape"] = shape_tuple;
  293. dic["dtype"] = dtype_tuple;
  294. dic["value"] = BuildValue(arg_tuple->BuildValue());
  295. } else if (abs_base->isa<AbstractList>()) {
  296. auto arg_list = dyn_cast<AbstractList>(abs_base);
  297. size_t len = arg_list->size();
  298. py::list shape_list(len);
  299. py::list dtype_list(len);
  300. for (size_t i = 0; i < len; i++) {
  301. py::dict out = ConvertAbstractToPython(arg_list->elements()[i]);
  302. shape_list[i] = out["shape"];
  303. dtype_list[i] = out["dtype"];
  304. }
  305. dic["shape"] = shape_list;
  306. dic["dtype"] = dtype_list;
  307. dic["value"] = BuildValue(arg_list->BuildValue());
  308. } else if (abs_base->isa<AbstractNone>()) {
  309. dic["shape"] = py::none();
  310. dic["dtype"] = py::none();
  311. dic["value"] = py::none();
  312. } else if (abs_base->isa<AbstractFunction>()) {
  313. dic["shape"] = py::none();
  314. dic["dtype"] = abs_base->BuildType();
  315. dic["value"] = py::none();
  316. } else {
  317. auto value = abs_base->BuildValue();
  318. if ((*value == *kAnyValue)) {
  319. auto value_desc = abs_base->value_desc();
  320. MS_EXCEPTION(TypeError) << "Unsupported parameter " << (value_desc.empty() ? "type" : value_desc)
  321. << " for python primitive.";
  322. }
  323. MS_EXCEPTION(TypeError) << "Unsupported parameter type for python primitive, the parameter value is "
  324. << value->ToString();
  325. }
  326. return dic;
  327. }
  328. namespace {
  329. py::tuple PreparePyInputs(const PrimitivePyPtr &prim_py, const AbstractBasePtrList &args) {
  330. const AbstractBasePtrList *args_ptr;
  331. if (prim_py->is_tuple_input_) {
  332. if (args.empty()) {
  333. MS_LOG(EXCEPTION) << "Primitive args is empty";
  334. }
  335. if (args[0] == nullptr || !args[0]->isa<AbstractTuple>()) {
  336. MS_LOG(EXCEPTION) << "Custom Primitive inputs should be packed into a Tuple after converting"
  337. "prim convert pass for GE.";
  338. }
  339. args_ptr = &(args[0]->cast<AbstractTuplePtr>()->elements());
  340. } else {
  341. args_ptr = &args;
  342. }
  343. py::tuple py_args(args_ptr->size());
  344. for (size_t i = 0; i < args_ptr->size(); i++) {
  345. auto arg_i = (*args_ptr)[i];
  346. py_args[i] = ConvertAbstractToPython(arg_i);
  347. }
  348. return py_args;
  349. }
  350. AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) {
  351. // Convert to AbstractValue based on type and shape
  352. if (output["value"].is_none()) {
  353. auto out_shape = output["shape"];
  354. auto out_dtype = output["dtype"];
  355. return PyListDtype2AbstractTensor(out_shape, out_dtype);
  356. }
  357. // Convert pyobject to Value, then to AbstractValue
  358. ValuePtr converted_ret = nullptr;
  359. bool converted = parse::ConvertData(output["value"], &converted_ret);
  360. if (!converted) {
  361. MS_LOG(EXCEPTION) << "Convert data failed";
  362. }
  363. auto res_spec = FromValue(converted_ret);
  364. MS_EXCEPTION_IF_NULL(res_spec);
  365. if (res_spec->isa<AbstractTensor>()) {
  366. // Replace to tensor constant node in specialize
  367. auto res_tensor = res_spec->cast<AbstractTensorPtr>();
  368. res_tensor->set_value(converted_ret);
  369. }
  370. if (prim_py->IsCustomPrim()) {
  371. // Raise error if output_num is not match the infer result.
  372. int output_num = GetValue<int>(prim_py->GetAttr("output_num"));
  373. if (res_spec->isa<AbstractTensor>() && output_num != 1) {
  374. MS_LOG(EXCEPTION) << "Custom primitive " << prim_py->ToString() << " output_num " << output_num
  375. << " not matches the infer result.";
  376. } else if (res_spec->isa<AbstractTuple>() &&
  377. (res_spec->cast<AbstractTuplePtr>()->size() != IntToSize(output_num))) {
  378. MS_LOG(EXCEPTION) << "Custom primitive " << prim_py->ToString() << " output_num " << output_num
  379. << " not matches the infer result.";
  380. }
  381. }
  382. return res_spec;
  383. }
  384. } // end anonymous namespace
  385. AbstractBasePtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
  386. MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString();
  387. const auto &iter = cache_->find(args);
  388. if (iter != cache_->end()) {
  389. return iter->second;
  390. }
  391. auto py_args = PreparePyInputs(prim_py_, args);
  392. auto pyobj = prim_py_->GetPyObj();
  393. if (pyobj == nullptr) {
  394. MS_LOG(EXCEPTION) << "[" << prim_py_->ToString() << "]: pyobj is empty";
  395. }
  396. auto infer_fuc = pyobj.attr("__infer__");
  397. py::dict output = infer_fuc(*py_args);
  398. MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output);
  399. auto res_spec = PyInferRes2Abstract(prim_py_, output);
  400. MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << ".";
  401. (*cache_)[args] = res_spec;
  402. return res_spec;
  403. }
  404. AbstractBasePtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
  405. // if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type.
  406. if (nargs_ != args.size()) {
  407. MS_LOG(ERROR) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size() << " inputs";
  408. return nullptr;
  409. }
  410. TypePtr ret_value_type = return_value_type_;
  411. ValuePtrList value_list;
  412. for (const auto &arg : args) {
  413. // Check if all arguments are scalar type.
  414. MS_EXCEPTION_IF_NULL(arg);
  415. if (arg->isa<AbstractScalar>()) {
  416. auto arg_scalar = dyn_cast<AbstractScalar>(arg);
  417. auto arg_value = arg_scalar->GetValueTrack();
  418. value_list.push_back(arg_value);
  419. } else {
  420. // Raise TypeError Expected Scalar.
  421. MS_LOG(EXCEPTION) << "Expect scalar arguments for uniform primitives.";
  422. }
  423. }
  424. for (const auto &item : type_map_) {
  425. TypePtrList selections;
  426. MS_EXCEPTION_IF_NULL(item.second);
  427. (void)std::transform(item.second->begin(), item.second->end(), std::back_inserter(selections),
  428. [&args](size_t arg_idx) -> TypePtr { return args[arg_idx]->GetTypeTrack(); });
  429. TypePtr res = CheckTypeList(item.first, selections);
  430. if (*return_value_type_ == *(item.first)) {
  431. ret_value_type = res;
  432. }
  433. }
  434. ValuePtr inferred_value = RunImpl(value_list);
  435. if (!(*inferred_value == *kAnyValue)) {
  436. ret_value_type = inferred_value->type();
  437. }
  438. // for comparison primitives , return type shall have be specified to be bool.
  439. if (specify_out_type_ != nullptr) {
  440. ret_value_type = specify_out_type_;
  441. }
  442. AbstractScalarPtr abs_base = std::make_shared<AbstractScalar>(inferred_value, ret_value_type);
  443. return abs_base;
  444. }
  445. ValuePtr UniformPrimEvaluator::RunImpl(const ValuePtrList &args) const {
  446. if (!eval_value_) {
  447. return kAnyValue;
  448. } else {
  449. if (std::any_of(args.begin(), args.end(), [](const ValuePtr &arg) {
  450. MS_EXCEPTION_IF_NULL(arg);
  451. return arg->isa<AnyValue>();
  452. })) {
  453. return kAnyValue;
  454. }
  455. return impl_(args);
  456. }
  457. }
  458. // Primitive implementation
  459. // static function start
  460. namespace {
  461. EvaluatorPtr InitStandardPrimEvaluator(PrimitivePtr primitive, const StandardPrimitiveEvalImpl eval_impl) {
  462. EvaluatorPtr prim_evaluator = std::make_shared<StandardPrimEvaluator>(primitive, eval_impl);
  463. return prim_evaluator;
  464. }
  465. EvaluatorPtr InitUniformPrimEvaluator(const PrimitivePtr &primitive, PrimitiveImpl prim_impl, bool eval_value,
  466. const TypePtr &specify_out_type) {
  467. FunctionPtr func = nullptr;
  468. (void)prim::PrimToFunction::GetInstance().GetFunction(primitive, &func);
  469. MS_EXCEPTION_IF_NULL(func);
  470. EvaluatorPtr uniform_primitive_evaluator =
  471. std::make_shared<UniformPrimEvaluator>(func, prim_impl, eval_value, specify_out_type);
  472. return uniform_primitive_evaluator;
  473. }
  474. const int kResolveCaseUserDefineClass = 1;
  475. const int kResolveCaseBuildinTypeMethod = 2;
  476. const int kResolveCaseFunction = 3;
  477. int GetResolveCase(const TypePtr &data_type) {
  478. MS_EXCEPTION_IF_NULL(data_type);
  479. if (data_type->type_id() == kObjectTypeClass) {
  480. return kResolveCaseUserDefineClass;
  481. }
  482. // try method map, if not in method map, the data_type should be External type.
  483. if (pipeline::Resource::IsTypeInMethodMap(data_type->type_id())) {
  484. return kResolveCaseBuildinTypeMethod;
  485. }
  486. return kResolveCaseFunction;
  487. }
  488. FuncGraphPtr PyObjToGraph(const AnalysisEnginePtr &engine, const ValuePtr &method) {
  489. MS_EXCEPTION_IF_NULL(engine);
  490. MS_EXCEPTION_IF_NULL(method);
  491. if (!method->isa<parse::PyObjectWrapper>()) {
  492. MS_LOG(EXCEPTION) << "Method type error: " << method->ToString();
  493. }
  494. std::shared_ptr<PyObjectWrapper> obj = method->cast<std::shared_ptr<PyObjectWrapper>>();
  495. FuncGraphPtr func_graph = mindspore::parse::ConvertToFuncGraph(obj->obj());
  496. if (func_graph == nullptr) {
  497. MS_LOG(EXCEPTION) << "Parse python object: " << method->ToString() << " failed";
  498. }
  499. FuncGraphManagerPtr manager = engine->func_graph_manager();
  500. manager->AddFuncGraph(func_graph);
  501. return func_graph;
  502. }
  503. inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr func_graph) {
  504. MS_EXCEPTION_IF_NULL(engine);
  505. FuncGraphManagerPtr manager = engine->func_graph_manager();
  506. manager->AddFuncGraph(func_graph);
  507. }
  508. AbstractBasePtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf,
  509. const AnfNodeConfigPtr &old_conf) {
  510. MS_EXCEPTION_IF_NULL(old_conf);
  511. AbstractBasePtr abs_ptr = ToAbstract(value, AnalysisContext::DummyContext(), old_conf);
  512. AbstractFunctionPtr abs_func = dyn_cast<abstract::AbstractFunction>(abs_ptr);
  513. MS_EXCEPTION_IF_NULL(abs_func);
  514. // Create new cnode
  515. std::vector<AnfNodePtr> input = {NewValueNode(prim::kPrimPartial)};
  516. auto func_graph_func = dyn_cast<abstract::FuncGraphAbstractClosure>(abs_func);
  517. if (func_graph_func != nullptr) {
  518. FuncGraphPtr fg = func_graph_func->func_graph();
  519. input.push_back(NewValueNode(fg));
  520. } else {
  521. auto prim_func = dyn_cast<abstract::PrimitiveAbstractClosure>(abs_func);
  522. MS_EXCEPTION_IF_NULL(prim_func);
  523. PrimitivePtr prim = prim_func->prim();
  524. input.push_back(NewValueNode(prim));
  525. }
  526. AnfNodeConfigPtr conf = dyn_cast<abstract::AnfNodeConfig>(data_conf);
  527. MS_EXCEPTION_IF_NULL(conf);
  528. input.push_back(conf->node());
  529. MS_EXCEPTION_IF_NULL(old_conf);
  530. FuncGraphPtr func_graph = old_conf->node()->func_graph();
  531. CNodePtr new_cnode = func_graph->NewCNode(input);
  532. AnalysisEnginePtr eng = old_conf->engine();
  533. AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, old_conf->context());
  534. return eng->ForwardConfig(old_conf, fn_conf);
  535. }
  536. AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engine,
  537. const AbstractBasePtrList &args_spec_list,
  538. const AnfNodeConfigPtr &out_conf) {
  539. // args_spec_list: same as StaticGetter
  540. if (args_spec_list.size() < 2) {
  541. MS_LOG(EXCEPTION) << "Size of args_spec_list is less than 2";
  542. }
  543. MS_EXCEPTION_IF_NULL(out_conf);
  544. // An external type.
  545. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  546. MS_EXCEPTION_IF_NULL(args_spec_list[1]);
  547. MS_LOG(DEBUG) << "Args[0]: " << args_spec_list[0]->ToString();
  548. MS_LOG(DEBUG) << "Args[1]: " << args_spec_list[1]->ToString();
  549. auto data_v = args_spec_list[0]->BuildValue();
  550. if (!data_v->isa<parse::NameSpace>()) {
  551. MS_LOG(EXCEPTION) << "Data is not NameSpace : " << data_v->ToString();
  552. }
  553. auto item_v = args_spec_list[1]->BuildValue();
  554. if (item_v->isa<StringImm>()) {
  555. item_v = std::make_shared<parse::Symbol>(item_v->cast<StringImmPtr>()->value());
  556. }
  557. if (!item_v->isa<parse::Symbol>()) {
  558. MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_v->ToString();
  559. }
  560. // item_name to func addr from obj_map
  561. parse::SymbolPtr symbol = item_v->cast<parse::SymbolPtr>();
  562. parse::NameSpacePtr name_space = data_v->cast<parse::NameSpacePtr>();
  563. FuncGraphPtr func_graph = out_conf->node()->func_graph();
  564. auto new_node = parse::ResolveSymbol(func_graph->manager(), name_space, symbol, out_conf->node());
  565. if (new_node == nullptr) {
  566. MS_LOG(EXCEPTION) << "Resolve node failed";
  567. }
  568. AnalysisEnginePtr eng = out_conf->engine();
  569. AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context());
  570. return eng->ForwardConfig(out_conf, fn_conf);
  571. }
  572. AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine,
  573. const AbstractBasePtrList &args_spec_list, const ValuePtr &item_v,
  574. const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
  575. if (args_spec_list.empty()) {
  576. MS_LOG(EXCEPTION) << "args_spec_list is empty";
  577. }
  578. AbstractClassPtr cls = CheckArg<AbstractClass>("__FUNC__", args_spec_list, 0);
  579. // If item_v is an attribute, get abstract value from AbstractClass
  580. MS_EXCEPTION_IF_NULL(item_v);
  581. if (!item_v->isa<StringImm>()) {
  582. MS_LOG(EXCEPTION) << "Attribute type error";
  583. }
  584. std::string item_name = item_v->cast<StringImmPtr>()->value();
  585. MS_LOG(DEBUG) << "Resolve name: " << cls->tag().name();
  586. MS_LOG(DEBUG) << "Resolve item: " << item_name;
  587. AbstractBasePtr attr = cls->GetAttribute(item_name);
  588. if (attr != nullptr) {
  589. return attr;
  590. }
  591. ValuePtr method = cls->GetMethod(item_name);
  592. if (method->isa<AnyValue>()) {
  593. MS_LOG(EXCEPTION) << "Unknown field, data type: " << args_spec_list[0]->BuildType()->ToString()
  594. << ", item value: " << item_v->ToString();
  595. }
  596. // Infer class method
  597. ValuePtr converted_v = PyObjToGraph(engine, method);
  598. return StaticGetterInferred(converted_v, data_conf, out_conf);
  599. }
  600. AbstractBasePtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v,
  601. const TypePtr &data_type, const ConfigPtr &data_conf,
  602. const AnfNodeConfigPtr &out_conf) {
  603. MS_EXCEPTION_IF_NULL(item_v);
  604. MS_EXCEPTION_IF_NULL(data_type);
  605. // The method maybe a Primitive or Composite
  606. if (!item_v->isa<StringImm>()) {
  607. MS_LOG(EXCEPTION) << "Error item is not string";
  608. }
  609. std::string item_name = item_v->cast<StringImmPtr>()->value();
  610. Any method = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name);
  611. if (method.empty()) {
  612. MS_LOG(EXCEPTION) << "Object type: " << data_type->ToString() << " has no method: " << item_name;
  613. }
  614. ValuePtr converted_v = nullptr;
  615. if (method.is<std::string>()) {
  616. // composite registered in standard_method_map go to this branch
  617. converted_v = prim::GetPythonOps(method.cast<std::string>());
  618. AddToManager(engine, converted_v->cast<FuncGraphPtr>());
  619. } else if (method.is<PrimitivePtr>()) {
  620. converted_v = method.cast<PrimitivePtr>();
  621. } else {
  622. MS_LOG(EXCEPTION) << "Expect to get string or PrimitivePtr from method map, but got " << method.ToString();
  623. }
  624. return StaticGetterInferred(converted_v, data_conf, out_conf);
  625. }
  626. AbstractBasePtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
  627. const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
  628. // Inputs: namespace and its static function; or class and its member function
  629. CheckArgsSize("StaticGetter", args_spec_list, 2);
  630. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  631. MS_EXCEPTION_IF_NULL(args_spec_list[1]);
  632. TypePtr data_type = args_spec_list[0]->BuildType();
  633. ValuePtr item_value = args_spec_list[1]->BuildValue();
  634. ScopePtr scope = kDefaultScope;
  635. if (out_conf != nullptr) {
  636. scope = out_conf->node()->scope();
  637. }
  638. ScopeGuard scope_guard(scope);
  639. if (item_value->isa<AnyValue>()) {
  640. MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
  641. }
  642. int case_v = GetResolveCase(data_type);
  643. if (case_v == kResolveCaseUserDefineClass) {
  644. return GetEvaluatedValueForClassAttrOrMethod(engine, args_spec_list, item_value, data_conf, out_conf);
  645. } else if (case_v == kResolveCaseBuildinTypeMethod) {
  646. return GetEvaluatedValueForBuiltinTypeMethod(engine, item_value, data_type, data_conf, out_conf);
  647. } else {
  648. return GetEvaluatedValueForNameSpaceString(engine, args_spec_list, out_conf);
  649. }
  650. }
  651. } // end anonymous namespace
  652. // static variable start;
  653. namespace {
  654. class EmbedEvaluator : public SymbolicPrimEvaluator {
  655. public:
  656. EmbedEvaluator() : SymbolicPrimEvaluator("EmbedEvaluator") {}
  657. ~EmbedEvaluator() override = default;
  658. MS_DECLARE_PARENT(EmbedEvaluator, SymbolicPrimEvaluator);
  659. AbstractBasePtr EvalPrim(const ConfigPtrList &args_conf_list) override {
  660. // arg: free variable to be embedded
  661. if (args_conf_list.size() != 1) {
  662. MS_LOG(EXCEPTION) << "EmbedEvaluator requires 1 parameter, but got " << args_conf_list.size();
  663. }
  664. AnfNodeConfigPtr node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]);
  665. MS_EXCEPTION_IF_NULL(node_conf);
  666. AbstractBasePtr x = node_conf->GetEvaluatedValue();
  667. x = SensitivityTransform(x);
  668. SymbolicKeyInstancePtr key = std::make_shared<SymbolicKeyInstance>(node_conf->node(), x);
  669. AbstractScalarPtr abs_scalar = std::make_shared<AbstractScalar>(key, std::make_shared<SymbolicKeyType>());
  670. return abs_scalar;
  671. }
  672. };
  673. static AnfNodePtr FindParameterNodeByString(const FuncGraphManagerPtr &manager, const std::string &name) {
  674. auto root_g_set = manager->roots();
  675. if (root_g_set.size() != 1) {
  676. return nullptr;
  677. }
  678. const FuncGraphPtr &root_g = root_g_set.back();
  679. for (auto &param_node : root_g->parameters()) {
  680. auto param = param_node->cast<ParameterPtr>();
  681. if (param && name == param->name()) {
  682. return param;
  683. }
  684. }
  685. return nullptr;
  686. }
  687. class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
  688. public:
  689. RefToEmbedEvaluator() : SymbolicPrimEvaluator("RefToEmbedEvaluator") {}
  690. ~RefToEmbedEvaluator() override = default;
  691. MS_DECLARE_PARENT(RefToEmbedEvaluator, SymbolicPrimEvaluator);
  692. AbstractBasePtr EvalPrim(const ConfigPtrList &args_conf_list) override {
  693. if (args_conf_list.size() != 1) {
  694. MS_LOG(ERROR) << "Requires 1 parameter, but has: " << args_conf_list.size();
  695. return nullptr;
  696. }
  697. static TypePtr type = std::make_shared<SymbolicKeyType>();
  698. auto node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]);
  699. if (node_conf == nullptr) {
  700. MS_LOG(ERROR) << "Conf should be AnfNodeConfig";
  701. return nullptr;
  702. }
  703. AbstractBasePtr abs = node_conf->GetEvaluatedValue();
  704. AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>();
  705. if (ref_abs == nullptr) {
  706. MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref.";
  707. return nullptr;
  708. }
  709. auto key_abs = ref_abs->ref_key();
  710. if (key_abs == nullptr) {
  711. MS_LOG(ERROR) << "RefToEmbed input Ref key is nullptr.";
  712. return nullptr;
  713. }
  714. auto key_value = key_abs->BuildValue();
  715. if (key_value == nullptr) {
  716. MS_LOG(ERROR) << "RefToEmbed input Ref key value is nullptr.";
  717. return nullptr;
  718. }
  719. auto refkey = key_value->cast<RefKeyPtr>();
  720. if (refkey == nullptr) {
  721. return std::make_shared<AbstractScalar>(type);
  722. }
  723. std::string name = refkey->tag();
  724. const auto &manager = node_conf->node()->func_graph()->manager();
  725. auto node = FindParameterNodeByString(manager, name);
  726. if (node == nullptr) {
  727. MS_LOG(ERROR) << "RefToEmbed input can't find parameter \"" << name << "\" in graph.";
  728. return nullptr;
  729. }
  730. AbstractBasePtr x = ref_abs->ref();
  731. x = SensitivityTransform(x);
  732. std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x);
  733. std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type);
  734. return abs_scalar;
  735. }
  736. };
  737. class GetAttrEvaluator : public TransitionPrimEvaluator {
  738. public:
  739. GetAttrEvaluator() : TransitionPrimEvaluator("GetAttrEvaluator") {}
  740. ~GetAttrEvaluator() override = default;
  741. MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator);
  742. AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
  743. const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
  744. // Inputs: data, item
  745. if (args_spec_list.size() != 2) {
  746. MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
  747. }
  748. AbstractBasePtr ret = nullptr;
  749. if (bound_node() != nullptr) {
  750. TraceManager::DebugTrace(std::make_shared<TraceResolve>(bound_node()->debug_info()));
  751. ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
  752. TraceManager::EndTrace();
  753. } else {
  754. ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
  755. }
  756. // don't lookup from cache, as different out_conf with same node but different context
  757. // may add different entry to anfnode_config_map, like getattr primitive;
  758. (*cache_)[args_spec_list] = ret;
  759. return ret;
  760. }
  761. };
  762. class ResolveEvaluator : public TransitionPrimEvaluator {
  763. public:
  764. ResolveEvaluator() : TransitionPrimEvaluator("ResolveEvaluator") {}
  765. ~ResolveEvaluator() override = default;
  766. MS_DECLARE_PARENT(ResolveEvaluator, TransitionPrimEvaluator);
  767. AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
  768. const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
  769. // Inputs: namespace, symbol
  770. if (args_spec_list.size() != 2) {
  771. MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
  772. }
  773. AbstractBasePtr ret = nullptr;
  774. if (bound_node() != nullptr) {
  775. TraceManager::DebugTrace(std::make_shared<TraceResolve>(bound_node()->debug_info()));
  776. ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
  777. TraceManager::EndTrace();
  778. } else {
  779. ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
  780. }
  781. return ret;
  782. }
  783. };
  784. class CreateInstanceEvaluator : public TransitionPrimEvaluator {
  785. public:
  786. CreateInstanceEvaluator() : TransitionPrimEvaluator("CreateInstanceEvaluator") {}
  787. ~CreateInstanceEvaluator() override = default;
  788. MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator);
  789. AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
  790. const ConfigPtr &, const AnfNodeConfigPtr &out_conf) override {
  791. if (args_spec_list.empty()) {
  792. MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty";
  793. }
  794. // get the type parameter
  795. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  796. TypePtr type = args_spec_list[0]->GetTypeTrack();
  797. if (type->type_id() != kMetaTypeTypeType) {
  798. MS_LOG(EXCEPTION) << "CreateInstanceEvaluator require first parameter should be an object of TypeType, but got "
  799. << type->ToString();
  800. }
  801. ValuePtr value_track = args_spec_list[0]->GetValueTrack();
  802. MS_EXCEPTION_IF_NULL(value_track);
  803. std::shared_ptr<parse::PyObjectWrapper> type_obj = dyn_cast<parse::PyObjectWrapper>(value_track);
  804. if (type_obj == nullptr) {
  805. MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << ".";
  806. }
  807. if (!type_obj->isa<parse::ClassType>()) {
  808. MS_LOG(EXCEPTION) << "CreateInstanceEvaluator the type_obj should be an object of ClassType, but got "
  809. << type_obj->ToString() << ".";
  810. }
  811. auto class_type = type_obj->obj();
  812. MS_LOG(DEBUG) << "Get class type is " << type_obj->ToString() << ".";
  813. // get the create instance obj's parameters
  814. pybind11::tuple params = GetParameters(args_spec_list);
  815. // create class instance
  816. auto obj = parse::data_converter::CreatePythonObject(class_type, params);
  817. if (py::isinstance<py::none>(obj)) {
  818. MS_LOG(EXCEPTION) << "Create python object failed, only support Cell and Primitive type";
  819. }
  820. // process the object
  821. ValuePtr converted_ret = nullptr;
  822. bool converted = parse::ConvertData(obj, &converted_ret, true);
  823. if (!converted) {
  824. MS_LOG(EXCEPTION) << "Convert the python object failed";
  825. }
  826. MS_EXCEPTION_IF_NULL(converted_ret);
  827. if (converted_ret->isa<FuncGraph>()) {
  828. AddToManager(engine, converted_ret->cast<FuncGraphPtr>());
  829. }
  830. AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf);
  831. (*cache_)[args_spec_list] = ret;
  832. return ret;
  833. }
  834. pybind11::tuple GetParameters(const AbstractBasePtrList &args_spec_list) const {
  835. // Exclude class type by minus 1;
  836. std::size_t params_size = args_spec_list.size() - 1;
  837. auto params = py::tuple(params_size);
  838. if (params_size > 0) {
  839. for (size_t i = 0; i < params_size; i++) {
  840. // Only support the Scalar parameters type. Bypass class type by offset with 1.
  841. auto arg = args_spec_list[i + 1];
  842. MS_EXCEPTION_IF_NULL(arg);
  843. // Because the Tensor's AbstractTensor can't get value from GetValueTrack.
  844. ValuePtr param_value = arg->BuildValue();
  845. py::object param = ValuePtrToPyData(param_value);
  846. params[i] = param;
  847. }
  848. }
  849. return params;
  850. }
  851. };
  852. class PartialEvaluator : public Evaluator {
  853. public:
  854. PartialEvaluator() : Evaluator("PartialEvaluator") {}
  855. ~PartialEvaluator() override = default;
  856. AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
  857. AnfNodeConfigPtr out_conf = nullptr) override {
  858. if (args_conf_list.size() == 0) {
  859. MS_LOG(EXCEPTION) << "Args size should be greater than 0";
  860. }
  861. auto arg0_value = args_conf_list[0]->GetEvaluatedValue();
  862. AbstractBasePtrList args_spec_list{arg0_value};
  863. auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0);
  864. // Sometimes, node[0] in out_conf becomes phi0;
  865. if (func->isa<PrimitiveAbstractClosure>()) {
  866. auto prim_func = dyn_cast<PrimitiveAbstractClosure>(func);
  867. if (prim_func->prim()->isa<prim::DoSignaturePrimitive>()) {
  868. prim::DoSignaturePrimitivePtr do_signature_prim = dyn_cast<prim::DoSignaturePrimitive>(prim_func->prim());
  869. return HandleDoSignature(engine, do_signature_prim->function(), out_conf);
  870. }
  871. }
  872. (void)std::transform(args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list),
  873. [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue(); });
  874. AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end());
  875. AbstractFuncAtomPtrList partialPtrList;
  876. auto build_partial = [args, &partialPtrList](const AbstractFuncAtomPtr &atom_func) {
  877. auto new_func = std::make_shared<PartialAbstractClosure>(atom_func, args);
  878. partialPtrList.push_back(new_func);
  879. };
  880. func->Visit(build_partial);
  881. auto ret = AbstractFunction::MakeAbstractFunction(partialPtrList);
  882. (*cache_)[args_spec_list] = ret;
  883. return ret;
  884. }
  885. AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) override {
  886. MS_LOG(EXCEPTION) << "Infer() should not be called, Run() method should be called";
  887. }
  888. AbstractBasePtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value,
  889. const AnfNodeConfigPtr &out_conf = nullptr) const {
  890. MS_EXCEPTION_IF_NULL(out_conf);
  891. MS_EXCEPTION_IF_NULL(out_conf->node());
  892. auto cnode = out_conf->node()->cast<CNodePtr>();
  893. if (cnode == nullptr) {
  894. MS_LOG(EXCEPTION) << "Cnode is nullptr";
  895. }
  896. std::vector<AnfNodePtr> new_nodes_inputs = cnode->inputs();
  897. auto new_signature_value = std::make_shared<prim::DoSignatureMetaFuncGraph>("signature", signature_value);
  898. new_nodes_inputs[1] = NewValueNode(new_signature_value);
  899. FuncGraphPtr func_graph = cnode->func_graph();
  900. ScopePtr scope = out_conf->node()->scope();
  901. ScopeGuard scope_guard(scope);
  902. CNodePtr new_cnode = func_graph->NewCNode(new_nodes_inputs);
  903. AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context());
  904. return engine->ForwardConfig(out_conf, fn_conf);
  905. }
  906. };
  907. struct PrimitiveImplInferValue {
  908. PrimitiveImpl impl_; // implement function of primitive
  909. bool eval_value_; // whether evaluate value
  910. TypePtr specify_out_type_; // whether specify return type
  911. bool in_white_list_; // true if this Primitive in white list, else false.
  912. };
  913. using PrimitiveToImplMap = std::unordered_map<PrimitivePtr, PrimitiveImplInferValue, PrimitiveHasher, PrimitiveEqual>;
  914. PrimitiveToImplMap &GetUniformPrimitiveToImplMap() {
  915. static PrimitiveToImplMap uniform_prim_implement_map = {
  916. {prim::kPrimScalarAdd, {prim::ScalarAdd, true, nullptr, true}},
  917. {prim::kPrimScalarSub, {prim::ScalarSub, true, nullptr, true}},
  918. {prim::kPrimScalarMul, {prim::ScalarMul, true, nullptr, true}},
  919. {prim::kPrimScalarDiv, {prim::ScalarDiv, true, nullptr, true}},
  920. {prim::kPrimScalarMod, {prim::ScalarMod, true, nullptr, true}},
  921. {prim::kPrimScalarPow, {prim::ScalarPow, true, nullptr, true}},
  922. {prim::kPrimScalarFloordiv, {prim::ScalarFloordiv, true, nullptr, true}},
  923. {prim::kPrimScalarUadd, {prim::ScalarUAdd, true, nullptr, true}},
  924. {prim::kPrimScalarUsub, {prim::ScalarUSub, true, nullptr, true}},
  925. {prim::kPrimScalarLog, {prim::ScalarLog, true, nullptr, true}},
  926. {prim::kPrimScalarEq, {prim::ScalarEq, true, std::make_shared<Bool>(), true}},
  927. {prim::kPrimScalarLt, {prim::ScalarLt, true, std::make_shared<Bool>(), true}},
  928. {prim::kPrimScalarGt, {prim::ScalarGt, true, std::make_shared<Bool>(), true}},
  929. {prim::kPrimScalarNe, {prim::ScalarNe, true, std::make_shared<Bool>(), true}},
  930. {prim::kPrimScalarLe, {prim::ScalarLe, true, std::make_shared<Bool>(), true}},
  931. {prim::kPrimScalarGe, {prim::ScalarGe, true, std::make_shared<Bool>(), true}},
  932. {prim::kPrimBoolNot, {prim::BoolNot, true, std::make_shared<Bool>(), true}},
  933. {prim::kPrimBoolAnd, {prim::BoolAnd, true, std::make_shared<Bool>(), true}},
  934. {prim::kPrimBoolEq, {prim::BoolEq, true, std::make_shared<Bool>(), true}},
  935. {prim::kPrimBoolOr, {prim::BoolOr, true, std::make_shared<Bool>(), true}},
  936. };
  937. return uniform_prim_implement_map;
  938. }
  939. PrimEvaluatorMap PrimEvaluatorConstructors = PrimEvaluatorMap();
  940. std::mutex PrimEvaluatorConstructorMutex;
  941. void InitPrimEvaluatorConstructors() {
  942. PrimEvaluatorMap &constructor = PrimEvaluatorConstructors;
  943. for (const auto &iter : GetPrimitiveToEvalImplMap()) {
  944. constructor[iter.first] = InitStandardPrimEvaluator(iter.first, iter.second.impl_);
  945. }
  946. for (const auto &iter : GetUniformPrimitiveToImplMap()) {
  947. constructor[iter.first] =
  948. InitUniformPrimEvaluator(iter.first, iter.second.impl_, iter.second.eval_value_, iter.second.specify_out_type_);
  949. }
  950. constructor[prim::kPrimEmbed] = std::make_shared<EmbedEvaluator>();
  951. constructor[prim::kPrimRefToEmbed] = std::make_shared<RefToEmbedEvaluator>();
  952. constructor[prim::kPrimGetAttr] = std::make_shared<GetAttrEvaluator>();
  953. constructor[prim::kPrimResolve] = std::make_shared<ResolveEvaluator>();
  954. constructor[prim::kPrimCreateInstance] = std::make_shared<CreateInstanceEvaluator>();
  955. constructor[prim::kPrimPartial] = std::make_shared<PartialEvaluator>();
  956. }
  957. } // namespace
  958. void ClearPrimEvaluatorMap() {
  959. PrimEvaluatorConstructors.clear();
  960. GetPrimitiveToEvalImplMap().clear();
  961. GetUniformPrimitiveToImplMap().clear();
  962. }
  963. bool IsInWhiteList(const PrimitivePtr primitive) {
  964. MS_EXCEPTION_IF_NULL(primitive);
  965. auto iter = GetPrimitiveToEvalImplMap().find(primitive);
  966. if (iter != GetPrimitiveToEvalImplMap().end()) {
  967. return iter->second.in_white_list_;
  968. }
  969. auto uni_iter = GetUniformPrimitiveToImplMap().find(primitive);
  970. if (uni_iter != GetUniformPrimitiveToImplMap().end()) {
  971. return uni_iter->second.in_white_list_;
  972. }
  973. return false;
  974. }
  975. StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive) {
  976. MS_EXCEPTION_IF_NULL(primitive);
  977. auto iter = GetPrimitiveToEvalImplMap().find(primitive);
  978. if (iter == GetPrimitiveToEvalImplMap().end()) {
  979. return nullptr;
  980. }
  981. return iter->second.impl_;
  982. }
  983. PrimEvaluatorMap &GetPrimEvaluatorConstructors() {
  984. PrimEvaluatorMap &constructor = PrimEvaluatorConstructors;
  985. if (!constructor.empty()) {
  986. return constructor;
  987. }
  988. std::lock_guard<std::mutex> initLock(PrimEvaluatorConstructorMutex);
  989. if (constructor.empty()) {
  990. InitPrimEvaluatorConstructors();
  991. }
  992. return constructor;
  993. }
  994. namespace {
  995. bool IsSubtypeTuple(const AbstractBasePtr x, const TypePtr model) {
  996. MS_EXCEPTION_IF_NULL(x);
  997. MS_EXCEPTION_IF_NULL(model);
  998. auto x_tuple = dyn_cast<AbstractTuple>(x);
  999. auto model_tuple = dyn_cast<Tuple>(model);
  1000. if (x_tuple == nullptr || model_tuple == nullptr) {
  1001. return false;
  1002. }
  1003. if (model->IsGeneric()) {
  1004. return true;
  1005. }
  1006. if (x_tuple->size() != model_tuple->size()) {
  1007. return false;
  1008. }
  1009. for (size_t i = 0; i < x_tuple->size(); i++) {
  1010. bool is_subtype = IsSubtype((*x_tuple)[i], (*model_tuple)[i]);
  1011. if (!is_subtype) {
  1012. return false;
  1013. }
  1014. }
  1015. return true;
  1016. }
  1017. bool IsSubtypeArray(const AbstractBasePtr x, const TypePtr model) {
  1018. MS_EXCEPTION_IF_NULL(x);
  1019. MS_EXCEPTION_IF_NULL(model);
  1020. auto x_tensor = dyn_cast<AbstractTensor>(x);
  1021. auto model_tensor = dyn_cast<TensorType>(model);
  1022. if (x_tensor == nullptr || model_tensor == nullptr) {
  1023. return false;
  1024. }
  1025. if (model->IsGeneric()) {
  1026. return true;
  1027. }
  1028. return IsSubtype(x_tensor->element(), model_tensor->element());
  1029. }
  1030. bool IsSubtypeList(const AbstractBasePtr x, const TypePtr model) {
  1031. MS_EXCEPTION_IF_NULL(x);
  1032. MS_EXCEPTION_IF_NULL(model);
  1033. auto x_list = dyn_cast<AbstractList>(x);
  1034. auto model_list = dyn_cast<List>(model);
  1035. if (x_list == nullptr || model_list == nullptr) {
  1036. return false;
  1037. }
  1038. if (model->IsGeneric()) {
  1039. return true;
  1040. }
  1041. if (x_list->size() != model_list->size()) {
  1042. return false;
  1043. }
  1044. bool is_subtype = true;
  1045. for (size_t i = 0; i < x_list->size(); i++) {
  1046. is_subtype = IsSubtype((*x_list)[i], (*model_list)[i]);
  1047. if (!is_subtype) {
  1048. return false;
  1049. }
  1050. }
  1051. return is_subtype;
  1052. }
  1053. bool IsSubtypeClass(const AbstractBasePtr x, const TypePtr model) {
  1054. MS_EXCEPTION_IF_NULL(x);
  1055. MS_EXCEPTION_IF_NULL(model);
  1056. auto x_class = dyn_cast<AbstractClass>(x);
  1057. auto model_class = dyn_cast<Class>(model);
  1058. if (x_class == nullptr) {
  1059. return false;
  1060. }
  1061. if (model->IsGeneric()) {
  1062. return true;
  1063. }
  1064. if (x_class->tag() == model_class->tag()) {
  1065. auto m_attributes = model_class->GetAttributes();
  1066. auto x_attributes = x_class->attributes();
  1067. if (m_attributes.size() != x_attributes.size()) {
  1068. return false;
  1069. }
  1070. for (size_t i = 0; i < m_attributes.size(); i++) {
  1071. if (!IsSubtype(x_attributes[i].second, m_attributes[i].second)) {
  1072. return false;
  1073. }
  1074. }
  1075. return true;
  1076. }
  1077. return false;
  1078. }
  1079. inline bool IsSubtypeScalar(const AbstractBasePtr x, const TypePtr model) {
  1080. MS_EXCEPTION_IF_NULL(x);
  1081. MS_EXCEPTION_IF_NULL(model);
  1082. if (dyn_cast<AbstractScalar>(x) == nullptr) {
  1083. return false;
  1084. }
  1085. TypePtr x_type = x->GetTypeTrack();
  1086. return IsSubType(x_type, model);
  1087. }
  1088. } // namespace
  1089. bool IsSubtype(const AbstractBasePtr x, const TypePtr model) {
  1090. MS_EXCEPTION_IF_NULL(x);
  1091. MS_EXCEPTION_IF_NULL(model);
  1092. TypeId model_typeid = model->type_id();
  1093. switch (model_typeid) {
  1094. case kMetaTypeObject:
  1095. return true;
  1096. case kObjectTypeTuple:
  1097. return IsSubtypeTuple(x, model);
  1098. case kObjectTypeTensorType:
  1099. return IsSubtypeArray(x, model);
  1100. case kObjectTypeList:
  1101. return IsSubtypeList(x, model);
  1102. case kObjectTypeClass:
  1103. return IsSubtypeClass(x, model);
  1104. default:
  1105. if (IsSubType(model, std::make_shared<Number>())) {
  1106. return IsSubtypeScalar(x, model);
  1107. }
  1108. MS_LOG(EXCEPTION) << "Invalid model type: " << model->ToString() << ".";
  1109. }
  1110. }
  1111. } // namespace abstract
  1112. } // namespace mindspore