You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ops.cc 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "operator/ops.h"
  17. #include <memory>
  18. #include <string>
  19. #include "pipeline/parse/python_adapter.h"
  20. #include "pipeline/parse/data_converter.h"
  21. namespace mindspore {
  22. // namespace to support primitive operators
  23. namespace prim {
  24. // Arithmetic
  25. const PrimitivePtr kPrimScalarAdd = std::make_shared<Primitive>("scalar_add");
  26. const PrimitivePtr kPrimScalarSub = std::make_shared<Primitive>("scalar_sub");
  27. const PrimitivePtr kPrimScalarMul = std::make_shared<Primitive>("scalar_mul");
  28. const PrimitivePtr kPrimScalarDiv = std::make_shared<Primitive>("scalar_div");
  29. const PrimitivePtr kPrimScalarFloordiv = std::make_shared<Primitive>("scalar_floordiv");
  30. const PrimitivePtr kPrimScalarMod = std::make_shared<Primitive>("scalar_mod");
  31. const PrimitivePtr kPrimScalarPow = std::make_shared<Primitive>("scalar_pow");
  32. const PrimitivePtr kPrimScalarTrunc = std::make_shared<Primitive>("scalar_trunc");
  33. const PrimitivePtr kPrimScalarFloor = std::make_shared<Primitive>("scalar_floor");
  34. const PrimitivePtr kPrimScalarUadd = std::make_shared<Primitive>("scalar_uadd");
  35. const PrimitivePtr kPrimScalarUsub = std::make_shared<Primitive>("scalar_usub");
  36. const PrimitivePtr kPrimScalarExp = std::make_shared<Primitive>("scalar_exp");
  37. const PrimitivePtr kPrimScalarLog = std::make_shared<Primitive>("scalar_log");
  38. const PrimitivePtr kPrimScalarSin = std::make_shared<Primitive>("scalar_sin");
  39. const PrimitivePtr kPrimScalarCos = std::make_shared<Primitive>("scalar_cos");
  40. const PrimitivePtr kPrimScalarTan = std::make_shared<Primitive>("scalar_tan");
  41. // Comparisons
  42. const PrimitivePtr kPrimScalarEq = std::make_shared<Primitive>("scalar_eq");
  43. const PrimitivePtr kPrimScalarLt = std::make_shared<Primitive>("scalar_lt");
  44. const PrimitivePtr kPrimScalarGt = std::make_shared<Primitive>("scalar_gt");
  45. const PrimitivePtr kPrimScalarNe = std::make_shared<Primitive>("scalar_ne");
  46. const PrimitivePtr kPrimScalarLe = std::make_shared<Primitive>("scalar_le");
  47. const PrimitivePtr kPrimScalarGe = std::make_shared<Primitive>("scalar_ge");
  48. const PrimitivePtr kPrimBoolNot = std::make_shared<Primitive>("bool_not");
  49. const PrimitivePtr kPrimBoolAnd = std::make_shared<Primitive>("bool_and");
  50. const PrimitivePtr kPrimBoolOr = std::make_shared<Primitive>("bool_or");
  51. const PrimitivePtr kPrimBoolEq = std::make_shared<Primitive>("bool_eq");
  52. // Type introspection
  53. const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof");
  54. const PrimitivePtr kPrimHasType = std::make_shared<Primitive>("hastype");
  55. // Statements
  56. const PrimitivePtr kPrimSwitch = std::make_shared<Primitive>("switch");
  57. const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return");
  58. const PrimitivePtr kPrimAssign = std::make_shared<Primitive>("Assign");
  59. const PrimitivePtr kPrimAssignAdd = std::make_shared<Primitive>("AssignAdd");
  60. const PrimitivePtr kPrimAssignSub = std::make_shared<Primitive>("AssignSub");
  61. const PrimitivePtr kPrimSelect = std::make_shared<Primitive>("Select");
  62. const PrimitivePtr kPrimDistribute = std::make_shared<Primitive>("distribute");
  63. const PrimitivePtr kPrimDot = std::make_shared<Primitive>("dot");
  64. const PrimitivePtr kPrimIm2Col = std::make_shared<Primitive>("im2col");
  65. const PrimitivePtr kPrimCol2Im = std::make_shared<Primitive>("col2im");
  66. const PrimitivePtr kPrimIm2ColV1 = std::make_shared<Primitive>("im2col_v1");
  67. const PrimitivePtr kPrimCol2ImV1 = std::make_shared<Primitive>("col2im_v1");
  68. const PrimitivePtr kPrimResolve = std::make_shared<Primitive>("resolve");
  69. const PrimitivePtr kPrimEmbed = std::make_shared<Primitive>("embed");
  70. const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEmbed");
  71. const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance");
  72. // Structure
  73. const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal");
  74. const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat");
  75. const PrimitivePtr kPrimMakeTuple = std::make_shared<Primitive>("make_tuple");
  76. const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list");
  77. const PrimitivePtr kPrimMakeDict = std::make_shared<Primitive>("make_dict");
  78. const PrimitivePtr kPrimMakeKeywordArg = std::make_shared<Primitive>("make_keyword_arg");
  79. const PrimitivePtr kPrimExtractKeywordArg = std::make_shared<Primitive>("extract_keyword_arg");
  80. const PrimitivePtr kPrimMakeSlice = std::make_shared<Primitive>("make_slice");
  81. const PrimitivePtr kPrimMakeRecord = std::make_shared<Primitive>("make_record");
  82. const PrimitivePtr kPrimTupleGetItem = std::make_shared<Primitive>("tuple_getitem");
  83. const PrimitivePtr kPrimListGetItem = std::make_shared<Primitive>("list_getitem");
  84. const PrimitivePtr kPrimArrayGetItem = std::make_shared<Primitive>("array_getitem");
  85. const PrimitivePtr kPrimTupleSetItem = std::make_shared<Primitive>("tuple_setitem");
  86. const PrimitivePtr kPrimListSetItem = std::make_shared<Primitive>("list_setitem");
  87. const PrimitivePtr kPrimArraySetItem = std::make_shared<Primitive>("array_setitem");
  88. const PrimitivePtr kPrimDictGetItem = std::make_shared<Primitive>("dict_getitem");
  89. const PrimitivePtr kPrimDictSetItem = std::make_shared<Primitive>("dict_setitem");
  90. const PrimitivePtr kPrimListAppend = std::make_shared<Primitive>("list_append");
  91. const PrimitivePtr kPrimGetAttr = std::make_shared<Primitive>("getattr");
  92. const PrimitivePtr kPrimTupleLen = std::make_shared<Primitive>("tuple_len");
  93. const PrimitivePtr kPrimDictLen = std::make_shared<Primitive>("dict_len");
  94. const PrimitivePtr kPrimListLen = std::make_shared<Primitive>("list_len");
  95. const PrimitivePtr kPrimArrayLen = std::make_shared<Primitive>("array_len");
  96. const PrimitivePtr kPrimListMap = std::make_shared<Primitive>("list_map");
  97. const PrimitivePtr kPrimListReduce = std::make_shared<Primitive>("list_reduce");
  98. const PrimitivePtr kPrimTupleReversed = std::make_shared<Primitive>("tuple_reversed");
  99. const PrimitivePtr kPrimTileShape = std::make_shared<Primitive>("tile_shape");
  100. const PrimitivePtr kPrimReducedShape = std::make_shared<Primitive>("reduced_shape");
  101. const PrimitivePtr kPrimTupleDiv = std::make_shared<Primitive>("tuple_div");
  102. const PrimitivePtr kPrimTupleToArray = std::make_shared<Primitive>("tuple_to_array");
  103. const PrimitivePtr kPrimShapeMul = std::make_shared<Primitive>("shape_mul");
  104. const PrimitivePtr kPrimGenerateShapeIndex = std::make_shared<Primitive>("generate_shape_index");
  105. const PrimitivePtr kPrimGenerateInverseIndex = std::make_shared<Primitive>("generate_inverse_index");
  106. const PrimitivePtr kPrimTupleEqual = std::make_shared<Primitive>("tuple_equal");
  107. const PrimitivePtr kPrimListEqual = std::make_shared<Primitive>("list_equal");
  108. const PrimitivePtr kPrimMakeRange = std::make_shared<Primitive>("make_range");
  109. const PrimitivePtr kPrimStopGradient = std::make_shared<Primitive>("stop_gradient");
  110. // Arrays
  111. const PrimitivePtr kPrimScalarToArray = std::make_shared<Primitive>("scalar_to_array");
  112. const PrimitivePtr kPrimArrayToScalar = std::make_shared<Primitive>("array_to_scalar");
  113. const PrimitivePtr kPrimBroadcastShape = std::make_shared<Primitive>("broadcast_shape");
  114. const PrimitivePtr kPrimArrayMap = std::make_shared<Primitive>("array_map");
  115. const PrimitivePtr kPrimArrayReduce = std::make_shared<Primitive>("array_reduce");
  116. const PrimitivePtr kPrimShape = std::make_shared<Primitive>("Shape");
  117. const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast");
  118. const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat");
  119. const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze");
  120. const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose");
  121. const PrimitivePtr kPrimGatherV2 = std::make_shared<Primitive>("GatherV2");
  122. const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size");
  123. const PrimitivePtr kPrimArgMax = std::make_shared<Primitive>("Argmax");
  124. const PrimitivePtr kPrimPack = std::make_shared<Primitive>("Pack");
  125. const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>("UnsortedSegmentSum");
  126. const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset");
  127. const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape");
  128. const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile");
  129. const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN");
  130. const PrimitivePtr KPrimTransData = std::make_shared<Primitive>("TransData");
  131. // Maths
  132. const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd");
  133. const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul");
  134. const PrimitivePtr kPrimBatchMatMul = std::make_shared<Primitive>("BatchMatMul");
  135. const PrimitivePtr kPrimMaximumGrad = std::make_shared<Primitive>("MaximumGrad");
  136. const PrimitivePtr kPrimMinimumGrad = std::make_shared<Primitive>("MinimumGrad");
  137. const PrimitivePtr kPrimReduceMean = std::make_shared<Primitive>("ReduceMean");
  138. const PrimitivePtr kPrimReduceSum = std::make_shared<Primitive>("ReduceSum");
  139. const PrimitivePtr kPrimReduceAll = std::make_shared<Primitive>("ReduceAll");
  140. const PrimitivePtr kPrimReduceMax = std::make_shared<Primitive>("ReduceMax");
  141. const PrimitivePtr kPrimReduceMin = std::make_shared<Primitive>("ReduceMin");
  142. const PrimitivePtr kPrimNeg = std::make_shared<Primitive>("Neg");
  143. const PrimitivePtr kPrimSub = std::make_shared<Primitive>("Sub");
  144. const PrimitivePtr kPrimMul = std::make_shared<Primitive>("Mul");
  145. const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum");
  146. const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum");
  147. const PrimitivePtr kPrimSquare = std::make_shared<Primitive>("Square");
  148. const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal");
  149. const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less");
  150. const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual");
  151. // NN
  152. const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");
  153. const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax");
  154. const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad");
  155. const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh");
  156. const PrimitivePtr kPrimTanhGrad = std::make_shared<Primitive>("TanhGrad");
  157. const PrimitivePtr kPrimPooling = std::make_shared<Primitive>("Pooling");
  158. const PrimitivePtr kPrimPoolingGrad = std::make_shared<Primitive>("PoolingGrad");
  159. const PrimitivePtr kPrimMaxPool = std::make_shared<Primitive>("MaxPool");
  160. const PrimitivePtr kPrimMaxPoolGrad = std::make_shared<Primitive>("MaxPoolGrad");
  161. const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm");
  162. const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D");
  163. const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>("FusedBatchNormGrad");
  164. const PrimitivePtr kPrimReluGrad = std::make_shared<Primitive>("ReluGrad");
  165. const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared<Primitive>("Conv2DBackpropInput");
  166. const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared<Primitive>("Conv2DBackpropFilter");
  167. const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared<Primitive>("DepthwiseConv2dNative");
  168. const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter =
  169. std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropFilter");
  170. const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput =
  171. std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropInput");
  172. const PrimitivePtr kPrimBiasAddGrad = std::make_shared<Primitive>("BiasAddGrad");
  173. const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits = std::make_shared<Primitive>("SoftmaxCrossEntropyWithLogits");
  174. const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits =
  175. std::make_shared<Primitive>("SparseSoftmaxCrossEntropyWithLogits");
  176. const PrimitivePtr kPrimMomentum = std::make_shared<Primitive>("Momentum");
  177. const PrimitivePtr kPrimApplyMomentum = std::make_shared<Primitive>("ApplyMomentum");
  178. const PrimitivePtr kPrimLayerNorm = std::make_shared<Primitive>("LayerNorm");
  179. const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGrad");
  180. const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared<Primitive>("LayerNormXBackprop");
  181. const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Primitive>("LayerNormBetaGammaBackprop");
  182. const PrimitivePtr kPrimDropoutGenMask = std::make_shared<Primitive>("DropoutGenMask");
  183. const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot");
  184. const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu");
  185. const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad");
  186. const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU");
  187. const PrimitivePtr kPrimZerosLikeTensor = std::make_shared<Primitive>("zeros_like_tensor");
  188. const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop");
  189. // Other miscellaneous
  190. const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("identity");
  191. const PrimitivePtr kPrimPartial = std::make_shared<Primitive>("partial");
  192. const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J");
  193. const PrimitivePtr kPrimEnvSetItem = std::make_shared<Primitive>("env_setitem");
  194. const PrimitivePtr kPrimEnvGetItem = std::make_shared<Primitive>("env_getitem");
  195. const PrimitivePtr kPrimEnvAdd = std::make_shared<Primitive>("env_add");
  196. const PrimitivePtr kPrimMakeRefKey = std::make_shared<Primitive>("MakeRefKey");
  197. const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key");
  198. const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value");
  199. const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin");
  200. const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf");
  201. const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType");
  202. const PrimitivePtr kPrimSameTypeShape = std::make_shared<Primitive>("SameTypeShape");
  203. const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop");
  204. const PrimitivePtr kPrimPrint = std::make_shared<Primitive>("Print");
  205. const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref");
  206. const PrimitivePtr kPrimDepend = std::make_shared<Primitive>("depend");
  207. const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem");
  208. const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("BroadcastGradientArgs");
  209. const PrimitivePtr kPrimControlDepend = std::make_shared<Primitive>("ControlDepend");
  210. const PrimitivePtr kPrimIs_ = std::make_shared<Primitive>("is_");
  211. const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not");
  212. const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict");
  213. const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict");
  214. // Comm ops
  215. const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
  216. const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
  217. const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
  218. // Debug ops
  219. const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary");
  220. const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary");
  221. const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary");
  222. const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary");
  223. ValuePtr GetPythonOps(const std::string &op_name, const std::string &module_name) {
  224. py::object obj = parse::python_adapter::GetPyFn(module_name, op_name);
  225. ValuePtr node = nullptr;
  226. bool succ = parse::ConvertData(obj, &node);
  227. if (!succ) {
  228. MS_LOG(EXCEPTION) << "get Python op " << op_name << " from " << module_name << " fail";
  229. }
  230. return node;
  231. }
  232. } // namespace prim
  233. } // namespace mindspore