You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

array_ops_vm_impl.py 8.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Generate vm_impl function for array ops"""
  16. import numpy as np
  17. import mindspore.common.dtype as mstype
  18. from mindspore.common.tensor import Tensor
  19. from mindspore.ops import operations as P
  20. from mindspore.ops.operations import _grad_ops as G
  21. from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
  22. from .vm_interface import vm
  23. # pylint: disable=unused-argument
  24. @vm_impl_getters.register(P.Assign)
  25. def vm_impl_assign(self):
  26. """Generate vm_impl function for Assign"""
  27. def vm_impl(x, value, u=None):
  28. x.assign_value(value)
  29. return x
  30. return vm_impl
  31. @vm_impl_getters.register(P.AssignAdd)
  32. def vm_impl_assignadd(self):
  33. """Generate vm_impl function for Assign"""
  34. def vm_impl(x, value, u=None):
  35. x.assign_value(value)
  36. return x
  37. return vm_impl
  38. @vm_impl_getters.register(P.ExpandDims)
  39. def vm_impl_expand_dims(self):
  40. """Generate vm_impl function for ExpandDims"""
  41. def vm_impl(x, axis):
  42. if isinstance(x, float):
  43. x = Tensor(np.array([x]))
  44. x = x.asnumpy()
  45. out = vm.expand_dims(x, axis)
  46. return Tensor(out)
  47. return vm_impl
  48. @vm_impl_getters.register(P.DType)
  49. def vm_impl_dType(self):
  50. """Generate vm_impl function for DType"""
  51. def vm_impl(x):
  52. # update the src type
  53. return x.dtype
  54. return vm_impl
  55. @vm_impl_getters.register(P.Cast)
  56. def vm_impl_cast(self):
  57. """Generate vm_impl function for Cast"""
  58. def vm_impl(x, t):
  59. if isinstance(t, type(mstype.tensor)):
  60. t = t.element_type()
  61. # update the src type
  62. x = x.asnumpy()
  63. out = x.astype(mstype.dtype_to_nptype(t))
  64. return Tensor(out)
  65. return vm_impl
  66. @vm_impl_getters.register(P.Reshape)
  67. def vm_impl_reshape(self):
  68. """Generate vm_impl function for Reshape"""
  69. def vm_impl(x, shp):
  70. x = x.asnumpy()
  71. out = vm.reshape(x, shp)
  72. return Tensor(out)
  73. return vm_impl
  74. @vm_impl_getters.register(P.Shape)
  75. def vm_impl_shape(self):
  76. """Generate vm_impl function for Shape"""
  77. def vm_impl(x):
  78. shp = vm.shape(x.asnumpy())
  79. return shp
  80. return vm_impl
  81. @vm_impl_getters.register(P.Squeeze)
  82. def vm_impl_squeeze(self):
  83. """Generate vm_impl function for Squeeze"""
  84. def vm_impl(x):
  85. x = x.asnumpy()
  86. out = vm.squeeze(x, self.axis)
  87. return Tensor(out)
  88. return vm_impl
  89. @vm_impl_getters.register(P.Transpose)
  90. def vm_impl_transpose(self):
  91. """Generate vm_impl function for Transpose"""
  92. def vm_impl(x, perm=None):
  93. x = x.asnumpy()
  94. if perm is None:
  95. perm = [i for i in reversed(range(len(x.shape)))]
  96. out = vm.transpose(x, perm)
  97. return Tensor(out)
  98. return vm_impl
  99. @vm_impl_getters.register(P.Split)
  100. def vm_impl_split(self):
  101. """Generate vm_impl function for Split"""
  102. def vm_impl(x):
  103. x = x.asnumpy()
  104. output = np.array_split(x, (self.pos,))
  105. return Tensor(output[0]), Tensor(output[1])
  106. return vm_impl
  107. @vm_impl_getters.register(P.Fill)
  108. def vm_impl_fill(self):
  109. """Generate vm_impl function for Fill"""
  110. def vm_impl(dims, x):
  111. if isinstance(x, int):
  112. ret = np.full(dims, x, np.int32)
  113. else:
  114. ret = np.full(dims, x, np.float32)
  115. return Tensor(ret)
  116. return vm_impl
  117. @vm_impl_getters.register(P.Eye)
  118. def vm_impl_eye(self):
  119. """Generate vm_impl function for Eye"""
  120. def vm_impl(n, m, t):
  121. np_type = mstype.dtype_to_nptype(t)
  122. ret = np.eye(n, m, dtype=np_type)
  123. return Tensor(ret)
  124. return vm_impl
  125. @vm_impl_getters.register(P.InvertPermutation)
  126. def vm_impl_invert_permutation(self):
  127. """Generate vm_impl function for InvertPermutation"""
  128. def vm_impl(x):
  129. out = vm.invert_permutation(x)
  130. return out
  131. return vm_impl
  132. @vm_impl_getters.register(P.Argmax)
  133. def vm_impl_argmax(self):
  134. """Generate vm_impl function for Argmax"""
  135. def vm_impl(x):
  136. output = np.argmax(x.asnumpy(), axis=self.axis)
  137. return Tensor(output.ravel())
  138. return vm_impl
  139. @vm_impl_getters.register(P.Tile)
  140. def vm_impl_tile(self):
  141. """Generate vm_impl function for Tile"""
  142. def vm_impl(x, multiples):
  143. x = x.asnumpy()
  144. out = np.tile(x, multiples)
  145. return Tensor(out)
  146. return vm_impl
  147. @vm_impl_getters.register(P.ReduceAll)
  148. def vm_impl_all(self):
  149. """Generate vm_impl function for All"""
  150. def vm_impl(x, axis):
  151. x = x.asnumpy()
  152. out = vm.all(x, axis, self.keep_dims)
  153. return Tensor(out)
  154. return vm_impl
  155. @vm_impl_getters.register(P.ReduceAny)
  156. def vm_impl_any(self):
  157. """Generate vm_impl function for Any"""
  158. def vm_impl(x, axis):
  159. x = x.asnumpy()
  160. out = vm.any(x, axis, self.keep_dims)
  161. return Tensor(out)
  162. return vm_impl
  163. @vm_impl_getters.register(P.Concat)
  164. def vm_impl_concatV2(self):
  165. """Generate vm_impl function for Concat"""
  166. def vm_impl(x):
  167. x = x.asnumpy()
  168. out = vm.Concat(x, self.axis)
  169. return Tensor(out)
  170. return vm_impl
  171. @vm_impl_getters.register(P.Slice)
  172. def vm_impl_slice(self):
  173. """Generate vm_impl function for Slice"""
  174. def vm_impl(x, begin, size):
  175. x = x.asnumpy()
  176. begin = begin.asnumpy()
  177. size = size.asnumpy()
  178. out = vm.Slice(x, begin, size)
  179. return Tensor(out)
  180. return vm_impl
  181. @vm_impl_getters.register(G.ConcatOffset)
  182. def vm_impl_concatOffset(self):
  183. """Generate vm_impl function for ConcatOffset"""
  184. def vm_impl(x):
  185. out = vm.ConcatOffset(x) # out is tuple
  186. return out
  187. return vm_impl
  188. @vm_impl_getters.register(P.ReduceSum)
  189. def vm_impl_sum(self):
  190. """Generate vm_impl function for Sum"""
  191. def vm_impl(x, axis):
  192. x = x.asnumpy()
  193. if axis == ():
  194. out = np.sum(x)
  195. else:
  196. out = np.sum(x, axis=axis)
  197. return Tensor(np.array(out))
  198. return vm_impl
  199. @vm_impl_getters.register(P.Select)
  200. @vm_impl_getters.register(P.Select.__name__)
  201. def vm_impl_select(self):
  202. """Generate vm_impl function for Select"""
  203. def vm_impl(cond, x, y):
  204. """
  205. Args:
  206. cond: A `Tensor` of type `bool`
  207. x: A Tensor which may have the same shape as `condition`.
  208. y: A `Tensor` with the same shape and type as `x`.
  209. """
  210. cond = cond.asnumpy()
  211. x = x.asnumpy()
  212. y = y.asnumpy()
  213. out = vm.select(cond, x, y)
  214. return Tensor(out)
  215. return vm_impl
  216. @vm_impl_getters.register(P.Square)
  217. def vm_impl_square(self):
  218. """Generate vm_impl function for Square"""
  219. def vm_impl(x):
  220. x = x.asnumpy()
  221. return Tensor(x * x)
  222. return vm_impl
  223. @vm_impl_getters.register(P.ZerosLike)
  224. def vm_impl_zeros_like(self):
  225. """Generate vm_impl function for ZerosLike"""
  226. def vm_impl(x):
  227. return Tensor(np.zeros_like(x.asnumpy()))
  228. @vm_impl_getters.register(P.Partial)
  229. def vm_impl_partial(self):
  230. """Generate vm_impl function for Partial"""
  231. def vm_impl(*args):
  232. func = args[0].__call__
  233. partial_func = functools.partial(func, *args[1:])
  234. return partial_func
  235. return vm_impl
  236. @vm_impl_getters.register(P.Depend)
  237. def vm_impl_depend(self):
  238. """Generate vm_impl function for Depend"""
  239. def vm_impl(value, expr):
  240. return value
  241. return vm_impl
  242. @vm_impl_getters.register(P.UpdateState)
  243. def vm_impl_updatestate(self):
  244. """Generate vm_impl function for UpdateState"""
  245. def vm_impl(monad, expr):
  246. return monad
  247. return vm_impl
  248. @vm_impl_getters.register(P.Load)
  249. def vm_impl_load(self):
  250. """Generate vm_impl function for Load"""
  251. def vm_impl(value, u=None):
  252. return value
  253. return vm_impl