You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

array_ops_vm_impl.py 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Generate vm_impl function for array ops"""
  16. import numpy as np
  17. from mindspore.ops import operations as P
  18. from mindspore.common.tensor import Tensor
  19. import mindspore.common.dtype as mstype
  20. from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
  21. from .vm_interface import vm
  22. # pylint: disable=unused-argument
  23. @vm_impl_getters.register(P.ExpandDims)
  24. def vm_impl_expand_dims(self):
  25. """Generate vm_impl function for ExpandDims"""
  26. def vm_impl(x, axis):
  27. if isinstance(x, float):
  28. x = Tensor(np.array([x]))
  29. x = x.asnumpy()
  30. out = vm.expand_dims(x, axis)
  31. return Tensor(out)
  32. return vm_impl
  33. @vm_impl_getters.register(P.DType)
  34. def vm_impl_dType(self):
  35. """Generate vm_impl function for DType"""
  36. def vm_impl(x):
  37. # update the src type
  38. return x.dtype()
  39. return vm_impl
  40. @vm_impl_getters.register(P.Cast)
  41. def vm_impl_cast(self):
  42. """Generate vm_impl function for Cast"""
  43. def vm_impl(x, t):
  44. if isinstance(t, type(mstype.tensor)):
  45. t = t.element_type()
  46. # update the src type
  47. x = x.asnumpy()
  48. out = x.astype(mstype.dtype_to_nptype(t))
  49. return Tensor(out)
  50. return vm_impl
  51. @vm_impl_getters.register(P.Reshape)
  52. def vm_impl_reshape(self):
  53. """Generate vm_impl function for Reshape"""
  54. def vm_impl(x, shp):
  55. x = x.asnumpy()
  56. out = vm.reshape(x, shp)
  57. return Tensor(out)
  58. return vm_impl
  59. @vm_impl_getters.register(P.Shape)
  60. def vm_impl_shape(self):
  61. """Generate vm_impl function for Shape"""
  62. def vm_impl(x):
  63. shp = vm.shape(x.asnumpy())
  64. return shp
  65. return vm_impl
  66. @vm_impl_getters.register(P.Squeeze)
  67. def vm_impl_squeeze(self):
  68. """Generate vm_impl function for Squeeze"""
  69. def vm_impl(x):
  70. x = x.asnumpy()
  71. out = vm.squeeze(x, self.axis)
  72. return Tensor(out)
  73. return vm_impl
  74. @vm_impl_getters.register(P.Transpose)
  75. def vm_impl_transpose(self):
  76. """Generate vm_impl function for Transpose"""
  77. def vm_impl(x, perm=None):
  78. x = x.asnumpy()
  79. if perm is None:
  80. perm = [i for i in reversed(range(len(x.shape)))]
  81. out = vm.transpose(x, perm)
  82. return Tensor(out)
  83. return vm_impl
  84. @vm_impl_getters.register(P.Split)
  85. def vm_impl_split(self):
  86. """Generate vm_impl function for Split"""
  87. def vm_impl(x):
  88. x = x.asnumpy()
  89. output = np.array_split(x, (self.pos,))
  90. return Tensor(output[0]), Tensor(output[1])
  91. return vm_impl
  92. @vm_impl_getters.register(P.Fill)
  93. def vm_impl_fill(self):
  94. """Generate vm_impl function for Fill"""
  95. def vm_impl(dims, x):
  96. if isinstance(x, int):
  97. ret = np.full(dims, x, np.int32)
  98. else:
  99. ret = np.full(dims, x, np.float32)
  100. return Tensor(ret)
  101. return vm_impl
  102. @vm_impl_getters.register(P.Eye)
  103. def vm_impl_eye(self):
  104. """Generate vm_impl function for Eye"""
  105. def vm_impl(n, m, t):
  106. np_type = mstype.dtype_to_nptype(t)
  107. ret = np.eye(n, m, dtype=np_type)
  108. return Tensor(ret)
  109. return vm_impl
  110. @vm_impl_getters.register(P.InvertPermutation)
  111. def vm_impl_invert_permutation(self):
  112. """Generate vm_impl function for InvertPermutation"""
  113. def vm_impl(x):
  114. out = vm.invert_permutation(x)
  115. return out
  116. return vm_impl
  117. @vm_impl_getters.register(P.Argmax)
  118. def vm_impl_argmax(self):
  119. """Generate vm_impl function for Argmax"""
  120. def vm_impl(x):
  121. output = np.argmax(x.asnumpy(), axis=self.axis)
  122. return Tensor(output.ravel())
  123. return vm_impl
  124. @vm_impl_getters.register(P.Tile)
  125. def vm_impl_tile(self):
  126. """Generate vm_impl function for Tile"""
  127. def vm_impl(x, multiples):
  128. x = x.asnumpy()
  129. multiples = multiples.asnumpy()
  130. out = vm.Tile(x, multiples)
  131. return Tensor(out)
  132. return vm_impl
  133. @vm_impl_getters.register(P.ReduceAll)
  134. def vm_impl_all(self):
  135. """Generate vm_impl function for All"""
  136. def vm_impl(x, axis):
  137. x = x.asnumpy()
  138. out = vm.all(x, axis)
  139. return Tensor(out)
  140. return vm_impl
  141. @vm_impl_getters.register(P.Concat)
  142. def vm_impl_concatV2(self):
  143. """Generate vm_impl function for Concat"""
  144. def vm_impl(x):
  145. x = x.asnumpy()
  146. out = vm.Concat(x, self.axis)
  147. return Tensor(out)
  148. return vm_impl
  149. @vm_impl_getters.register(P.Slice)
  150. def vm_impl_slice(self):
  151. """Generate vm_impl function for Slice"""
  152. def vm_impl(x, begin, size):
  153. x = x.asnumpy()
  154. begin = begin.asnumpy()
  155. size = size.asnumpy()
  156. out = vm.Slice(x, begin, size)
  157. return Tensor(out)
  158. return vm_impl
  159. @vm_impl_getters.register(P.ConcatOffset)
  160. def vm_impl_concatOffset(self):
  161. """Generate vm_impl function for ConcatOffset"""
  162. def vm_impl(x):
  163. out = vm.ConcatOffset(x) # out is tuple
  164. return out
  165. return vm_impl
  166. @vm_impl_getters.register(P.ReduceSum)
  167. def vm_impl_sum(self):
  168. """Generate vm_impl function for Sum"""
  169. def vm_impl(x, axis):
  170. x = x.asnumpy()
  171. out = vm.sum(x, axis)
  172. return Tensor(np.array(out))
  173. return vm_impl
  174. @vm_impl_getters.register(P.Select)
  175. def vm_impl_select(self):
  176. """Generate vm_impl function for Select"""
  177. def vm_impl(cond, x, y):
  178. """
  179. Args:
  180. cond: A `Tensor` of type `bool`
  181. x: A Tensor which may have the same shape as `condition`.
  182. y: A `Tensor` with the same shape and type as `x`.
  183. """
  184. cond = cond.asnumpy()
  185. x = x.asnumpy()
  186. y = y.asnumpy()
  187. out = vm.select(cond, x, y)
  188. return Tensor(out)
  189. return vm_impl
  190. @vm_impl_getters.register(P.Square)
  191. def vm_impl_square(self):
  192. """Generate vm_impl function for Square"""
  193. def vm_impl(x):
  194. x = x.asnumpy()
  195. return Tensor(x * x)
  196. return vm_impl