You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_operator.py 6.1 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_operator """
  16. import numpy as np
  17. from mindspore import Tensor, Model, context
  18. from mindspore.nn import Cell
  19. from mindspore.nn import ReLU
  20. from mindspore.ops import operations as P
  21. from ...ut_filter import non_graph_engine
  22. class arithmetic_Net(Cell):
  23. """ arithmetic_Net definition """
  24. def __init__(self, symbol, loop_count=(1, 3)):
  25. super().__init__()
  26. self.symbol = symbol
  27. self.loop_count = loop_count
  28. self.relu = ReLU()
  29. def construct(self, x):
  30. a, b = self.loop_count
  31. y = self.symbol
  32. if y == 1:
  33. a += b
  34. for _ in (b, a):
  35. x = self.relu(x)
  36. elif y == 2:
  37. b -= a
  38. for _ in (a, b):
  39. x = self.relu(x)
  40. elif y == 3:
  41. z = a + b
  42. for _ in (b, z):
  43. x = self.relu(x)
  44. elif y == 4:
  45. z = b - a
  46. for _ in (z, b):
  47. x = self.relu(x)
  48. elif y == 5:
  49. z = a * b
  50. for _ in (a, z):
  51. x = self.relu(x)
  52. elif y == 6:
  53. z = b / a
  54. for _ in (a, z):
  55. x = self.relu(x)
  56. elif y == 7:
  57. z = b % a + 1
  58. for _ in (a, z):
  59. x = self.relu(x)
  60. else:
  61. if not a:
  62. x = self.relu(x)
  63. return x
  64. class logical_Net(Cell):
  65. """ logical_Net definition """
  66. def __init__(self, symbol, loop_count=(1, 3)):
  67. super().__init__()
  68. self.symbol = symbol
  69. self.loop_count = loop_count
  70. self.fla = P.Flatten()
  71. self.relu = ReLU()
  72. def construct(self, x):
  73. a, b = self.loop_count
  74. y = self.symbol
  75. if y == 1:
  76. if b and a:
  77. x = self.relu(x)
  78. else:
  79. x = self.fla(x)
  80. else:
  81. if b or a:
  82. x = self.relu(x)
  83. else:
  84. x = self.fla(x)
  85. return x
  86. def arithmetic_operator_base(symbol):
  87. """ arithmetic_operator_base """
  88. input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
  89. input_me = Tensor(input_np)
  90. logical_operator = {"++": 1, "--": 2, "+": 3, "-": 4, "*": 5, "/": 6, "%": 7, "not": 8}
  91. x = logical_operator[symbol]
  92. net = arithmetic_Net(x)
  93. context.set_context(mode=context.GRAPH_MODE)
  94. model = Model(net)
  95. model.predict(input_me)
  96. def logical_operator_base(symbol):
  97. """ logical_operator_base """
  98. input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
  99. input_me = Tensor(input_np)
  100. logical_operator = {"and": 1, "or": 2}
  101. x = logical_operator[symbol]
  102. net = logical_Net(x)
  103. context.set_context(mode=context.GRAPH_MODE)
  104. model = Model(net)
  105. model.predict(input_me)
  106. @non_graph_engine
  107. def test_ME_arithmetic_operator_0080():
  108. """ test_ME_arithmetic_operator_0080 """
  109. arithmetic_operator_base('not')
  110. @non_graph_engine
  111. def test_ME_arithmetic_operator_0070():
  112. """ test_ME_arithmetic_operator_0070 """
  113. logical_operator_base('and')
  114. @non_graph_engine
  115. def test_ME_logical_operator_0020():
  116. """ test_ME_logical_operator_0020 """
  117. logical_operator_base('or')
  118. def test_ops():
  119. class OpsNet(Cell):
  120. """ OpsNet definition """
  121. def __init__(self, x, y):
  122. super(OpsNet, self).__init__()
  123. self.x = x
  124. self.y = y
  125. self.int = 4
  126. self.float = 3.2
  127. self.str_a = "hello"
  128. self.str_b = "world"
  129. def construct(self, x, y):
  130. h = x // y
  131. m = x ** y
  132. n = x % y
  133. r = self.x // self.y
  134. s = self.x ** self.y
  135. t = self.x % self.y
  136. p = h + m + n
  137. q = r + s + t
  138. ret_pow = p ** q + q ** p
  139. ret_mod = p % q + q % p
  140. ret_floor = p // q + q // p
  141. ret = ret_pow + ret_mod + ret_floor
  142. if self.int > self.float:
  143. if [1, 2, 3] is not None:
  144. if self.str_a + self.str_b == "helloworld":
  145. if q == 86:
  146. print("hello world")
  147. return ret
  148. return x
  149. net = OpsNet(9, 2)
  150. x = Tensor(np.random.randint(low=1, high=10, size=(2, 3, 4), dtype=np.int32))
  151. y = Tensor(np.random.randint(low=10, high=20, size=(2, 3, 4), dtype=np.int32))
  152. context.set_context(mode=context.GRAPH_MODE)
  153. net(x, y)
  154. def test_in_dict():
  155. class InDictNet(Cell):
  156. """ InDictNet definition """
  157. def __init__(self, key_in, key_not_in):
  158. super(InDictNet, self).__init__()
  159. self.key_in = key_in
  160. self.key_not_in = key_not_in
  161. def construct(self, x, y, z):
  162. d = {"a": x, "b": y}
  163. ret_in = 1
  164. ret_not_in = 2
  165. if self.key_in in d:
  166. ret_in = d[self.key_in]
  167. if self.key_not_in not in d:
  168. ret_not_in = z
  169. ret = ret_in + ret_not_in
  170. return ret
  171. net = InDictNet("a", "c")
  172. x = Tensor(np.random.randint(low=1, high=10, size=(2, 3, 4), dtype=np.int32))
  173. y = Tensor(np.random.randint(low=10, high=20, size=(2, 3, 4), dtype=np.int32))
  174. z = Tensor(np.random.randint(low=20, high=30, size=(2, 3, 4), dtype=np.int32))
  175. context.set_context(mode=context.GRAPH_MODE)
  176. net(x, y, z)