You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_enumerate.py 7.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test enumerate"""
  16. import numpy as np
  17. import pytest
  18. import mindspore.nn as nn
  19. from mindspore import Tensor
  20. from mindspore import context
  21. context.set_context(mode=context.GRAPH_MODE)
  22. def test_enumerate_list_const():
  23. class Net(nn.Cell):
  24. def __init__(self):
  25. super(Net, self).__init__()
  26. self.value = [11, 22, 33, 44]
  27. def construct(self):
  28. index_sum = 0
  29. value_sum = 0
  30. for i, j in enumerate(self.value):
  31. index_sum += i
  32. value_sum += j
  33. return index_sum, value_sum
  34. net = Net()
  35. assert net() == (6, 110)
  36. def test_enumerate_tuple_const():
  37. class Net(nn.Cell):
  38. def __init__(self):
  39. super(Net, self).__init__()
  40. self.value = (11, 22, 33, 44)
  41. def construct(self):
  42. index_sum = 0
  43. value_sum = 0
  44. for i, j in enumerate(self.value):
  45. index_sum += i
  46. value_sum += j
  47. return index_sum, value_sum
  48. net = Net()
  49. assert net() == (6, 110)
  50. def test_enumerate_tensor_const():
  51. class Net(nn.Cell):
  52. def __init__(self):
  53. super(Net, self).__init__()
  54. self.value = Tensor(np.arange(2 * 3).reshape(2, 3))
  55. def construct(self):
  56. return enumerate(self.value)
  57. net = Net()
  58. net()
  59. def test_enumerate_list_parameter():
  60. class Net(nn.Cell):
  61. def __init__(self):
  62. super(Net, self).__init__()
  63. def construct(self, x, y):
  64. index_sum = 0
  65. value = [x, y]
  66. ret = ()
  67. for i, j in enumerate(value):
  68. index_sum += i
  69. ret += (j,)
  70. return index_sum, ret
  71. x = Tensor(np.arange(4))
  72. net = Net()
  73. net(x, x)
  74. def test_enumerate_tuple_parameter():
  75. class Net(nn.Cell):
  76. def __init__(self):
  77. super(Net, self).__init__()
  78. def construct(self, x, y):
  79. index_sum = 0
  80. value = (x, y)
  81. ret = ()
  82. for i, j in enumerate(value):
  83. index_sum += i
  84. ret += (j,)
  85. return index_sum, ret
  86. x = Tensor(np.arange(4))
  87. net = Net()
  88. net(x, x)
  89. def test_enumerate_tensor_parameter():
  90. class Net(nn.Cell):
  91. def __init__(self):
  92. super(Net, self).__init__()
  93. def construct(self, x):
  94. index_sum = 0
  95. ret = ()
  96. for i, j in enumerate(x):
  97. index_sum += i
  98. ret += (j,)
  99. return index_sum, ret
  100. x = Tensor(np.arange(2 * 3).reshape(2, 3))
  101. net = Net()
  102. net(x)
  103. def test_enumerate_tuple_const_1():
  104. class Net(nn.Cell):
  105. def __init__(self):
  106. super(Net, self).__init__()
  107. self.value = (11, 22, 33, 44)
  108. def construct(self):
  109. index_sum = 0
  110. value_sum = 0
  111. for i in enumerate(self.value):
  112. index_sum += i[0]
  113. value_sum += i[1]
  114. return index_sum, value_sum
  115. net = Net()
  116. assert net() == (6, 110)
  117. def test_enumerate_tensor_const_1():
  118. class Net(nn.Cell):
  119. def __init__(self):
  120. super(Net, self).__init__()
  121. self.value = Tensor(np.arange(2*3).reshape(2, 3))
  122. def construct(self):
  123. index_sum = 0
  124. ret = ()
  125. for i in enumerate(self.value):
  126. index_sum += i[0]
  127. ret += (i[1],)
  128. return index_sum, ret
  129. net = Net()
  130. net()
  131. def test_enumerate_tuple_parameter_1():
  132. class Net(nn.Cell):
  133. def __init__(self):
  134. super(Net, self).__init__()
  135. def construct(self, x, y):
  136. index_sum = 0
  137. value = (x, y)
  138. ret = ()
  139. for i in enumerate(value):
  140. index_sum += i[0]
  141. ret += (i[1],)
  142. return index_sum, ret
  143. x = Tensor(np.arange(4))
  144. net = Net()
  145. net(x, x)
  146. def test_enumerate_tensor_parameter_1():
  147. class Net(nn.Cell):
  148. def __init__(self):
  149. super(Net, self).__init__()
  150. def construct(self, x):
  151. index_sum = 0
  152. ret = ()
  153. for i in enumerate(x):
  154. index_sum += i[0]
  155. ret += (i[1],)
  156. return index_sum, ret
  157. x = Tensor(np.arange(2 * 3).reshape(2, 3))
  158. net = Net()
  159. net(x)
  160. def test_enumerate_tuple_const_2():
  161. class Net(nn.Cell):
  162. def __init__(self):
  163. super(Net, self).__init__()
  164. self.value = (11, 22, 33, 44)
  165. def construct(self):
  166. index_sum = 0
  167. value_sum = 0
  168. for i in enumerate(self.value, 1):
  169. index_sum += i[0]
  170. value_sum += i[1]
  171. return index_sum, value_sum
  172. net = Net()
  173. assert net() == (10, 110)
  174. def test_enumerate_tensor_const_2():
  175. class Net(nn.Cell):
  176. def __init__(self):
  177. super(Net, self).__init__()
  178. self.value = Tensor(np.arange(2 * 3).reshape(2, 3))
  179. def construct(self):
  180. index_sum = 0
  181. ret = ()
  182. for i in enumerate(self.value, 1):
  183. index_sum += i[0]
  184. ret += (i[1],)
  185. return index_sum, ret
  186. net = Net()
  187. net()
  188. def test_enumerate_tuple_parameter_2():
  189. class Net(nn.Cell):
  190. def __init__(self):
  191. super(Net, self).__init__()
  192. def construct(self, x, y):
  193. index_sum = 0
  194. value = (x, y)
  195. ret = ()
  196. for i in enumerate(value, 1):
  197. index_sum += i[0]
  198. ret += (i[1],)
  199. return index_sum, ret
  200. x = Tensor(np.arange(4))
  201. net = Net()
  202. net(x, x)
  203. def test_enumerate_tensor_parameter_2():
  204. class Net(nn.Cell):
  205. def __init__(self):
  206. super(Net, self).__init__()
  207. def construct(self, x):
  208. index_sum = 0
  209. ret = ()
  210. for i, j in enumerate(x, 1):
  211. index_sum += i
  212. ret += (j,)
  213. return index_sum, ret
  214. x = Tensor(np.arange(2 * 3).reshape(2, 3))
  215. net = Net()
  216. net(x)
  217. def test_enumerate_start_type_error():
  218. class Net(nn.Cell):
  219. def __init__(self):
  220. super(Net, self).__init__()
  221. def construct(self, x):
  222. return enumerate((x, x), start=1.2)
  223. x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
  224. net = Net()
  225. with pytest.raises(TypeError) as ex:
  226. net(x)
  227. assert "For 'enumerate', the 'start'" in str(ex.value)