You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ops_utils.h 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPS_UTILS_H_
  17. #define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPS_UTILS_H_
  18. namespace mindspore {
  19. namespace parallel {
  20. constexpr size_t PRELU_INPUTS_SIZE = 2;
  21. constexpr size_t PRELU_OUTPUTS_SIZE = 1;
  22. constexpr size_t PRELU_SECOND_INPUT_SIZE = 1;
  23. constexpr int32_t PRELU_CHANNEL_INDEX = 1;
  24. constexpr int32_t PRELU_CHANNEL_STRATEGY = 1;
  25. constexpr int32_t NO_SPLIT_MAP = -1;
  26. constexpr int32_t NO_SPLIT_STRATEGY = 1;
  27. constexpr int32_t SPLIT_FLAG = 1;
  28. constexpr int32_t NO_SPLIT_FLAG = 0;
  29. constexpr size_t MATMUL_ATTRS_SIZE = 2;
  30. constexpr size_t MATMUL_INPUTS_SIZE = 2;
  31. constexpr size_t MATMUL_OUTPUTS_SIZE = 1;
  32. constexpr size_t ACTIVATION_ATTR_SIZE = 1;
  33. constexpr size_t SOFTMAX_ATTR_SIZE = 1;
  34. constexpr size_t ACTIVATION_INPUTS_SIZE = 1;
  35. constexpr size_t ACTIVATION_OUTPUTS_SIZE = 1;
  36. constexpr size_t EXPANDDIMS_INPUT_SIZE = 2;
  37. constexpr size_t DROPOUT_DO_MASK_CNODE_INPUT_SIZE = 4;
  38. constexpr size_t DROPOUT_GEN_MASK_CNODE_INPUT_SIZE = 3;
  39. constexpr size_t DROPOUT_GEN_MASK_INDEX = 2;
  40. constexpr size_t DROPOUT_DO_MASK_KEEP_PROB_INDEX = 3;
  41. constexpr size_t SoftmaxCrossEntropyWithLogitsAttrSize = 1;
  42. constexpr size_t SoftmaxCrossEntropyWithLogitsInputsSize = 2;
  43. constexpr size_t SoftmaxCrossEntropyWithLogitsOutputsSize = 2;
  44. constexpr double EPS = 1e-6;
  45. constexpr double INF = 1e20;
  46. constexpr char AUTO_PARALLEL_RUN_ONCE_ONLY[] = "auto_parallel_run_once_only";
  47. constexpr char SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY[] = "semi_auto_parallel_run_once_only";
  48. constexpr char STRATEGY[] = "strategy";
  49. constexpr char GEN_STRATEGY[] = "gen_strategy";
  50. constexpr char REDUCE_OP_SUM[] = "sum";
  51. constexpr char REDUCE_OP_MAX[] = "max";
  52. constexpr char REDUCE_OP_MIN[] = "min";
  53. constexpr char OP_PATH[] = "mindspore.ops.operations";
  54. constexpr char GET_OP_FUNCTION_PATH[] = "mindspore.parallel._utils";
  55. constexpr char GET_OP_FUNCTION[] = "_get_python_op";
  56. constexpr char KEEP_DIMS[] = "keep_dims";
  57. constexpr char CROSS_BATCH[] = "cross_batch";
  58. constexpr char STEP_PARALLEL_BEGIN[] = "step_parallel_begin";
  59. constexpr char STEP_PARALLEL_END[] = "step_parallel_end";
  60. constexpr char STEP_AUTO_PARALLEL_BEGIN[] = "step_auto_parallel_begin.dot";
  61. constexpr char RELU_TYPE[] = "relu";
  62. constexpr char RELU6_TYPE[] = "relu6";
  63. constexpr char SIGMOID_TYPE[] = "sigmoid";
  64. constexpr char OP[] = "op";
  65. constexpr char IDENTITY_INFO[] = "identity_info";
  66. constexpr char DIVISOR[] = "divisor";
  67. constexpr char NONE[] = "None";
  68. constexpr char DEPEND[] = "depend";
  69. constexpr char BATCH_PARALLEL[] = "BatchParallel";
  70. constexpr char ACTIVATION_TYPE[] = "activation_type";
  71. constexpr char TRANSPOSE_A[] = "transpose_a";
  72. constexpr char TRANSPOSE_B[] = "transpose_b";
  73. constexpr char SHAPE[] = "shape";
  74. constexpr char BEGIN_MASK[] = "begin_mask";
  75. constexpr char END_MASK[] = "end_mask";
  76. constexpr char ELLIPSIS_MASK[] = "ellipsis_mask";
  77. constexpr char NEW_AXIS_MASK[] = "new_axis_mask";
  78. constexpr char SHRINK_AXIS_MASK[] = "shrink_axis_mask";
  79. constexpr char BEGIN[] = "begin";
  80. constexpr char END[] = "end";
  81. constexpr char STRIDES[] = "strides";
  82. constexpr char GROUP[] = "group";
  83. constexpr char AXIS[] = "axis";
  84. constexpr char OUTPUT_NUM[] = "output_num";
  85. constexpr char SPLIT_COUNT[] = "split_count";
  86. constexpr char SPLIT_DIM[] = "split_dim";
  87. constexpr char CONCAT_DIM[] = "concat_dim";
  88. constexpr char FORWARD[] = "forward";
  89. constexpr char BACKWARD[] = "backward";
  90. constexpr char REDISTRIBUTION[] = "redistribution";
  91. constexpr char REPLACE[] = "replace";
  92. constexpr char CONNSYMBOL[] = "/";
  93. constexpr char INSTANCE_NAME[] = "instance_name";
  94. constexpr char SPLIT_SENS[] = "split_sens";
  95. constexpr char SPLIT_TENSOR[] = "split_tensor";
  96. constexpr char DEV_MAT[] = "dev_mat";
  97. constexpr char TENSOR_MAP[] = "tensor_map";
  98. constexpr char SEED0[] = "Seed0";
  99. constexpr char SEED1[] = "Seed1";
  100. constexpr char KEEP_PROB[] = "keep_prob";
  101. constexpr char SRC[] = "src";
  102. constexpr char CLONE_INFO[] = "clone_info";
  103. constexpr char CLONED[] = "cloned";
  104. constexpr char BE_CLONED[] = "be_cloned";
  105. constexpr char CLONED_INDEX[] = "cloned_index";
  106. constexpr char BE_CLONED_INDEX[] = "be_cloned_index";
  107. constexpr char GROUP_RANKS[] = "group_ranks";
  108. constexpr char IS_IN_FORWARD[] = "is_in_forward";
  109. constexpr char DEFAULT_INPUT[] = "default_input";
  110. constexpr char DTYPE[] = "dtype";
  111. constexpr char DEV_NUM[] = "dev_num";
  112. constexpr char MEAN_FLAG[] = "mean_flag";
  113. constexpr char TYPES[] = "types";
  114. constexpr char SHAPES[] = "shapes";
  115. constexpr char GETNEXT_NUM[] = "output_num";
  116. constexpr char SHARED_NAME[] = "shared_name";
  117. constexpr char MIRROR_OP[] = "mirror_op";
  118. constexpr char FORWARD_OP[] = "forward_op";
  119. constexpr char REDISTRIBUTION_OP[] = "redistribution_op";
  120. // Operator
  121. constexpr char VIRTUAL_DIV[] = "_VirtualDiv";
  122. constexpr char GET_TENSOR_SLICE[] = "_GetTensorSlice";
  123. constexpr char SPLIT[] = "Split";
  124. constexpr char ALL_TO_ALL[] = "_AlltoAll";
  125. constexpr char PERMUTE_BY_AXIS[] = "PermuteByAxis";
  126. constexpr char CONCAT_BY_AXIS[] = "ConcatByAxis";
  127. constexpr char SPLIT_BY_AXIS[] = "SplitByAxis";
  128. constexpr char ALL_REDUCE[] = "AllReduce";
  129. constexpr char MIRROR_OPERATOR[] = "_MirrorOperator";
  130. constexpr char STRIDED_SLICE[] = "StridedSlice";
  131. constexpr char ALL_GATHER[] = "AllGather";
  132. constexpr char REDUCE_SCATTER[] = "ReduceScatter";
  133. constexpr char CONCAT[] = "Concat";
  134. constexpr char SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SoftmaxCrossEntropyWithLogits";
  135. constexpr char SIGMOID_CROSS_ENTROPY_WITH_LOGITS[] = "SigmoidCrossEntropyWithLogits";
  136. constexpr char MATMUL[] = "MatMul";
  137. constexpr char GELU[] = "Gelu";
  138. constexpr char TANH[] = "Tanh";
  139. constexpr char SOFTMAX[] = "Softmax";
  140. constexpr char LOG_SOFTMAX[] = "LogSoftmax";
  141. constexpr char ACTIVATION[] = "Activation";
  142. constexpr char PRELU[] = "PReLU";
  143. constexpr char FLOORDIV[] = "FloorDiv";
  144. constexpr char MAXPOOL[] = "MaxPool";
  145. constexpr char MAXPOOLV2[] = "MaxPoolV2";
  146. constexpr char L2_NORMALIZE[] = "L2Normalize";
  147. constexpr char TRANSPOSE[] = "Transpose";
  148. constexpr char RESHAPE[] = "Reshape";
  149. constexpr char TENSOR_ADD[] = "TensorAdd";
  150. constexpr char BIAS_ADD[] = "BiasAdd";
  151. constexpr char SUB[] = "Sub";
  152. constexpr char MUL[] = "Mul";
  153. constexpr char DIV[] = "Div";
  154. constexpr char REAL_DIV[] = "RealDiv";
  155. constexpr char ASSIGN_SUB[] = "AssignSub";
  156. constexpr char GREATER[] = "Greater";
  157. constexpr char VIRTUAL_DATA_SET[] = "_VirtualDataset";
  158. constexpr char VIRTUAL_DATA_SET_INFO[] = "VirtualDatasetInfo";
  159. constexpr char SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SparseSoftmaxCrossEntropyWithLogits";
  160. constexpr char RELU[] = "ReLU";
  161. constexpr char ONEHOT[] = "OneHot";
  162. constexpr char DROPOUT_DO_MASK[] = "DropoutDoMask";
  163. constexpr char DROPOUT_GEN_MASK[] = "DropoutGenMask";
  164. constexpr char REDUCE_MAX[] = "ReduceMax";
  165. constexpr char REDUCE_MIN[] = "ReduceMin";
  166. constexpr char REDUCE_SUM[] = "ReduceSum";
  167. constexpr char REDUCE_MEAN[] = "ReduceMean";
  168. constexpr char ARGMAXWITHVALUE[] = "ArgMaxWithValue";
  169. constexpr char ARGMINWITHVALUE[] = "ArgMinWithValue";
  170. constexpr char CONV2D[] = "Conv2D";
  171. constexpr char FUSE_BATCH_NORM[] = "FusedBatchNorm";
  172. constexpr char BATCH_NORM[] = "BatchNorm";
  173. constexpr char LAYER_NORM[] = "LayerNorm";
  174. constexpr char POOLING[] = "Pooling";
  175. constexpr char CAST[] = "Cast";
  176. constexpr char MAX_POOL_WITH_ARGMAX[] = "MaxPoolWithArgmax";
  177. constexpr char SIMPLE_MEAN[] = "SimpleMean";
  178. constexpr char FLATTEN[] = "Flatten";
  179. constexpr char J[] = "J";
  180. constexpr char TMPIDENTITY_INFO_NAME[] = "identity_info";
  181. constexpr char COS[] = "Cos";
  182. constexpr char ACOS[] = "ACos";
  183. constexpr char EXP[] = "Exp";
  184. constexpr char LOG[] = "Log";
  185. constexpr char SIGMOID[] = "Sigmoid";
  186. constexpr char POW[] = "Pow";
  187. constexpr char MAXIMUM[] = "Maximum";
  188. constexpr char MINIMUM[] = "Minimum";
  189. constexpr char EQUAL[] = "Equal";
  190. constexpr char NOT_EQUAL[] = "NotEqual";
  191. constexpr char LOGICALNOT[] = "LogicalNot";
  192. constexpr char GATHERV2[] = "GatherV2";
  193. constexpr char STRIDEDSLICE[] = "StridedSlice";
  194. constexpr char BROADCAST[] = "Broadcast";
  195. constexpr char SQRT[] = "Sqrt";
  196. constexpr char ASSIGN[] = "Assign";
  197. constexpr char GET_NEXT[] = "GetNext";
  198. constexpr char SQUEEZE[] = "Squeeze";
  199. constexpr char NEG[] = "Neg";
  200. constexpr char BATCH_MATMUL[] = "BatchMatMul";
  201. constexpr char EXPAND_DIMS[] = "ExpandDims";
  202. constexpr char SQUARE[] = "Square";
  203. // Parallel don't care
  204. constexpr char TUPLE_GETITEM[] = "tuple_getitem";
  205. constexpr char STRING_EQUAL[] = "string_equal";
  206. constexpr char MAKE_TUPLE[] = "make_tuple";
  207. constexpr char MAKE_LIST[] = "make_list";
  208. constexpr char MAKE_DICT[] = "make_dict";
  209. constexpr char MAKE_SLICE[] = "make_slice";
  210. constexpr char MAKE_RECORD[] = "make_record";
  211. constexpr char LIST_GETITEM[] = "list_getitem";
  212. constexpr char ARRAY_GETITEM[] = "array_getitem";
  213. constexpr char TUPLE_SETITEM[] = "tuple_setitem";
  214. constexpr char LIST_SETITEM[] = "list_setitem";
  215. constexpr char ARRAY_SETITEM[] = "array_setitem";
  216. constexpr char DICT_GETITEM[] = "dict_getitem";
  217. constexpr char LIST_APPEND[] = "list_append";
  218. constexpr char LIST_MAP[] = "list_map";
  219. constexpr char LIST_REDUCE[] = "list_reduce";
  220. constexpr char TUPLE_REVERSED[] = "tuple_reversed";
  221. constexpr char TILE_SHAPE[] = "tile_shape";
  222. constexpr char REDUCED_SHAPE[] = "reduced_shape";
  223. constexpr char TUPLE_DIV[] = "tuple_div";
  224. constexpr char TUPLE_TO_ARRAY[] = "tuple_to_array";
  225. constexpr char VIRTUALLOSS[] = "VirtualLoss";
  226. constexpr char RETURN[] = "return";
  227. constexpr char ENV_GETITEM[] = "env_getitem";
  228. constexpr char IDENTITY[] = "identity";
  229. constexpr char PARTIAL[] = "partial";
  230. constexpr char ENVSETITEM[] = "env_setitem";
  231. constexpr char ENVGETITEM[] = "env_getitem";
  232. constexpr char ENVADD[] = "env_add";
  233. constexpr char MAKEREFKEY[] = "MakeRefKey";
  234. constexpr char MAKEREF[] = "make_ref";
  235. constexpr char GETREFKEY[] = "get_ref_key";
  236. constexpr char GETREFVALUE[] = "get_ref_value";
  237. constexpr char GETREFORIGIN[] = "get_ref_origin";
  238. constexpr char STATESETITEM[] = "state_setitem";
  239. constexpr char SCALARSUMMARY[] = "ScalarSummary";
  240. constexpr char IMAGESUMMARY[] = "ImageSummary";
  241. constexpr char TENSORSUMMARY[] = "TensorSummary";
  242. constexpr char HISTOGRAMSUMMARY[] = "HistogramSummary";
  243. constexpr char BROADCASTGRADIENTARGS[] = "BroadcastGradientArgs";
  244. constexpr char INVERTPERMUTATION[] = "InvertPermutation";
  245. constexpr char CONTROLDEPEND[] = "ControlDepend";
  246. constexpr char DOT[] = "dot";
  247. constexpr char IM2COL[] = "im2col";
  248. constexpr char COL2IM[] = "col2im";
  249. constexpr char IM2COLV1[] = "im2col_v1";
  250. constexpr char COL2IMV1[] = "col2im_v1";
  251. constexpr char RESOLVE[] = "resolve";
  252. constexpr char EMBED[] = "embed";
  253. constexpr char CREATINSTANCE[] = "create_instance";
  254. constexpr char ZEROSLIKETENSOR[] = "zeros_like_tensor";
  255. constexpr char REF_TO_EMBED[] = "RefToEmbed";
  256. constexpr char STOP_GRADIENT[] = "stop_gradient";
  257. constexpr size_t LAST_INDEX(size_t s) { return s - 1; }
  258. constexpr size_t SECOND_FROM_END(size_t s) { return s - 2; }
  259. constexpr size_t THIRD_FROM_END(size_t s) { return s - 3; }
  260. } // namespace parallel
  261. } // namespace mindspore
  262. #endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPS_UTILS_H_