You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functional.py 6.6 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  2. #
  3. # Copyright 2020 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ============================================================================
  17. """The names of functional part are summarized here."""
  18. from mindspore.common._register_for_tensor import tensor_operator_registry
  19. from .primitive import Primitive
  20. from . import operations as P
  21. from .operations import _grad_ops
  22. from .._extends import builtin_operations as BP
  23. typeof = Primitive('typeof')
  24. hastype = Primitive('hastype')
  25. cast = P.Cast()
  26. dtype = P.DType()
  27. isconstant = Primitive('is_constant')
  28. isconstant.set_is_const_value(True)
  29. issubclass_ = P.IsSubClass()
  30. isinstance_ = P.IsInstance()
  31. fill = P.Fill()
  32. tile = P.Tile()
  33. select = P.Select()
  34. size = P.Size()
  35. ones_like = P.OnesLike()
  36. shape = P.Shape()
  37. rank = P.Rank()
  38. reshape = P.Reshape()
  39. # control_depend: represent dependency between two operators
  40. control_depend = P.ControlDepend()
  41. merge = P.Merge()
  42. geswitch = P.GeSwitch()
  43. addn = P.AddN()
  44. tensor_add = P.TensorAdd()
  45. neg_tensor = P.Neg()
  46. tensor_lt = P.Less()
  47. tensor_le = P.LessEqual()
  48. tensor_gt = P.Greater()
  49. tensor_ge = P.GreaterEqual()
  50. tensor_sub = P.Sub()
  51. tensor_mul = P.Mul()
  52. tensor_div = P.RealDiv()
  53. tensor_floordiv = P.FloorDiv()
  54. tensor_pow = P.Pow()
  55. tensor_mod = P.FloorMod()
  56. strided_slice = P.StridedSlice()
  57. same_type_shape = P.SameTypeShape()
  58. check_bprop = P.CheckBprop()
  59. equal = P.Equal()
  60. not_equal = P.NotEqual()
  61. assign_sub = P.AssignSub()
  62. assign_add = P.AssignAdd()
  63. assign = P.Assign()
  64. square = P.Square()
  65. sqrt = P.Sqrt()
  66. scalar_to_array = P.ScalarToArray()
  67. scalar_to_tensor = P.ScalarToTensor()
  68. tuple_to_array = P.TupleToArray()
  69. scalar_cast = P.ScalarCast()
  70. print_ = P.Print()
  71. expand_dims = P.ExpandDims()
  72. scatter_nd = P.ScatterNd()
  73. gather = P.GatherV2()
  74. gather_nd = P.GatherNd()
  75. scatter_update = P.ScatterUpdate()
  76. scatter_nd_update = P.ScatterNdUpdate()
  77. pack = P.Pack()
  78. partial = P.Partial()
  79. # depend: mount a node to another node
  80. depend = P.Depend()
  81. identity = P.identity()
  82. tuple_setitem = Primitive('tuple_setitem')
  83. tuple_getitem = Primitive('tuple_getitem')
  84. list_getitem = Primitive('list_getitem')
  85. list_setitem = Primitive('list_setitem')
  86. dict_getitem = Primitive('dict_getitem')
  87. dict_setitem = Primitive('dict_setitem')
  88. tuple_div = Primitive("tuple_div")
  89. tuple_len = Primitive("tuple_len")
  90. tuple_reversed = Primitive("tuple_reversed")
  91. make_range = Primitive("make_range")
  92. make_tuple = Primitive('make_tuple')
  93. make_dict = Primitive('make_dict')
  94. make_list = Primitive('make_list')
  95. make_slice = Primitive('make_slice')
  96. tuple_equal = Primitive("tuple_equal")
  97. list_equal = Primitive("list_equal")
  98. make_ref = Primitive("make_ref")
  99. scalar_add = Primitive('scalar_add')
  100. scalar_mul = Primitive('scalar_mul')
  101. scalar_sub = Primitive('scalar_sub')
  102. scalar_div = Primitive('scalar_div')
  103. scalar_floordiv = Primitive('scalar_floordiv')
  104. scalar_log = Primitive('scalar_log')
  105. scalar_pow = Primitive('scalar_pow')
  106. scalar_gt = Primitive('scalar_gt')
  107. scalar_ge = Primitive('scalar_ge')
  108. scalar_le = Primitive('scalar_le')
  109. scalar_lt = Primitive('scalar_lt')
  110. scalar_eq = Primitive('scalar_eq')
  111. scalar_ne = Primitive('scalar_ne')
  112. scalar_uadd = Primitive('scalar_uadd')
  113. scalar_usub = Primitive('scalar_usub')
  114. scalar_mod = Primitive('scalar_mod')
  115. string_eq = Primitive('string_equal')
  116. string_concat = Primitive('string_concat')
  117. bool_not = Primitive("bool_not")
  118. bool_or = Primitive("bool_or")
  119. bool_and = Primitive("bool_and")
  120. logical_and = P.LogicalAnd()
  121. logical_or = P.LogicalOr()
  122. logical_not = P.LogicalNot()
  123. array_to_scalar = Primitive('array_to_scalar')
  124. is_ = Primitive("is_")
  125. is_not = Primitive("is_not")
  126. in_dict = Primitive("in_dict")
  127. not_in_dict = Primitive("not_in_dict")
  128. mixed_precision_cast = Primitive("mixed_precision_cast")
  129. broadcast_gradient_args = Primitive('BroadcastGradientArgs')
  130. dot = Primitive('dot')
  131. array_reduce = Primitive('array_reduce')
  132. zeros_like = P.ZerosLike()
  133. distribute = Primitive('distribute')
  134. embed = Primitive('embed')
  135. ref_to_embed = _grad_ops.RefToEmbed()
  136. env_setitem = Primitive('env_setitem')
  137. env_getitem = Primitive('env_getitem')
  138. env_add = Primitive('env_add')
  139. J = Primitive('J')
  140. switch = Primitive('switch')
  141. switch_layer = Primitive('switch_layer')
  142. # for sum bprop
  143. reduced_shape = Primitive("reduced_shape")
  144. # shape_mul:input mush be shape multiply elemts in tuple(shape)
  145. shape_mul = Primitive("shape_mul")
  146. # a primitive to compare between tuple.
  147. stop_gradient = Primitive("stop_gradient")
  148. make_row_tensor = Primitive('MakeRowTensor')
  149. row_tensor_get_values = Primitive('RowTensorGetValues')
  150. row_tensor_get_indices = Primitive('RowTensorGetIndices')
  151. row_tensor_get_dense_shape = Primitive('RowTensorGetDenseShape')
  152. make_sparse_tensor = Primitive('MakeSparseTensor')
  153. sparse_tensor_get_values = Primitive('SparseTensorGetValues')
  154. sparse_tensor_get_indices = Primitive('SparseTensorGetIndices')
  155. sparse_tensor_get_dense_shape = Primitive('SparseTensorGetDenseShape')
  156. tensor_operator_registry.register('__add__', tensor_add)
  157. tensor_operator_registry.register('__sub__', tensor_sub)
  158. tensor_operator_registry.register('__mul__', tensor_mul)
  159. tensor_operator_registry.register('__truediv__', tensor_div)
  160. tensor_operator_registry.register('__mod__', tensor_mod)
  161. tensor_operator_registry.register('__pow__', tensor_pow)
  162. tensor_operator_registry.register('__floordiv__', tensor_floordiv)
  163. tensor_operator_registry.register('all', P.ReduceAll)
  164. tensor_operator_registry.register('any', P.ReduceAny)
  165. # ms cannot support Tensor(True) compare
  166. tensor_operator_registry.register('__eq__', equal)
  167. tensor_operator_registry.register('__ne__', not_equal)
  168. tensor_operator_registry.register('__neg__', neg_tensor)
  169. tensor_operator_registry.register('__lt__', tensor_lt)
  170. tensor_operator_registry.register('__le__', tensor_le)
  171. tensor_operator_registry.register('__gt__', tensor_gt)
  172. tensor_operator_registry.register('__ge__', tensor_ge)
  173. tensor_operator_registry.register('shape', shape)
  174. # support GE backend for no compare operators
  175. tensor_operator_registry.register('vm_compare', BP.vm_compare)
  176. tensor_operator_registry.register('cast', cast)