You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

math_ops_vm_impl.py 6.7 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Generate vm_impl function for math ops"""
  16. import copy
  17. import numpy as np
  18. from mindspore.common.dtype import dtype_to_nptype
  19. from mindspore.common.tensor import Tensor
  20. from mindspore.ops import operations as P
  21. from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
  22. from .vm_interface import vm
  23. # pylint: disable=unused-argument
  24. @vm_impl_getters.register(P.Add)
  25. def vm_impl_tensor_add(self):
  26. """Generate vm_impl function for TensorAdd."""
  27. def vm_impl(x, y):
  28. x = x.asnumpy()
  29. y = y.asnumpy()
  30. return Tensor(x + y)
  31. return vm_impl
  32. # pylint: disable=used-before-assignment
  33. @vm_impl_getters.register(P.LogicalNot)
  34. def vm_impl_logical_not(self):
  35. def vm_impl(x):
  36. x = x.asnumpy()
  37. out = vm.logical_not(x)
  38. return Tensor(out)
  39. return vm_impl
  40. @vm_impl_getters.register(P.MatMul)
  41. def vm_impl_mat_mul(self):
  42. """Generate vm_impl function for MatMul."""
  43. def vm_impl(x, w):
  44. x = x.asnumpy()
  45. w = w.asnumpy()
  46. if self.transpose_a:
  47. x = x.transpose()
  48. if self.transpose_b:
  49. w = w.transpose()
  50. z = x @ w
  51. return Tensor(z)
  52. return vm_impl
  53. @vm_impl_getters.register(P.AddN)
  54. def vm_impl_addn(self):
  55. """Generate vm_impl function for AddN."""
  56. def vm_impl(inputs):
  57. added = copy.deepcopy(inputs[0].asnumpy())
  58. for x in inputs[1:]:
  59. added += x.asnumpy()
  60. return Tensor(added)
  61. return vm_impl
  62. @vm_impl_getters.register(P.Neg)
  63. def vm_impl_neg(self):
  64. """Generate vm_impl function for Neg."""
  65. def vm_impl(x):
  66. x = x.asnumpy()
  67. return Tensor(-x)
  68. return vm_impl
  69. @vm_impl_getters.register(P.Sub)
  70. def vm_impl_Sub(self):
  71. """Generate vm_impl function for Sub."""
  72. def vm_impl(x, y):
  73. x = x.asnumpy()
  74. y = y.asnumpy()
  75. return Tensor(x - y)
  76. return vm_impl
  77. @vm_impl_getters.register(P.Mul)
  78. def vm_impl_mul(self):
  79. """Generate vm_impl function for Mul."""
  80. def vm_impl(x, y):
  81. x = x.asnumpy()
  82. y = y.asnumpy()
  83. return Tensor(x * y)
  84. return vm_impl
  85. @vm_impl_getters.register(P.Square)
  86. def vm_impl_square(self):
  87. """Generate vm_impl function for Square."""
  88. def vm_impl(x):
  89. x = x.asnumpy()
  90. return Tensor(x * x)
  91. return vm_impl
  92. @vm_impl_getters.register(P.Sqrt)
  93. def vm_impl_sqrt(self):
  94. """Generate vm_impl function for Sqrt."""
  95. def vm_impl(x):
  96. x = x.asnumpy()
  97. res = vm.sqrt(x)
  98. return Tensor(res)
  99. return vm_impl
  100. @vm_impl_getters.register(P.Pow)
  101. def vm_impl_pow(self):
  102. """Generate vm_impl function for Pow."""
  103. def vm_impl(x, y):
  104. x = x.asnumpy()
  105. y = y.asnumpy()
  106. res = vm.power(x, y)
  107. return Tensor(res)
  108. return vm_impl
  109. @vm_impl_getters.register(P.Exp)
  110. def vm_impl_exp(self):
  111. """Generate vm_impl function for Exp."""
  112. def vm_impl(x):
  113. x = x.asnumpy()
  114. res = vm.exp(x)
  115. return Tensor(res)
  116. return vm_impl
  117. @vm_impl_getters.register(P.RealDiv)
  118. def vm_impl_real_div(self):
  119. """Generate vm_impl function for RealDiv."""
  120. def vm_impl(x, y):
  121. x = x.asnumpy()
  122. y = y.asnumpy()
  123. out = x / y
  124. out = np.array(out, x.dtype)
  125. return Tensor(out)
  126. return vm_impl
  127. @vm_impl_getters.register(P.Div)
  128. def vm_impl_div(self):
  129. """Generate vm_impl function for Div."""
  130. def vm_impl(x, y):
  131. x = x.asnumpy()
  132. y = y.asnumpy()
  133. return Tensor(x / y)
  134. return vm_impl
  135. @vm_impl_getters.register(P.ReduceMean)
  136. def vm_impl_reduce_mean(self):
  137. """Generate vm_impl function for ReduceMean."""
  138. def vm_impl(x, axis):
  139. x = x.asnumpy()
  140. out = vm.mean(x, axis)
  141. return Tensor(out)
  142. return vm_impl
  143. @vm_impl_getters.register(P.ReduceMax)
  144. def vm_impl_reduce_max(self):
  145. """Generate vm_impl function for ReduceMean."""
  146. def vm_impl(x, axis):
  147. x = x.asnumpy()
  148. if axis == ():
  149. axis = None
  150. out = np.amax(x, axis)
  151. return Tensor(out)
  152. return vm_impl
  153. @vm_impl_getters.register(P.Equal)
  154. def vm_impl_equal(self):
  155. """Generate vm_impl function for Equal."""
  156. def vm_impl(x, y):
  157. x = x.asnumpy()
  158. y = y.asnumpy()
  159. out = vm.equal(x, y)
  160. return Tensor(np.array(out))
  161. return vm_impl
  162. @vm_impl_getters.register(P.NotEqual)
  163. def vm_impl_not_equal(self):
  164. """Generate vm_impl function for NotEqual."""
  165. def vm_impl(x, y):
  166. x = x.asnumpy()
  167. y = y.asnumpy()
  168. out = vm.not_equal(x, y)
  169. return Tensor(np.array(out))
  170. return vm_impl
  171. @vm_impl_getters.register(P.Greater)
  172. def vm_impl_greater(self):
  173. """Generate vm_impl function for Greater."""
  174. def vm_impl(x, y):
  175. x = x.asnumpy()
  176. y = y.asnumpy()
  177. out = vm.greater(x, y)
  178. return Tensor(np.array(out))
  179. return vm_impl
  180. @vm_impl_getters.register(P.Maximum)
  181. def vm_impl_maximum(self):
  182. """Generate vm_impl function for Maximum."""
  183. def vm_impl(x, y):
  184. x = x.asnumpy()
  185. y = y.asnumpy()
  186. out = vm.maximum(x, y)
  187. return Tensor(out)
  188. return vm_impl
  189. @vm_impl_getters.register(P.Minimum)
  190. def vm_impl_minimum(self):
  191. """Generate vm_impl function for Minimum."""
  192. def vm_impl(x, y):
  193. x = x.asnumpy()
  194. y = y.asnumpy()
  195. out = vm.minimum(x, y)
  196. return Tensor(out)
  197. return vm_impl
  198. @vm_impl_getters.register(P.Less)
  199. def vm_impl_less(self):
  200. """Generate vm_impl function for Less"""
  201. def vm_impl(x, y):
  202. x = x.asnumpy()
  203. y = y.asnumpy()
  204. out = vm.less(x, y)
  205. return Tensor(np.array(out))
  206. return vm_impl
  207. @vm_impl_getters.register(P.ScalarCast)
  208. def vm_impl_scalar_cast(self):
  209. """Generate vm_impl function for ScalarCast"""
  210. def vm_impl(x, t):
  211. np_type = dtype_to_nptype(t)
  212. value = np_type(x)
  213. cast_value = value.item()
  214. return cast_value
  215. return vm_impl