You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

math_ops_vm_impl.py 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Generate vm_impl function for math ops"""
  16. import copy
  17. import numpy as np
  18. from mindspore.common.dtype import dtype_to_nptype
  19. from mindspore.common.tensor import Tensor
  20. from mindspore.ops import operations as P
  21. from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
  22. from .vm_interface import vm
  23. # pylint: disable=unused-argument
  24. @vm_impl_getters.register(P.TensorAdd)
  25. def vm_impl_tensor_add(self):
  26. """Generate vm_impl function for TensorAdd."""
  27. def vm_impl(x, y):
  28. x = x.asnumpy()
  29. y = y.asnumpy()
  30. return Tensor(x + y)
  31. return vm_impl
  32. @vm_impl_getters.register(P.LogicalNot)
  33. def vm_impl_logical_not(self):
  34. x = x.asnumpy()
  35. out = vm.logical_not(x)
  36. return Tensor(out)
  37. @vm_impl_getters.register(P.MatMul)
  38. def vm_impl_mat_mul(self):
  39. """Generate vm_impl function for MatMul."""
  40. def vm_impl(x, w):
  41. x = x.asnumpy()
  42. w = w.asnumpy()
  43. if self.transpose_a:
  44. x = x.transpose()
  45. if self.transpose_b:
  46. w = w.transpose()
  47. z = x @ w
  48. return Tensor(z)
  49. return vm_impl
  50. @vm_impl_getters.register(P.AddN)
  51. def vm_impl_addn(self):
  52. """Generate vm_impl function for AddN."""
  53. def vm_impl(inputs):
  54. added = copy.deepcopy(inputs[0].asnumpy())
  55. for x in inputs[1:]:
  56. added += x.asnumpy()
  57. return Tensor(added)
  58. return vm_impl
  59. @vm_impl_getters.register(P.Neg)
  60. def vm_impl_neg(self):
  61. """Generate vm_impl function for Neg."""
  62. def vm_impl(x):
  63. x = x.asnumpy()
  64. return Tensor(-x)
  65. return vm_impl
  66. @vm_impl_getters.register(P.Sub)
  67. def vm_impl_Sub(self):
  68. """Generate vm_impl function for Sub."""
  69. def vm_impl(x, y):
  70. x = x.asnumpy()
  71. y = y.asnumpy()
  72. return Tensor(x - y)
  73. return vm_impl
  74. @vm_impl_getters.register(P.Mul)
  75. def vm_impl_mul(self):
  76. """Generate vm_impl function for Mul."""
  77. def vm_impl(x, y):
  78. x = x.asnumpy()
  79. y = y.asnumpy()
  80. return Tensor(x * y)
  81. return vm_impl
  82. @vm_impl_getters.register(P.Square)
  83. def vm_impl_square(self):
  84. """Generate vm_impl function for Square."""
  85. def vm_impl(x):
  86. x = x.asnumpy()
  87. return Tensor(x * x)
  88. return vm_impl
  89. @vm_impl_getters.register(P.Sqrt)
  90. def vm_impl_sqrt(self):
  91. """Generate vm_impl function for Sqrt."""
  92. def vm_impl(x):
  93. x = x.asnumpy()
  94. res = vm.sqrt(x)
  95. return Tensor(res)
  96. return vm_impl
  97. @vm_impl_getters.register(P.Pow)
  98. def vm_impl_pow(self):
  99. """Generate vm_impl function for Pow."""
  100. def vm_impl(x, y):
  101. x = x.asnumpy()
  102. y = y.asnumpy()
  103. res = vm.power(x, y)
  104. return Tensor(res)
  105. return vm_impl
  106. @vm_impl_getters.register(P.Exp)
  107. def vm_impl_exp(self):
  108. """Generate vm_impl function for Exp."""
  109. def vm_impl(x):
  110. x = x.asnumpy()
  111. res = vm.exp(x)
  112. return Tensor(res)
  113. return vm_impl
  114. @vm_impl_getters.register(P.RealDiv)
  115. def vm_impl_real_div(self):
  116. """Generate vm_impl function for RealDiv."""
  117. def vm_impl(x, y):
  118. x = x.asnumpy()
  119. y = y.asnumpy()
  120. out = x / y
  121. out = np.array(out, x.dtype)
  122. return Tensor(out)
  123. return vm_impl
  124. @vm_impl_getters.register(P.Div)
  125. def vm_impl_div(self):
  126. """Generate vm_impl function for Div."""
  127. def vm_impl(x, y):
  128. x = x.asnumpy()
  129. y = y.asnumpy()
  130. return Tensor(x / y)
  131. return vm_impl
  132. @vm_impl_getters.register(P.ReduceMean)
  133. def vm_impl_reduce_mean(self):
  134. """Generate vm_impl function for ReduceMean."""
  135. def vm_impl(x, axis):
  136. x = x.asnumpy()
  137. out = vm.mean(x, axis)
  138. return Tensor(out)
  139. return vm_impl
  140. @vm_impl_getters.register(P.Equal)
  141. def vm_impl_equal(self):
  142. """Generate vm_impl function for Equal."""
  143. def vm_impl(x, y):
  144. x = x.asnumpy()
  145. y = y.asnumpy()
  146. out = vm.equal(x, y)
  147. return Tensor(np.array(out))
  148. return vm_impl
  149. @vm_impl_getters.register(P.NotEqual)
  150. def vm_impl_not_equal(self):
  151. """Generate vm_impl function for NotEqual."""
  152. def vm_impl(x, y):
  153. x = x.asnumpy()
  154. y = y.asnumpy()
  155. out = vm.not_equal(x, y)
  156. return Tensor(np.array(out))
  157. return vm_impl
  158. @vm_impl_getters.register(P.Greater)
  159. def vm_impl_greater(self):
  160. """Generate vm_impl function for Greater."""
  161. def vm_impl(x, y):
  162. x = x.asnumpy()
  163. y = y.asnumpy()
  164. out = vm.greater(x, y)
  165. return Tensor(np.array(out))
  166. return vm_impl
  167. @vm_impl_getters.register(P.Maximum)
  168. def vm_impl_maximum(self):
  169. """Generate vm_impl function for Maximum."""
  170. def vm_impl(x, y):
  171. x = x.asnumpy()
  172. y = y.asnumpy()
  173. out = vm.maximum(x, y)
  174. return Tensor(out)
  175. return vm_impl
  176. @vm_impl_getters.register(P.Minimum)
  177. def vm_impl_minimum(self):
  178. """Generate vm_impl function for Minimum."""
  179. def vm_impl(x, y):
  180. x = x.asnumpy()
  181. y = y.asnumpy()
  182. out = vm.minimum(x, y)
  183. return Tensor(out)
  184. return vm_impl
  185. @vm_impl_getters.register(P.Less)
  186. def vm_impl_less(self):
  187. """Generate vm_impl function for Less"""
  188. def vm_impl(x, y):
  189. x = x.asnumpy()
  190. y = y.asnumpy()
  191. out = vm.less(x, y)
  192. return Tensor(np.array(out))
  193. return vm_impl
  194. @vm_impl_getters.register(P.ScalarCast)
  195. def vm_impl_scalar_cast(self):
  196. """Generate vm_impl function for ScalarCast"""
  197. def vm_impl(x, t):
  198. np_type = dtype_to_nptype(t)
  199. value = np_type(x)
  200. cast_value = value.item()
  201. return cast_value
  202. return vm_impl