You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

array_ops_vm_impl.py 7.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Generate vm_impl function for array ops"""
  16. import numpy as np
  17. import mindspore.common.dtype as mstype
  18. from mindspore.common.tensor import Tensor
  19. from mindspore.ops import operations as P
  20. from mindspore.ops.operations import _grad_ops as G
  21. from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
  22. from .vm_interface import vm
  23. # pylint: disable=unused-argument
  24. @vm_impl_getters.register(P.ExpandDims)
  25. def vm_impl_expand_dims(self):
  26. """Generate vm_impl function for ExpandDims"""
  27. def vm_impl(x, axis):
  28. if isinstance(x, float):
  29. x = Tensor(np.array([x]))
  30. x = x.asnumpy()
  31. out = vm.expand_dims(x, axis)
  32. return Tensor(out)
  33. return vm_impl
  34. @vm_impl_getters.register(P.DType)
  35. def vm_impl_dType(self):
  36. """Generate vm_impl function for DType"""
  37. def vm_impl(x):
  38. # update the src type
  39. return x.dtype
  40. return vm_impl
  41. @vm_impl_getters.register(P.Cast)
  42. def vm_impl_cast(self):
  43. """Generate vm_impl function for Cast"""
  44. def vm_impl(x, t):
  45. if isinstance(t, type(mstype.tensor)):
  46. t = t.element_type()
  47. # update the src type
  48. x = x.asnumpy()
  49. out = x.astype(mstype.dtype_to_nptype(t))
  50. return Tensor(out)
  51. return vm_impl
  52. @vm_impl_getters.register(P.Reshape)
  53. def vm_impl_reshape(self):
  54. """Generate vm_impl function for Reshape"""
  55. def vm_impl(x, shp):
  56. x = x.asnumpy()
  57. out = vm.reshape(x, shp)
  58. return Tensor(out)
  59. return vm_impl
  60. @vm_impl_getters.register(P.Shape)
  61. def vm_impl_shape(self):
  62. """Generate vm_impl function for Shape"""
  63. def vm_impl(x):
  64. shp = vm.shape(x.asnumpy())
  65. return shp
  66. return vm_impl
  67. @vm_impl_getters.register(P.Squeeze)
  68. def vm_impl_squeeze(self):
  69. """Generate vm_impl function for Squeeze"""
  70. def vm_impl(x):
  71. x = x.asnumpy()
  72. out = vm.squeeze(x, self.axis)
  73. return Tensor(out)
  74. return vm_impl
  75. @vm_impl_getters.register(P.Transpose)
  76. def vm_impl_transpose(self):
  77. """Generate vm_impl function for Transpose"""
  78. def vm_impl(x, perm=None):
  79. x = x.asnumpy()
  80. if perm is None:
  81. perm = [i for i in reversed(range(len(x.shape)))]
  82. out = vm.transpose(x, perm)
  83. return Tensor(out)
  84. return vm_impl
  85. @vm_impl_getters.register(P.Split)
  86. def vm_impl_split(self):
  87. """Generate vm_impl function for Split"""
  88. def vm_impl(x):
  89. x = x.asnumpy()
  90. output = np.array_split(x, (self.pos,))
  91. return Tensor(output[0]), Tensor(output[1])
  92. return vm_impl
  93. @vm_impl_getters.register(P.Fill)
  94. def vm_impl_fill(self):
  95. """Generate vm_impl function for Fill"""
  96. def vm_impl(dims, x):
  97. if isinstance(x, int):
  98. ret = np.full(dims, x, np.int32)
  99. else:
  100. ret = np.full(dims, x, np.float32)
  101. return Tensor(ret)
  102. return vm_impl
  103. @vm_impl_getters.register(P.Eye)
  104. def vm_impl_eye(self):
  105. """Generate vm_impl function for Eye"""
  106. def vm_impl(n, m, t):
  107. np_type = mstype.dtype_to_nptype(t)
  108. ret = np.eye(n, m, dtype=np_type)
  109. return Tensor(ret)
  110. return vm_impl
  111. @vm_impl_getters.register(P.InvertPermutation)
  112. def vm_impl_invert_permutation(self):
  113. """Generate vm_impl function for InvertPermutation"""
  114. def vm_impl(x):
  115. out = vm.invert_permutation(x)
  116. return out
  117. return vm_impl
  118. @vm_impl_getters.register(P.Argmax)
  119. def vm_impl_argmax(self):
  120. """Generate vm_impl function for Argmax"""
  121. def vm_impl(x):
  122. output = np.argmax(x.asnumpy(), axis=self.axis)
  123. return Tensor(output.ravel())
  124. return vm_impl
  125. @vm_impl_getters.register(P.Tile)
  126. def vm_impl_tile(self):
  127. """Generate vm_impl function for Tile"""
  128. def vm_impl(x, multiples):
  129. x = x.asnumpy()
  130. multiples = multiples.asnumpy()
  131. out = vm.Tile(x, multiples)
  132. return Tensor(out)
  133. return vm_impl
  134. @vm_impl_getters.register(P.ReduceAll)
  135. def vm_impl_all(self):
  136. """Generate vm_impl function for All"""
  137. def vm_impl(x, axis):
  138. x = x.asnumpy()
  139. out = vm.all(x, axis)
  140. return Tensor(out)
  141. return vm_impl
  142. @vm_impl_getters.register(P.Concat)
  143. def vm_impl_concatV2(self):
  144. """Generate vm_impl function for Concat"""
  145. def vm_impl(x):
  146. x = x.asnumpy()
  147. out = vm.Concat(x, self.axis)
  148. return Tensor(out)
  149. return vm_impl
  150. @vm_impl_getters.register(P.Slice)
  151. def vm_impl_slice(self):
  152. """Generate vm_impl function for Slice"""
  153. def vm_impl(x, begin, size):
  154. x = x.asnumpy()
  155. begin = begin.asnumpy()
  156. size = size.asnumpy()
  157. out = vm.Slice(x, begin, size)
  158. return Tensor(out)
  159. return vm_impl
  160. @vm_impl_getters.register(G.ConcatOffset)
  161. def vm_impl_concatOffset(self):
  162. """Generate vm_impl function for ConcatOffset"""
  163. def vm_impl(x):
  164. out = vm.ConcatOffset(x) # out is tuple
  165. return out
  166. return vm_impl
  167. @vm_impl_getters.register(P.ReduceSum)
  168. def vm_impl_sum(self):
  169. """Generate vm_impl function for Sum"""
  170. def vm_impl(x, axis):
  171. x = x.asnumpy()
  172. out = vm.sum(x, axis)
  173. return Tensor(np.array(out))
  174. return vm_impl
  175. @vm_impl_getters.register(P.Select)
  176. def vm_impl_select(self):
  177. """Generate vm_impl function for Select"""
  178. def vm_impl(cond, x, y):
  179. """
  180. Args:
  181. cond: A `Tensor` of type `bool`
  182. x: A Tensor which may have the same shape as `condition`.
  183. y: A `Tensor` with the same shape and type as `x`.
  184. """
  185. cond = cond.asnumpy()
  186. x = x.asnumpy()
  187. y = y.asnumpy()
  188. out = vm.select(cond, x, y)
  189. return Tensor(out)
  190. return vm_impl
  191. @vm_impl_getters.register(P.Square)
  192. def vm_impl_square(self):
  193. """Generate vm_impl function for Square"""
  194. def vm_impl(x):
  195. x = x.asnumpy()
  196. return Tensor(x * x)
  197. return vm_impl
  198. @vm_impl_getters.register(P.ZerosLike)
  199. def vm_impl_zeros_like(self):
  200. """Generate vm_impl function for ZerosLike"""
  201. def vm_impl(x):
  202. return Tensor(np.zeros_like(x.asnumpy()))
  203. @vm_impl_getters.register(P.Partial)
  204. def vm_impl_partial(self):
  205. """Generate vm_impl function for Partial"""
  206. def vm_impl(*args):
  207. func = args[0].__call__
  208. partial_func = functools.partial(func, *args[1:])
  209. return partial_func
  210. return vm_impl
  211. @vm_impl_getters.register(P.Depend)
  212. def vm_impl_depend(self):
  213. """Generate vm_impl function for Depend"""
  214. def vm_impl(value, expr):
  215. return value
  216. return vm_impl