You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_list.py 6.1 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import functools
  16. import numpy as np
  17. import mindspore.nn as nn
  18. from mindspore import Tensor
  19. from mindspore.ops import operations as P
  20. from ..ut_filter import non_graph_engine
  21. from ....mindspore_test_framework.mindspore_test import mindspore_test
  22. from ....mindspore_test_framework.pipeline.forward.compile_forward \
  23. import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
  24. def test_list_equal():
  25. class Net(nn.Cell):
  26. def __init__(self, z: list):
  27. super(Net, self).__init__()
  28. self.z = z
  29. def construct(self, x, y):
  30. if self.z == [1, 2, 3]:
  31. ret = x
  32. else:
  33. ret = y
  34. return ret
  35. x = Tensor(np.ones([6, 8, 10], np.int32))
  36. y = Tensor(np.zeros([3, 4, 5], np.int32))
  37. z = [1, 2, 3]
  38. net = Net(z)
  39. assert net(x, y) == x
  40. def test_list_not_equal():
  41. class Net(nn.Cell):
  42. def __init__(self, z: list):
  43. super(Net, self).__init__()
  44. self.z = z
  45. def construct(self, x, y):
  46. if self.z == [3, 4, 5]:
  47. ret = x
  48. else:
  49. ret = y
  50. return ret
  51. x = Tensor(np.ones([6, 8, 10], np.int32))
  52. y = Tensor(np.zeros([3, 4, 5], np.int32))
  53. z = [1, 2, 3]
  54. net = Net(z)
  55. assert net(x, y) == y
  56. def test_list_expansion():
  57. class Net(nn.Cell):
  58. def __init__(self, z: list):
  59. super(Net, self).__init__()
  60. self.z = z
  61. def construct(self, x, y):
  62. a, b, c = self.z
  63. if a == 1 and b == 2 and c == 3:
  64. ret = x
  65. else:
  66. ret = y
  67. return ret
  68. x = Tensor(np.ones([6, 8, 10], np.int32))
  69. y = Tensor(np.zeros([3, 4, 5], np.int32))
  70. z = [1, 2, 3]
  71. net = Net(z)
  72. assert net(x, y) == x
  73. def test_list_append():
  74. class Net(nn.Cell):
  75. def __init__(self, z: list):
  76. super(Net, self).__init__()
  77. self.z = z
  78. def construct(self, x, y):
  79. z = [[1, 2], 3]
  80. z[0].append(88)
  81. z[0].append(99)
  82. if z[0][3] == 99:
  83. ret = y
  84. else:
  85. ret = x
  86. return ret
  87. x = Tensor(np.ones([6, 8, 10], np.int32))
  88. y = Tensor(np.zeros([3, 4, 5], np.int32))
  89. z = [1, 2, 3]
  90. net = Net(z)
  91. assert net(x, y) == y
  92. def test_list_append_2():
  93. class Net(nn.Cell):
  94. def __init__(self, z: list):
  95. super(Net, self).__init__()
  96. self.z = z
  97. self.x = 9
  98. def construct(self, x, y):
  99. self.z[0].append(88)
  100. self.z[0].append(99)
  101. if self.z[0][3] == 88:
  102. ret = y
  103. else:
  104. ret = x
  105. return ret
  106. x = Tensor(np.ones([6, 8, 10], np.int32))
  107. y = Tensor(np.zeros([3, 4, 5], np.int32))
  108. z = [[1, 2], 3]
  109. net = Net(z)
  110. assert net(x, y) == x
  111. class ListOperate(nn.Cell):
  112. def __init__(self, ):
  113. super(ListOperate, self).__init__()
  114. def construct(self, t, l):
  115. x = [1, 2, 3, 4, 5, 6]
  116. x[2] = 9
  117. x[1] = x[3] + 11
  118. x[3] = x[1] + x[0]
  119. x[0] = x[2] * x[4]
  120. x[5] = x[1] - x[2]
  121. x[4] = x[3] / x[2]
  122. x.append(8)
  123. x.append(8)
  124. x.append(t)
  125. x.append(l)
  126. x.append(l)
  127. return x
  128. class AxisListNet(nn.Cell):
  129. def __init__(self):
  130. super(AxisListNet, self).__init__()
  131. self.reduce_sum = P.ReduceSum()
  132. self.reduce_mean = P.ReduceMean()
  133. self.reduce_max = P.ReduceMax()
  134. self.reduce_min = P.ReduceMin()
  135. self.add_n = P.AddN()
  136. self.axis = [0, 1, 2]
  137. def construct(self, x):
  138. ret_sum = self.reduce_sum(x, self.axis)
  139. ret_mean = self.reduce_mean(x, self.axis)
  140. ret_max = self.reduce_max(x, self.axis)
  141. ret_min = self.reduce_min(x, self.axis)
  142. ret = [ret_sum, ret_mean, ret_max, ret_min]
  143. return self.add_n(ret) + ret_sum
  144. class AxisListEmptyNet(nn.Cell):
  145. def __init__(self):
  146. super(AxisListEmptyNet, self).__init__()
  147. self.reduce_sum = P.ReduceSum()
  148. self.axis = []
  149. def construct(self, x):
  150. return self.reduce_sum(x, self.axis)
  151. class AxisListDefaultNet(nn.Cell):
  152. def __init__(self):
  153. super(AxisListDefaultNet, self).__init__()
  154. self.reduce_sum = P.ReduceSum()
  155. def construct(self, x):
  156. return self.reduce_sum(x)
  157. test_case_ops = [
  158. ('ListOperate', {
  159. 'block': ListOperate(),
  160. 'desc_inputs': [Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32)),
  161. [2, 3, 4]]}),
  162. ('AxisList', {
  163. 'block': AxisListNet(),
  164. 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}),
  165. ('AxisListEmpty', {
  166. 'block': AxisListEmptyNet(),
  167. 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}),
  168. ('AxisListDefault', {
  169. 'block': AxisListDefaultNet(),
  170. 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}),
  171. ]
  172. test_case_lists = [test_case_ops]
  173. test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists)
  174. # use -k to select certain testcast
  175. # pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm
  176. import mindspore.context as context
  177. @non_graph_engine
  178. @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
  179. def test_exec():
  180. context.set_context(mode=context.GRAPH_MODE)
  181. return test_exec_case