You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_list.py 6.1 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import functools
  16. import numpy as np
  17. import mindspore.nn as nn
  18. import mindspore.context as context
  19. from mindspore import Tensor
  20. from mindspore.ops import operations as P
  21. from ..ut_filter import non_graph_engine
  22. from ....mindspore_test_framework.mindspore_test import mindspore_test
  23. from ....mindspore_test_framework.pipeline.forward.compile_forward \
  24. import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
  25. def test_list_equal():
  26. class Net(nn.Cell):
  27. def __init__(self, z: list):
  28. super(Net, self).__init__()
  29. self.z = z
  30. def construct(self, x, y):
  31. if self.z == [1, 2, 3]:
  32. ret = x
  33. else:
  34. ret = y
  35. return ret
  36. x = Tensor(np.ones([6, 8, 10], np.int32))
  37. y = Tensor(np.zeros([3, 4, 5], np.int32))
  38. z = [1, 2, 3]
  39. net = Net(z)
  40. assert net(x, y) == x
  41. def test_list_not_equal():
  42. class Net(nn.Cell):
  43. def __init__(self, z: list):
  44. super(Net, self).__init__()
  45. self.z = z
  46. def construct(self, x, y):
  47. if self.z == [3, 4, 5]:
  48. ret = x
  49. else:
  50. ret = y
  51. return ret
  52. x = Tensor(np.ones([6, 8, 10], np.int32))
  53. y = Tensor(np.zeros([3, 4, 5], np.int32))
  54. z = [1, 2, 3]
  55. net = Net(z)
  56. assert net(x, y) == y
  57. def test_list_expansion():
  58. class Net(nn.Cell):
  59. def __init__(self, z: list):
  60. super(Net, self).__init__()
  61. self.z = z
  62. def construct(self, x, y):
  63. a, b, c = self.z
  64. if a == 1 and b == 2 and c == 3:
  65. ret = x
  66. else:
  67. ret = y
  68. return ret
  69. x = Tensor(np.ones([6, 8, 10], np.int32))
  70. y = Tensor(np.zeros([3, 4, 5], np.int32))
  71. z = [1, 2, 3]
  72. net = Net(z)
  73. assert net(x, y) == x
  74. def test_list_append():
  75. class Net(nn.Cell):
  76. def __init__(self, z: list):
  77. super(Net, self).__init__()
  78. self.z = z
  79. def construct(self, x, y):
  80. z = [[1, 2], 3]
  81. z[0].append(88)
  82. z[0].append(99)
  83. if z[0][3] == 99:
  84. ret = y
  85. else:
  86. ret = x
  87. return ret
  88. x = Tensor(np.ones([6, 8, 10], np.int32))
  89. y = Tensor(np.zeros([3, 4, 5], np.int32))
  90. z = [1, 2, 3]
  91. net = Net(z)
  92. assert net(x, y) == y
  93. def test_list_append_2():
  94. class Net(nn.Cell):
  95. def __init__(self, z: list):
  96. super(Net, self).__init__()
  97. self.z = z
  98. self.x = 9
  99. def construct(self, x, y):
  100. self.z[0].append(88)
  101. self.z[0].append(99)
  102. if self.z[0][3] == 88:
  103. ret = y
  104. else:
  105. ret = x
  106. return ret
  107. x = Tensor(np.ones([6, 8, 10], np.int32))
  108. y = Tensor(np.zeros([3, 4, 5], np.int32))
  109. z = [[1, 2], 3]
  110. net = Net(z)
  111. assert net(x, y) == x
  112. class ListOperate(nn.Cell):
  113. def __init__(self,):
  114. super(ListOperate, self).__init__()
  115. def construct(self, t, l):
  116. x = [1, 2, 3, 4, 5, 6]
  117. x[2] = 9
  118. x[1] = x[3] + 11
  119. x[3] = x[1] + x[0]
  120. x[0] = x[2] * x[4]
  121. x[5] = x[1] - x[2]
  122. x[4] = x[3] / x[2]
  123. x.append(8)
  124. x.append(8)
  125. x.append(t)
  126. x.append(l)
  127. x.append(l)
  128. return x
  129. class AxisListNet(nn.Cell):
  130. def __init__(self):
  131. super(AxisListNet, self).__init__()
  132. self.reduce_sum = P.ReduceSum()
  133. self.reduce_mean = P.ReduceMean()
  134. self.reduce_max = P.ReduceMax()
  135. self.reduce_min = P.ReduceMin()
  136. self.add_n = P.AddN()
  137. self.axis = [0, 1, 2]
  138. def construct(self, x):
  139. ret_sum = self.reduce_sum(x, self.axis)
  140. ret_mean = self.reduce_mean(x, self.axis)
  141. ret_max = self.reduce_max(x, self.axis)
  142. ret_min = self.reduce_min(x, self.axis)
  143. ret = [ret_sum, ret_mean, ret_max, ret_min]
  144. return self.add_n(ret) + ret_sum
  145. class AxisListEmptyNet(nn.Cell):
  146. def __init__(self):
  147. super(AxisListEmptyNet, self).__init__()
  148. self.reduce_sum = P.ReduceSum()
  149. self.axis = []
  150. def construct(self, x):
  151. return self.reduce_sum(x, self.axis)
  152. class AxisListDefaultNet(nn.Cell):
  153. def __init__(self):
  154. super(AxisListDefaultNet, self).__init__()
  155. self.reduce_sum = P.ReduceSum()
  156. def construct(self, x):
  157. return self.reduce_sum(x)
  158. test_case_ops = [
  159. ('ListOperate', {
  160. 'block': ListOperate(),
  161. 'desc_inputs': [Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32)),
  162. [2, 3, 4]]}),
  163. ('AxisList', {
  164. 'block': AxisListNet(),
  165. 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}),
  166. ('AxisListEmpty', {
  167. 'block': AxisListEmptyNet(),
  168. 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}),
  169. ('AxisListDefault', {
  170. 'block': AxisListDefaultNet(),
  171. 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}),
  172. ]
  173. test_case_lists = [test_case_ops]
  174. test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists)
  175. # use -k to select certain testcast
  176. # pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm
  177. @non_graph_engine
  178. @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
  179. def test_exec():
  180. context.set_context(mode=context.GRAPH_MODE)
  181. return test_exec_case