You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prim.h 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef PIPELINE_STATIC_ANALYSIS_PRIM_H_
  19. #define PIPELINE_STATIC_ANALYSIS_PRIM_H_
  20. #include <algorithm>
  21. #include <memory>
  22. #include <string>
  23. #include <unordered_map>
  24. #include <vector>
  25. #include "pipeline/static_analysis/evaluator.h"
  26. namespace mindspore {
  27. namespace abstract {
  28. using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const AnalysisEnginePtr &, const PrimitivePtr &,
  29. const AbstractBasePtrList &);
  30. struct StandartPrimitiveImplReg {
  31. StandardPrimitiveEvalImpl impl_; // Implement function of Primitive.
  32. bool in_white_list_; // true if this Primitive in white list, else false.
  33. };
  34. using PrimitiveEvalImplMap =
  35. std::unordered_map<PrimitivePtr, StandartPrimitiveImplReg, PrimitiveHasher, PrimitiveEqual>;
  36. class StandardPrimEvaluator : public TrivialPrimEvaluator {
  37. public:
  38. StandardPrimEvaluator(const PrimitivePtr primitive, StandardPrimitiveEvalImpl eval_impl)
  39. : TrivialPrimEvaluator("StandardPrimEvaluator"), prim_(primitive), eval_impl_(eval_impl) {}
  40. ~StandardPrimEvaluator() override = default;
  41. MS_DECLARE_PARENT(StandardPrimEvaluator, TrivialPrimEvaluator);
  42. AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
  43. PrimitivePtr prim() { return prim_; }
  44. std::string ToString() const override { return identifier_ + prim_->name(); }
  45. private:
  46. PrimitivePtr prim_;
  47. const StandardPrimitiveEvalImpl eval_impl_;
  48. };
  49. using StandardPrimEvaluatorPtr = std::shared_ptr<StandardPrimEvaluator>;
  50. class PythonPrimEvaluator : public TrivialPrimEvaluator {
  51. public:
  52. explicit PythonPrimEvaluator(const PrimitivePyPtr primitive)
  53. : TrivialPrimEvaluator("PythonPrimEvaluator"), prim_py_(primitive) {}
  54. ~PythonPrimEvaluator() override = default;
  55. MS_DECLARE_PARENT(PythonPrimEvaluator, TrivialPrimEvaluator);
  56. AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
  57. PrimitivePtr prim() { return dyn_cast<Primitive>(prim_py_); }
  58. std::string ToString() const override { return identifier_ + prim_py_->name(); }
  59. private:
  60. PrimitivePyPtr prim_py_;
  61. };
  62. class DoSignatureEvaluator : public Evaluator {
  63. public:
  64. explicit DoSignatureEvaluator(const PrimitivePtr primitive) : Evaluator("DoSignatureEvaluator"), prim_(primitive) {}
  65. ~DoSignatureEvaluator() override = default;
  66. AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
  67. AnfNodeConfigPtr out_config = nullptr) override;
  68. AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) override {
  69. MS_LOG(EXCEPTION) << "Infer() should not be called, Run() method should be called";
  70. }
  71. private:
  72. PrimitivePtr prim_;
  73. };
  74. class UnpackGraphEvaluator : public Evaluator {
  75. public:
  76. explicit UnpackGraphEvaluator(const PrimitivePtr primitive) : Evaluator("UnpackGraphEvaluator"), prim_(primitive) {}
  77. ~UnpackGraphEvaluator() override = default;
  78. AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
  79. AnfNodeConfigPtr out_config = nullptr) override;
  80. AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) override {
  81. MS_LOG(EXCEPTION) << "Infer() should not be called, Run() method should be called";
  82. }
  83. private:
  84. PrimitivePtr prim_;
  85. };
  86. bool IsInWhiteList(PrimitivePtr primitive);
  87. StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive);
  88. using ValuePtrList = std::vector<ValuePtr>;
  89. using PrimitiveImpl = ValuePtr (*)(const ValuePtrList &);
  90. class UniformPrimEvaluator : public TrivialPrimEvaluator {
  91. public:
  92. UniformPrimEvaluator(const FunctionPtr func_desc, PrimitiveImpl impl, bool eval_value, const TypePtr specify_out_type)
  93. : TrivialPrimEvaluator("UniformPrimEvaluator"),
  94. impl_(impl),
  95. eval_value_(eval_value),
  96. func_desc_(func_desc),
  97. nargs_(func_desc_->args().size()),
  98. return_value_type_(func_desc_->retval()),
  99. specify_out_type_(specify_out_type) {
  100. for (size_t i = 0; i < nargs_; ++i) {
  101. TypePtr type = func_desc_->args()[i];
  102. if (type_map_[type]) {
  103. type_map_[type]->push_back(i);
  104. } else {
  105. type_map_[type] = std::make_shared<std::vector<size_t>>();
  106. type_map_[type]->push_back(i);
  107. }
  108. }
  109. }
  110. ~UniformPrimEvaluator() override = default;
  111. MS_DECLARE_PARENT(UniformPrimEvaluator, TrivialPrimEvaluator);
  112. AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
  113. ValuePtr RunImpl(const ValuePtrList &args) const;
  114. // If eval_value_ is False, return broadened arguments.
  115. AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override {
  116. if (!eval_value_) {
  117. AbstractBasePtrList broadened_args_spec_list;
  118. (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened_args_spec_list),
  119. [](const AbstractBasePtr &arg) -> AbstractBasePtr { return arg->Broaden(); });
  120. return broadened_args_spec_list;
  121. }
  122. return args_spec_list;
  123. }
  124. private:
  125. PrimitiveImpl impl_;
  126. bool eval_value_;
  127. const FunctionPtr func_desc_;
  128. const std::size_t nargs_;
  129. const TypePtr return_value_type_;
  130. const TypePtr specify_out_type_;
  131. std::unordered_map<TypePtr, std::shared_ptr<std::vector<size_t>>, TypeHasher, TypeEqual> type_map_;
  132. };
  133. PrimEvaluatorMap &GetPrimEvaluatorConstructors();
  134. // Check whether type x is a subtype of model.
  135. bool IsSubtype(const AbstractBasePtr x, const TypePtr model);
  136. void ClearPrimEvaluatorMap();
  137. py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base);
  138. AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
  139. const AbstractBasePtrList &args_spec_list);
  140. AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &,
  141. const AbstractBasePtrList &args_spec_list);
  142. AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  143. const AbstractBasePtrList &args_spec_list);
  144. AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  145. const AbstractBasePtrList &args_spec_list);
  146. AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
  147. const AbstractBasePtrList &args_spec_list);
  148. AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &,
  149. const AbstractBasePtrList &args_spec_list);
  150. AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &,
  151. const AbstractBasePtrList &args_spec_list);
  152. AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &,
  153. const AbstractBasePtrList &args_spec_list);
  154. AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &,
  155. const AbstractBasePtrList &args_spec_list);
  156. AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  157. const AbstractBasePtrList &args_spec_list);
  158. AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  159. const AbstractBasePtrList &args_spec_list);
  160. AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  161. const AbstractBasePtrList &args_spec_list);
  162. AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  163. const AbstractBasePtrList &args_spec_list);
  164. AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  165. const AbstractBasePtrList &args_spec_list);
  166. AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  167. const AbstractBasePtrList &args_spec_list);
  168. AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  169. const AbstractBasePtrList &args_spec_list);
  170. AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  171. const AbstractBasePtrList &args_spec_list);
  172. AbstractBasePtr InferImplGelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  173. const AbstractBasePtrList &args_spec_list);
  174. AbstractBasePtr InferImplGeluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  175. const AbstractBasePtrList &args_spec_list);
  176. AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  177. const AbstractBasePtrList &args_spec_list);
  178. AbstractBasePtr InferImplZerosLikeTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  179. const AbstractBasePtrList &args_spec_list);
  180. AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  181. const AbstractBasePtrList &args_spec_list);
  182. AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  183. const AbstractBasePtrList &args_spec_list);
  184. AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  185. const AbstractBasePtrList &args_spec_list);
  186. AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  187. const AbstractBasePtrList &args_spec_list);
  188. AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  189. const AbstractBasePtrList &args_spec_list);
  190. AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  191. const AbstractBasePtrList &args_spec_list);
  192. AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  193. const AbstractBasePtrList &args_spec_list);
  194. AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  195. const AbstractBasePtrList &args_spec_list);
  196. AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  197. const AbstractBasePtrList &args_spec_list);
  198. AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  199. const AbstractBasePtrList &args_spec_list);
  200. AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  201. const AbstractBasePtrList &args_spec_list);
  202. AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  203. const AbstractBasePtrList &args_spec_list);
  204. AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  205. const AbstractBasePtrList &args_spec_list);
  206. AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  207. const AbstractBasePtrList &args_spec_list);
  208. AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  209. const AbstractBasePtrList &args_spec_list);
  210. AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  211. const AbstractBasePtrList &args_spec_list);
  212. AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  213. const AbstractBasePtrList &args_spec_list);
  214. AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  215. const AbstractBasePtrList &args_spec_list);
  216. AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  217. const AbstractBasePtrList &args_spec_list);
  218. AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  219. const AbstractBasePtrList &args_spec_list);
  220. AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  221. const AbstractBasePtrList &args_spec_list);
  222. AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  223. const AbstractBasePtrList &args_spec_list);
  224. AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  225. const AbstractBasePtrList &args_spec_list);
  226. AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  227. const AbstractBasePtrList &args_spec_list);
  228. AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  229. const AbstractBasePtrList &args_spec_list);
  230. AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  231. const AbstractBasePtrList &args_spec_list);
  232. AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  233. const AbstractBasePtrList &args_spec_list);
  234. AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  235. const AbstractBasePtrList &args_spec_list);
  236. AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  237. const AbstractBasePtrList &args_spec_list);
  238. AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  239. const AbstractBasePtrList &args_spec_list);
  240. AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  241. const AbstractBasePtrList &args_spec_list);
  242. AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  243. const AbstractBasePtrList &args_spec_list);
  244. AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  245. const AbstractBasePtrList &args_spec_list);
  246. AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  247. const AbstractBasePtrList &args_spec_list);
  248. AbstractBasePtr InferImplGenShapeIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  249. const AbstractBasePtrList &args_spec_list);
  250. AbstractBasePtr InferImplGenInverseIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  251. const AbstractBasePtrList &args_spec_list);
  252. AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  253. const AbstractBasePtrList &args_spec_list);
  254. AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  255. const AbstractBasePtrList &args_spec_list);
  256. AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  257. const AbstractBasePtrList &args_spec_list);
  258. AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  259. const AbstractBasePtrList &args_spec_list);
  260. AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  261. const AbstractBasePtrList &args_spec_list);
  262. AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  263. const AbstractBasePtrList &args_spec_list);
  264. AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  265. const AbstractBasePtrList &args_spec_list);
  266. AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  267. const AbstractBasePtrList &args_spec_list);
  268. AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  269. const AbstractBasePtrList &args_spec_list);
  270. AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  271. const AbstractBasePtrList &args_spec_list);
  272. AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  273. const AbstractBasePtrList &args_spec_list);
  274. AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  275. const AbstractBasePtrList &args_spec_list);
  276. AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  277. const AbstractBasePtrList &args_spec_list);
  278. AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  279. const AbstractBasePtrList &args_spec_list);
  280. AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  281. const AbstractBasePtrList &args_spec_list);
  282. AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  283. const AbstractBasePtrList &args_spec_list);
  284. AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  285. const AbstractBasePtrList &args_spec_list);
  286. AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  287. const AbstractBasePtrList &args_spec_list);
  288. AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  289. const AbstractBasePtrList &args_spec_list);
  290. AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  291. const AbstractBasePtrList &args_spec_list);
  292. AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  293. const AbstractBasePtrList &args_spec_list);
  294. AbstractBasePtr InferImplScalarSummary(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  295. const AbstractBasePtrList &args_spec_list);
  296. AbstractBasePtr InferImplTensorSummary(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  297. const AbstractBasePtrList &args_spec_list);
  298. } // namespace abstract
  299. } // namespace mindspore
  300. #endif // PIPELINE_STATIC_ANALYSIS_PRIM_H_