You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

array_ops_vm_impl.py 8.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Generate vm_impl function for array ops"""
  16. import numpy as np
  17. import mindspore.common.dtype as mstype
  18. from mindspore.common.tensor import Tensor
  19. from mindspore.ops import operations as P
  20. from mindspore.ops.operations import _grad_ops as G
  21. from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
  22. from .vm_interface import vm
  23. # pylint: disable=unused-argument
  24. @vm_impl_getters.register(P.Assign)
  25. def vm_impl_assign(self):
  26. """Generate vm_impl function for Assign"""
  27. def vm_impl(x, value, u=None):
  28. x.assign_value(value)
  29. return x
  30. return vm_impl
  31. @vm_impl_getters.register(P.ExpandDims)
  32. def vm_impl_expand_dims(self):
  33. """Generate vm_impl function for ExpandDims"""
  34. def vm_impl(x, axis):
  35. if isinstance(x, float):
  36. x = Tensor(np.array([x]))
  37. x = x.asnumpy()
  38. out = vm.expand_dims(x, axis)
  39. return Tensor(out)
  40. return vm_impl
  41. @vm_impl_getters.register(P.DType)
  42. def vm_impl_dType(self):
  43. """Generate vm_impl function for DType"""
  44. def vm_impl(x):
  45. # update the src type
  46. return x.dtype
  47. return vm_impl
  48. @vm_impl_getters.register(P.Cast)
  49. def vm_impl_cast(self):
  50. """Generate vm_impl function for Cast"""
  51. def vm_impl(x, t):
  52. if isinstance(t, type(mstype.tensor)):
  53. t = t.element_type()
  54. # update the src type
  55. x = x.asnumpy()
  56. out = x.astype(mstype.dtype_to_nptype(t))
  57. return Tensor(out)
  58. return vm_impl
  59. @vm_impl_getters.register(P.Reshape)
  60. def vm_impl_reshape(self):
  61. """Generate vm_impl function for Reshape"""
  62. def vm_impl(x, shp):
  63. x = x.asnumpy()
  64. out = vm.reshape(x, shp)
  65. return Tensor(out)
  66. return vm_impl
  67. @vm_impl_getters.register(P.Shape)
  68. def vm_impl_shape(self):
  69. """Generate vm_impl function for Shape"""
  70. def vm_impl(x):
  71. shp = vm.shape(x.asnumpy())
  72. return shp
  73. return vm_impl
  74. @vm_impl_getters.register(P.Squeeze)
  75. def vm_impl_squeeze(self):
  76. """Generate vm_impl function for Squeeze"""
  77. def vm_impl(x):
  78. x = x.asnumpy()
  79. out = vm.squeeze(x, self.axis)
  80. return Tensor(out)
  81. return vm_impl
  82. @vm_impl_getters.register(P.Transpose)
  83. def vm_impl_transpose(self):
  84. """Generate vm_impl function for Transpose"""
  85. def vm_impl(x, perm=None):
  86. x = x.asnumpy()
  87. if perm is None:
  88. perm = [i for i in reversed(range(len(x.shape)))]
  89. out = vm.transpose(x, perm)
  90. return Tensor(out)
  91. return vm_impl
  92. @vm_impl_getters.register(P.Split)
  93. def vm_impl_split(self):
  94. """Generate vm_impl function for Split"""
  95. def vm_impl(x):
  96. x = x.asnumpy()
  97. output = np.array_split(x, (self.pos,))
  98. return Tensor(output[0]), Tensor(output[1])
  99. return vm_impl
  100. @vm_impl_getters.register(P.Fill)
  101. def vm_impl_fill(self):
  102. """Generate vm_impl function for Fill"""
  103. def vm_impl(dims, x):
  104. if isinstance(x, int):
  105. ret = np.full(dims, x, np.int32)
  106. else:
  107. ret = np.full(dims, x, np.float32)
  108. return Tensor(ret)
  109. return vm_impl
  110. @vm_impl_getters.register(P.Eye)
  111. def vm_impl_eye(self):
  112. """Generate vm_impl function for Eye"""
  113. def vm_impl(n, m, t):
  114. np_type = mstype.dtype_to_nptype(t)
  115. ret = np.eye(n, m, dtype=np_type)
  116. return Tensor(ret)
  117. return vm_impl
  118. @vm_impl_getters.register(P.InvertPermutation)
  119. def vm_impl_invert_permutation(self):
  120. """Generate vm_impl function for InvertPermutation"""
  121. def vm_impl(x):
  122. out = vm.invert_permutation(x)
  123. return out
  124. return vm_impl
  125. @vm_impl_getters.register(P.Argmax)
  126. def vm_impl_argmax(self):
  127. """Generate vm_impl function for Argmax"""
  128. def vm_impl(x):
  129. output = np.argmax(x.asnumpy(), axis=self.axis)
  130. return Tensor(output.ravel())
  131. return vm_impl
  132. @vm_impl_getters.register(P.Tile)
  133. def vm_impl_tile(self):
  134. """Generate vm_impl function for Tile"""
  135. def vm_impl(x, multiples):
  136. x = x.asnumpy()
  137. out = np.tile(x, multiples)
  138. return Tensor(out)
  139. return vm_impl
  140. @vm_impl_getters.register(P.ReduceAll)
  141. def vm_impl_all(self):
  142. """Generate vm_impl function for All"""
  143. def vm_impl(x, axis):
  144. x = x.asnumpy()
  145. out = vm.all(x, axis, self.keep_dims)
  146. return Tensor(out)
  147. return vm_impl
  148. @vm_impl_getters.register(P.ReduceAny)
  149. def vm_impl_any(self):
  150. """Generate vm_impl function for Any"""
  151. def vm_impl(x, axis):
  152. x = x.asnumpy()
  153. out = vm.any(x, axis, self.keep_dims)
  154. return Tensor(out)
  155. return vm_impl
  156. @vm_impl_getters.register(P.Concat)
  157. def vm_impl_concatV2(self):
  158. """Generate vm_impl function for Concat"""
  159. def vm_impl(x):
  160. x = x.asnumpy()
  161. out = vm.Concat(x, self.axis)
  162. return Tensor(out)
  163. return vm_impl
  164. @vm_impl_getters.register(P.Slice)
  165. def vm_impl_slice(self):
  166. """Generate vm_impl function for Slice"""
  167. def vm_impl(x, begin, size):
  168. x = x.asnumpy()
  169. begin = begin.asnumpy()
  170. size = size.asnumpy()
  171. out = vm.Slice(x, begin, size)
  172. return Tensor(out)
  173. return vm_impl
  174. @vm_impl_getters.register(G.ConcatOffset)
  175. def vm_impl_concatOffset(self):
  176. """Generate vm_impl function for ConcatOffset"""
  177. def vm_impl(x):
  178. out = vm.ConcatOffset(x) # out is tuple
  179. return out
  180. return vm_impl
  181. @vm_impl_getters.register(P.ReduceSum)
  182. def vm_impl_sum(self):
  183. """Generate vm_impl function for Sum"""
  184. def vm_impl(x, axis):
  185. x = x.asnumpy()
  186. if axis == ():
  187. out = np.sum(x)
  188. else:
  189. out = np.sum(x, axis=axis)
  190. return Tensor(np.array(out))
  191. return vm_impl
  192. @vm_impl_getters.register(P.Select)
  193. def vm_impl_select(self):
  194. """Generate vm_impl function for Select"""
  195. def vm_impl(cond, x, y):
  196. """
  197. Args:
  198. cond: A `Tensor` of type `bool`
  199. x: A Tensor which may have the same shape as `condition`.
  200. y: A `Tensor` with the same shape and type as `x`.
  201. """
  202. cond = cond.asnumpy()
  203. x = x.asnumpy()
  204. y = y.asnumpy()
  205. out = vm.select(cond, x, y)
  206. return Tensor(out)
  207. return vm_impl
  208. @vm_impl_getters.register(P.Square)
  209. def vm_impl_square(self):
  210. """Generate vm_impl function for Square"""
  211. def vm_impl(x):
  212. x = x.asnumpy()
  213. return Tensor(x * x)
  214. return vm_impl
  215. @vm_impl_getters.register(P.ZerosLike)
  216. def vm_impl_zeros_like(self):
  217. """Generate vm_impl function for ZerosLike"""
  218. def vm_impl(x):
  219. return Tensor(np.zeros_like(x.asnumpy()))
  220. @vm_impl_getters.register(P.Partial)
  221. def vm_impl_partial(self):
  222. """Generate vm_impl function for Partial"""
  223. def vm_impl(*args):
  224. func = args[0].__call__
  225. partial_func = functools.partial(func, *args[1:])
  226. return partial_func
  227. return vm_impl
  228. @vm_impl_getters.register(P.Depend)
  229. def vm_impl_depend(self):
  230. """Generate vm_impl function for Depend"""
  231. def vm_impl(value, expr):
  232. return value
  233. return vm_impl
  234. @vm_impl_getters.register(P.UpdateState)
  235. def vm_impl_updatestate(self):
  236. """Generate vm_impl function for UpdateState"""
  237. def vm_impl(monad, expr):
  238. return monad
  239. return vm_impl
  240. @vm_impl_getters.register(P.Load)
  241. def vm_impl_load(self):
  242. """Generate vm_impl function for Load"""
  243. def vm_impl(value, u=None):
  244. return value
  245. return vm_impl