You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functional.py 7.5 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  2. #
  3. # Copyright 2021 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ============================================================================
  17. """The names of functional part are summarized here."""
  18. from mindspore.common._register_for_tensor import tensor_operator_registry
  19. from mindspore.ops import _constants
  20. from .primitive import Primitive
  21. from . import operations as P
  22. from .operations import _grad_ops
  23. typeof = Primitive('typeof')
  24. hastype = Primitive('hastype')
  25. cast = P.Cast()
  26. dtype = P.DType()
  27. isconstant = Primitive('is_constant')
  28. isconstant.set_const_prim(True)
  29. issubclass_ = P.IsSubClass()
  30. isinstance_ = P.IsInstance()
  31. eye = P.Eye()
  32. fill = P.Fill()
  33. tile = P.Tile()
  34. select = P.Select()
  35. size = P.Size()
  36. ones_like = P.OnesLike()
  37. shape = P.Shape()
  38. rank = P.Rank()
  39. reshape = P.Reshape()
  40. # control_depend: represent dependency between two operators
  41. def control_depend(src, dst):
  42. control_depend_op = P.ControlDepend()
  43. return control_depend_op(src, dst)
  44. merge = P.Merge()
  45. geswitch = P.GeSwitch()
  46. addn = P.AddN()
  47. absolute = P.Abs()
  48. tensor_add = P.Add()
  49. neg_tensor = P.Neg()
  50. tensor_lt = P.Less()
  51. tensor_le = P.LessEqual()
  52. tensor_gt = P.Greater()
  53. tensor_ge = P.GreaterEqual()
  54. tensor_sub = P.Sub()
  55. tensor_mul = P.Mul()
  56. tensor_div = P.RealDiv()
  57. tensor_floordiv = P.FloorDiv()
  58. tensor_pow = P.Pow()
  59. tensor_mod = P.FloorMod()
  60. tensor_exp = P.Exp()
  61. tensor_expm1 = P.Expm1()
  62. strided_slice = P.StridedSlice()
  63. same_type_shape = P.SameTypeShape()
  64. check_bprop = P.CheckBprop()
  65. equal = P.Equal()
  66. not_equal = P.NotEqual()
  67. isfinite = P.IsFinite()
  68. assign_sub = P.AssignSub()
  69. assign_add = P.AssignAdd()
  70. assign = P.Assign()
  71. square = P.Square()
  72. sqrt = P.Sqrt()
  73. log = P.Log()
  74. reduce_sum = P.ReduceSum()
  75. tensor_slice = P.Slice()
  76. maximum = P.Maximum()
  77. minimum = P.Minimum()
  78. floor = P.Floor()
  79. scalar_to_array = P.ScalarToArray()
  80. scalar_to_tensor = P.ScalarToTensor()
  81. tuple_to_array = P.TupleToArray()
  82. scalar_cast = P.ScalarCast()
  83. print_ = P.Print()
  84. expand_dims = P.ExpandDims()
  85. transpose = P.Transpose()
  86. squeeze = P.Squeeze()
  87. scatter_nd = P.ScatterNd()
  88. gather = P.Gather()
  89. gather_d = P.GatherD()
  90. gather_nd = P.GatherNd()
  91. scatter_update = P.ScatterUpdate()
  92. scatter_nd_update = P.ScatterNdUpdate()
  93. pack = P.Pack()
  94. stack = P.Stack()
  95. partial = P.Partial()
  96. # depend: mount a node to another node
  97. depend = P.Depend()
  98. identity = P.identity()
  99. tuple_setitem = Primitive('tuple_setitem')
  100. tuple_getitem = Primitive(_constants.kTupleGetItem)
  101. list_getitem = Primitive('list_getitem')
  102. list_setitem = Primitive('list_setitem')
  103. dict_getitem = Primitive('dict_getitem')
  104. dict_setitem = Primitive('dict_setitem')
  105. tuple_div = Primitive("tuple_div")
  106. tuple_len = Primitive("tuple_len")
  107. list_len = Primitive("list_len")
  108. tuple_reversed = Primitive("tuple_reversed")
  109. make_range = Primitive("make_range")
  110. make_tuple = Primitive('make_tuple')
  111. make_dict = Primitive('make_dict')
  112. make_list = Primitive('make_list')
  113. make_slice = Primitive('make_slice')
  114. tuple_equal = Primitive("tuple_equal")
  115. list_equal = Primitive("list_equal")
  116. make_ref = Primitive("make_ref")
  117. scalar_add = Primitive(_constants.kScalarAdd)
  118. scalar_mul = Primitive(_constants.kScalarMul)
  119. scalar_sub = Primitive(_constants.kScalarSub)
  120. scalar_div = Primitive(_constants.kScalarDiv)
  121. scalar_floordiv = Primitive(_constants.kScalarFloordiv)
  122. scalar_log = Primitive('scalar_log')
  123. scalar_pow = Primitive(_constants.kScalarPow)
  124. scalar_gt = Primitive('scalar_gt')
  125. scalar_ge = Primitive('scalar_ge')
  126. scalar_le = Primitive('scalar_le')
  127. scalar_lt = Primitive('scalar_lt')
  128. scalar_eq = Primitive('scalar_eq')
  129. scalar_ne = Primitive('scalar_ne')
  130. scalar_uadd = Primitive(_constants.kScalarUadd)
  131. scalar_usub = Primitive(_constants.kScalarUsub)
  132. scalar_mod = Primitive(_constants.kScalarMod)
  133. string_eq = Primitive('string_equal')
  134. string_concat = Primitive('string_concat')
  135. bool_not = Primitive("bool_not")
  136. bool_or = Primitive("bool_or")
  137. bool_and = Primitive("bool_and")
  138. bool_eq = Primitive("bool_eq")
  139. logical_and = P.LogicalAnd()
  140. logical_or = P.LogicalOr()
  141. logical_not = P.LogicalNot()
  142. array_to_scalar = Primitive('array_to_scalar')
  143. is_ = Primitive("is_")
  144. is_not = Primitive("is_not")
  145. in_dict = Primitive("in_dict")
  146. not_in_dict = Primitive("not_in_dict")
  147. mixed_precision_cast = Primitive("mixed_precision_cast")
  148. broadcast_gradient_args = Primitive('BroadcastGradientArgs')
  149. array_reduce = Primitive('array_reduce')
  150. zeros_like = P.ZerosLike()
  151. distribute = Primitive('distribute')
  152. embed = Primitive('embed')
  153. ref_to_embed = _grad_ops.RefToEmbed()
  154. env_setitem = Primitive('env_setitem')
  155. env_getitem = Primitive('env_getitem')
  156. env_add = Primitive('env_add')
  157. J = Primitive('J')
  158. switch = Primitive('switch')
  159. switch_layer = Primitive('switch_layer')
  160. # for sum bprop
  161. reduced_shape = Primitive("reduced_shape")
  162. # shape_mul:input mush be shape multiply elemts in tuple(shape)
  163. shape_mul = Primitive("shape_mul")
  164. # a primitive to compare between tuple.
  165. stop_gradient = Primitive("stop_gradient")
  166. make_row_tensor = Primitive('MakeRowTensor')
  167. row_tensor_get_values = Primitive('RowTensorGetValues')
  168. row_tensor_get_indices = Primitive('RowTensorGetIndices')
  169. row_tensor_get_dense_shape = Primitive('RowTensorGetDenseShape')
  170. row_tensor_add = Primitive('RowTensorAdd')
  171. make_sparse_tensor = Primitive('MakeSparseTensor')
  172. sparse_tensor_get_values = Primitive('SparseTensorGetValues')
  173. sparse_tensor_get_indices = Primitive('SparseTensorGetIndices')
  174. sparse_tensor_get_dense_shape = Primitive('SparseTensorGetDenseShape')
  175. tensor_operator_registry.register('__add__', tensor_add)
  176. tensor_operator_registry.register('__sub__', tensor_sub)
  177. tensor_operator_registry.register('__mul__', tensor_mul)
  178. tensor_operator_registry.register('__truediv__', tensor_div)
  179. tensor_operator_registry.register('__mod__', tensor_mod)
  180. tensor_operator_registry.register('__pow__', tensor_pow)
  181. tensor_operator_registry.register('__floordiv__', tensor_floordiv)
  182. tensor_operator_registry.register('all', P.ReduceAll)
  183. tensor_operator_registry.register('any', P.ReduceAny)
  184. tensor_operator_registry.register('abs', P.Abs)
  185. tensor_operator_registry.register('mean', P.ReduceMean)
  186. tensor_operator_registry.register('reshape', P.Reshape)
  187. tensor_operator_registry.register('transpose', P.Transpose)
  188. tensor_operator_registry.register('broadcast_to', P.BroadcastTo)
  189. # ms cannot support Tensor(True) compare
  190. tensor_operator_registry.register('__eq__', equal)
  191. tensor_operator_registry.register('__ne__', not_equal)
  192. tensor_operator_registry.register('__neg__', neg_tensor)
  193. tensor_operator_registry.register('__lt__', tensor_lt)
  194. tensor_operator_registry.register('__le__', tensor_le)
  195. tensor_operator_registry.register('__gt__', tensor_gt)
  196. tensor_operator_registry.register('__ge__', tensor_ge)
  197. tensor_operator_registry.register('shape', shape)
  198. tensor_operator_registry.register('squeeze', squeeze)
  199. # support GE backend for no compare operators
  200. tensor_operator_registry.register('cast', cast)
  201. __all__ = [name for name in dir() if name[0] != "_"]
  202. __all__.remove('Primitive')