You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

math_ops_vm_impl.py 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Generate vm_impl function for math ops"""
  16. import copy
  17. import numpy as np
  18. from mindspore.ops import operations as P
  19. from mindspore.common.tensor import Tensor
  20. from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
  21. from mindspore.common.dtype import dtype_to_nptype
  22. from .vm_interface import vm
  23. # pylint: disable=unused-argument
  24. @vm_impl_getters.register(P.TensorAdd)
  25. def vm_impl_tensor_add(self):
  26. """Generate vm_impl function for TensorAdd."""
  27. def vm_impl(x, y):
  28. x = x.asnumpy()
  29. y = y.asnumpy()
  30. return Tensor(x + y)
  31. return vm_impl
  32. @vm_impl_getters.register(P.LogicalNot)
  33. def vm_impl_logical_not(self):
  34. x = x.asnumpy()
  35. out = vm.logical_not(x)
  36. return Tensor(out)
  37. @vm_impl_getters.register(P.MatMul)
  38. def vm_impl_mat_mul(self):
  39. """Generate vm_impl function for MatMul."""
  40. def vm_impl(x, w):
  41. x = x.asnumpy()
  42. w = w.asnumpy()
  43. if self.transpose_a:
  44. x = x.transpose()
  45. if self.transpose_b:
  46. w = w.transpose()
  47. z = x@w
  48. return Tensor(z)
  49. return vm_impl
  50. @vm_impl_getters.register(P.AddN)
  51. def vm_impl_addn(self):
  52. """Generate vm_impl function for AddN."""
  53. def vm_impl(inputs):
  54. added = copy.deepcopy(inputs[0].asnumpy())
  55. for x in inputs[1:]:
  56. added += x.asnumpy()
  57. return Tensor(added)
  58. return vm_impl
  59. @vm_impl_getters.register(P.Neg)
  60. def vm_impl_neg(self):
  61. """Generate vm_impl function for Neg."""
  62. def vm_impl(x):
  63. x = x.asnumpy()
  64. return Tensor(-x)
  65. return vm_impl
  66. @vm_impl_getters.register(P.Sub)
  67. def vm_impl_Sub(self):
  68. """Generate vm_impl function for Sub."""
  69. def vm_impl(x, y):
  70. x = x.asnumpy()
  71. y = y.asnumpy()
  72. return Tensor(x - y)
  73. return vm_impl
  74. @vm_impl_getters.register(P.Mul)
  75. def vm_impl_mul(self):
  76. """Generate vm_impl function for Mul."""
  77. def vm_impl(x, y):
  78. x = x.asnumpy()
  79. y = y.asnumpy()
  80. return Tensor(x * y)
  81. return vm_impl
  82. @vm_impl_getters.register(P.Square)
  83. def vm_impl_square(self):
  84. """Generate vm_impl function for Square."""
  85. def vm_impl(x):
  86. x = x.asnumpy()
  87. return Tensor(x * x)
  88. return vm_impl
  89. @vm_impl_getters.register(P.Sqrt)
  90. def vm_impl_sqrt(self):
  91. """Generate vm_impl function for Sqrt."""
  92. def vm_impl(x):
  93. x = x.asnumpy()
  94. res = vm.sqrt(x)
  95. return Tensor(res)
  96. return vm_impl
  97. @vm_impl_getters.register(P.Pow)
  98. def vm_impl_pow(self):
  99. """Generate vm_impl function for Pow."""
  100. def vm_impl(x, y):
  101. x = x.asnumpy()
  102. res = vm.power(x, y)
  103. return Tensor(res)
  104. return vm_impl
  105. @vm_impl_getters.register(P.Exp)
  106. def vm_impl_exp(self):
  107. """Generate vm_impl function for Exp."""
  108. def vm_impl(x):
  109. x = x.asnumpy()
  110. res = vm.exp(x)
  111. return Tensor(res)
  112. return vm_impl
  113. @vm_impl_getters.register(P.RealDiv)
  114. def vm_impl_real_div(self):
  115. """Generate vm_impl function for RealDiv."""
  116. def vm_impl(x, y):
  117. x = x.asnumpy()
  118. y = y.asnumpy()
  119. out = x / y
  120. out = np.array(out, x.dtype)
  121. return Tensor(out)
  122. return vm_impl
  123. @vm_impl_getters.register(P.Div)
  124. def vm_impl_div(self):
  125. """Generate vm_impl function for Div."""
  126. def vm_impl(x, y):
  127. x = x.asnumpy()
  128. y = y.asnumpy()
  129. return Tensor(x / y)
  130. return vm_impl
  131. @vm_impl_getters.register(P.ReduceMean)
  132. def vm_impl_reduce_mean(self):
  133. """Generate vm_impl function for ReduceMean."""
  134. def vm_impl(x, axis):
  135. x = x.asnumpy()
  136. out = vm.mean(x, axis)
  137. return Tensor(out)
  138. return vm_impl
  139. @vm_impl_getters.register(P.Equal)
  140. def vm_impl_equal(self):
  141. """Generate vm_impl function for Equal."""
  142. def vm_impl(x, y):
  143. x = x.asnumpy()
  144. y = y.asnumpy()
  145. out = vm.equal(x, y)
  146. return Tensor(out)
  147. return vm_impl
  148. @vm_impl_getters.register(P.NotEqual)
  149. def vm_impl_not_equal(self):
  150. """Generate vm_impl function for NotEqual."""
  151. def vm_impl(x, y):
  152. x = x.asnumpy()
  153. y = y.asnumpy()
  154. out = vm.not_equal(x, y)
  155. return Tensor(out)
  156. return vm_impl
  157. @vm_impl_getters.register(P.Greater)
  158. def vm_impl_greater(self):
  159. """Generate vm_impl function for Greater."""
  160. def vm_impl(x, y):
  161. x = x.asnumpy()
  162. y = y.asnumpy()
  163. out = vm.greater(x, y)
  164. return Tensor(out)
  165. return vm_impl
  166. @vm_impl_getters.register(P.Maximum)
  167. def vm_impl_maximum(self):
  168. """Generate vm_impl function for Maximum."""
  169. def vm_impl(x, y):
  170. x = x.asnumpy()
  171. y = y.asnumpy()
  172. out = vm.maximum(x, y)
  173. return Tensor(out)
  174. return vm_impl
  175. @vm_impl_getters.register(P.Minimum)
  176. def vm_impl_minimum(self):
  177. """Generate vm_impl function for Minimum."""
  178. def vm_impl(x, y):
  179. x = x.asnumpy()
  180. y = y.asnumpy()
  181. out = vm.minimum(x, y)
  182. return Tensor(out)
  183. return vm_impl
  184. @vm_impl_getters.register(P.Less)
  185. def vm_impl_greater(self):
  186. """Generate vm_impl function for Less"""
  187. def vm_impl(x, y):
  188. x = x.asnumpy()
  189. y = y.asnumpy()
  190. out = vm.less(x, y)
  191. return Tensor(out)
  192. return vm_impl
  193. @vm_impl_getters.register(P.ScalarCast)
  194. def vm_impl_greater(self):
  195. """Generate vm_impl function for ScalarCast"""
  196. def vm_impl(x, t):
  197. np_type = dtype_to_nptype(t)
  198. value = np_type(x)
  199. cast_value = value.item()
  200. return cast_value
  201. return vm_impl

MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.