You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_operator.py 3.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_operator """
  16. import numpy as np
  17. from mindspore.ops import operations as P
  18. from mindspore.nn import ReLU
  19. from mindspore.nn import Cell
  20. from mindspore import Tensor, Model, context
  21. from ...ut_filter import non_graph_engine
  22. class arithmetic_Net(Cell):
  23. """ arithmetic_Net definition """
  24. def __init__(self, symbol, loop_count=(1, 3)):
  25. super().__init__()
  26. self.symbol = symbol
  27. self.loop_count = loop_count
  28. self.relu = ReLU()
  29. def construct(self, x):
  30. a, b = self.loop_count
  31. y = self.symbol
  32. if y == 1:
  33. a += b
  34. for _ in (b, a):
  35. x = self.relu(x)
  36. elif y == 2:
  37. b -= a
  38. for _ in (a, b):
  39. x = self.relu(x)
  40. elif y == 3:
  41. z = a + b
  42. for _ in (b, z):
  43. x = self.relu(x)
  44. elif y == 4:
  45. z = b - a
  46. for _ in (z, b):
  47. x = self.relu(x)
  48. elif y == 5:
  49. z = a * b
  50. for _ in (a, z):
  51. x = self.relu(x)
  52. elif y == 6:
  53. z = b / a
  54. for _ in (a, z):
  55. x = self.relu(x)
  56. elif y == 7:
  57. z = b % a + 1
  58. for _ in (a, z):
  59. x = self.relu(x)
  60. else:
  61. if not a:
  62. x = self.relu(x)
  63. return x
  64. class logical_Net(Cell):
  65. """ logical_Net definition """
  66. def __init__(self, symbol, loop_count=(1, 3)):
  67. super().__init__()
  68. self.symbol = symbol
  69. self.loop_count = loop_count
  70. self.fla = P.Flatten()
  71. self.relu = ReLU()
  72. def construct(self, x):
  73. a, b = self.loop_count
  74. y = self.symbol
  75. if y == 1:
  76. if b and a:
  77. x = self.relu(x)
  78. else:
  79. x = self.fla(x)
  80. else:
  81. if b or a:
  82. x = self.relu(x)
  83. else:
  84. x = self.fla(x)
  85. return x
  86. def arithmetic_operator_base(symbol):
  87. """ arithmetic_operator_base """
  88. input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
  89. input_me = Tensor(input_np)
  90. logical_operator = {"++": 1, "--": 2, "+": 3, "-": 4, "*": 5, "/": 6, "%": 7, "not": 8}
  91. x = logical_operator[symbol]
  92. net = arithmetic_Net(x)
  93. context.set_context(mode=context.GRAPH_MODE)
  94. model = Model(net)
  95. model.predict(input_me)
  96. def logical_operator_base(symbol):
  97. """ logical_operator_base """
  98. input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
  99. input_me = Tensor(input_np)
  100. logical_operator = {"and": 1, "or": 2}
  101. x = logical_operator[symbol]
  102. net = logical_Net(x)
  103. context.set_context(mode=context.GRAPH_MODE)
  104. model = Model(net)
  105. model.predict(input_me)
  106. @non_graph_engine
  107. def test_ME_arithmetic_operator_0080():
  108. """ test_ME_arithmetic_operator_0080 """
  109. arithmetic_operator_base('not')
  110. @non_graph_engine
  111. def test_ME_arithmetic_operator_0070():
  112. """ test_ME_arithmetic_operator_0070 """
  113. logical_operator_base('and')
  114. @non_graph_engine
  115. def test_ME_logical_operator_0020():
  116. """ test_ME_logical_operator_0020 """
  117. logical_operator_base('or')