You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functional.py 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  2. #
  3. # Copyright 2020 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ============================================================================
  17. """The names of functional part are summarized here."""
  18. from mindspore.common._register_for_tensor import tensor_operator_registry
  19. from .primitive import Primitive
  20. from . import operations as P
  21. from .operations import _grad_ops
  22. typeof = Primitive('typeof')
  23. hastype = Primitive('hastype')
  24. cast = P.Cast()
  25. dtype = P.DType()
  26. issubclass_ = P.IsSubClass()
  27. isinstance_ = P.IsInstance()
  28. fill = P.Fill()
  29. select = P.Select()
  30. size = P.Size()
  31. ones_like = P.OnesLike()
  32. shape = P.Shape()
  33. rank = P.Rank()
  34. reshape = P.Reshape()
  35. # control_depend: represent dependency between two operators
  36. control_depend = P.ControlDepend()
  37. merge = P.Merge()
  38. geswitch = P.GeSwitch()
  39. addn = P.AddN()
  40. tensor_add = P.TensorAdd()
  41. neg_tensor = P.Neg()
  42. tensor_lt = P.Less()
  43. tensor_le = P.LessEqual()
  44. tensor_gt = P.Greater()
  45. tensor_ge = P.GreaterEqual()
  46. tensor_sub = P.Sub()
  47. tensor_mul = P.Mul()
  48. tensor_div = P.RealDiv()
  49. tensor_floordiv = P.FloorDiv()
  50. tensor_pow = P.Pow()
  51. tensor_mod = P.FloorMod()
  52. strided_slice = P.StridedSlice()
  53. same_type_shape = P.SameTypeShape()
  54. check_bprop = P.CheckBprop()
  55. equal = P.Equal()
  56. not_equal = P.NotEqual()
  57. assign_sub = P.AssignSub()
  58. assign = P.Assign()
  59. square = P.Square()
  60. sqrt = P.Sqrt()
  61. scalar_to_array = P.ScalarToArray()
  62. scalar_to_tensor = P.ScalarToTensor()
  63. tuple_to_array = P.TupleToArray()
  64. scalar_cast = P.ScalarCast()
  65. print_ = P.Print()
  66. expand_dims = P.ExpandDims()
  67. scatter_nd = P.ScatterNd()
  68. tuple_setitem = Primitive('tuple_setitem')
  69. tuple_getitem = Primitive('tuple_getitem')
  70. list_getitem = Primitive('list_getitem')
  71. list_setitem = Primitive('list_setitem')
  72. dict_getitem = Primitive('dict_getitem')
  73. dict_setitem = Primitive('dict_setitem')
  74. tuple_div = Primitive("tuple_div")
  75. tuple_len = Primitive("tuple_len")
  76. tuple_reversed = Primitive("tuple_reversed")
  77. make_range = Primitive("make_range")
  78. make_tuple = Primitive('make_tuple')
  79. make_dict = Primitive('make_dict')
  80. make_list = Primitive('make_list')
  81. make_slice = Primitive('make_slice')
  82. tuple_equal = Primitive("tuple_equal")
  83. list_equal = Primitive("list_equal")
  84. make_ref = Primitive("make_ref")
  85. scalar_add = Primitive('scalar_add')
  86. scalar_mul = Primitive('scalar_mul')
  87. scalar_sub = Primitive('scalar_sub')
  88. scalar_div = Primitive('scalar_div')
  89. scalar_floordiv = Primitive('scalar_floordiv')
  90. scalar_log = Primitive('scalar_log')
  91. scalar_pow = Primitive('scalar_pow')
  92. scalar_gt = Primitive('scalar_gt')
  93. scalar_ge = Primitive('scalar_ge')
  94. scalar_le = Primitive('scalar_le')
  95. scalar_lt = Primitive('scalar_lt')
  96. scalar_eq = Primitive('scalar_eq')
  97. scalar_ne = Primitive('scalar_ne')
  98. scalar_uadd = Primitive('scalar_uadd')
  99. scalar_usub = Primitive('scalar_usub')
  100. scalar_mod = Primitive('scalar_mod')
  101. string_eq = Primitive('string_equal')
  102. string_concat = Primitive('string_concat')
  103. bool_not = Primitive("bool_not")
  104. bool_or = Primitive("bool_or")
  105. bool_and = Primitive("bool_and")
  106. logical_and = P.LogicalAnd()
  107. logical_or = P.LogicalOr()
  108. logical_not = P.LogicalNot()
  109. array_to_scalar = Primitive('array_to_scalar')
  110. is_ = Primitive("is_")
  111. is_not = Primitive("is_not")
  112. in_dict = Primitive("in_dict")
  113. not_in_dict = Primitive("not_in_dict")
  114. broadcast_gradient_args = Primitive('BroadcastGradientArgs')
  115. dot = Primitive('dot')
  116. array_reduce = Primitive('array_reduce')
  117. partial = Primitive('partial')
  118. zeros_like_tensor = Primitive('zeros_like_tensor')
  119. identity = Primitive('identity')
  120. distribute = Primitive('distribute')
  121. # depend: mount a node to another node
  122. depend = Primitive('depend')
  123. embed = Primitive('embed')
  124. ref_to_embed = _grad_ops.RefToEmbed()
  125. env_setitem = Primitive('env_setitem')
  126. env_getitem = Primitive('env_getitem')
  127. env_add = Primitive('env_add')
  128. J = Primitive('J')
  129. switch = Primitive('switch')
  130. # for sum bprop
  131. reduced_shape = Primitive("reduced_shape")
  132. # shape_mul:input mush be shape multiply elemts in tuple(shape)
  133. shape_mul = Primitive("shape_mul")
  134. # a primitive to compare between tuple.
  135. stop_gradient = Primitive("stop_gradient")
  136. tensor_operator_registry.register('__add__', tensor_add)
  137. tensor_operator_registry.register('__mul__', tensor_mul)
  138. tensor_operator_registry.register('__div__', tensor_div)
  139. #ms cannot support Tensor(True) compare
  140. tensor_operator_registry.register('__eq__', equal)