You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_take.py 9.6 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago

  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.dataset as ds
  18. from mindspore import log as logger
  19. # In generator dataset: Number of rows is 3, its value is 0, 1, 2
  20. def generator():
  21. for i in range(3):
  22. yield (np.array([i]),)
  23. # In generator dataset: Number of rows is 10, its value is 0, 1, 2 ... 10
  24. def generator_10():
  25. for i in range(10):
  26. yield (np.array([i]),)
  27. def filter_func_ge(data):
  28. if data > 3:
  29. return False
  30. return True
  31. def test_take_01():
  32. """
  33. Test take: origin there are 3 row, and take 1 row, in this case: will not meet eoe and eof
  34. """
  35. logger.info("test_take_01")
  36. data1 = ds.GeneratorDataset(generator, ["data"])
  37. data1 = data1.take(1)
  38. data1 = data1.repeat(2)
  39. # Here i refers to index, d refers to data element
  40. for _, d in enumerate(data1):
  41. assert d[0].asnumpy()[0] == 0
  42. assert sum([1 for _ in data1]) == 2
  43. def test_take_02():
  44. """
  45. Test take: origin there are 3 row, and take 2 row, in this case: will meet eoe
  46. """
  47. logger.info("test_take_02")
  48. data1 = ds.GeneratorDataset(generator, ["data"])
  49. data1 = data1.take(2)
  50. data1 = data1.repeat(2)
  51. # Here i refers to index, d refers to data element
  52. for i, d in enumerate(data1):
  53. assert i % 2 == d[0].asnumpy()[0]
  54. assert sum([1 for _ in data1]) == 4
  55. def test_take_03():
  56. """
  57. Test take: origin there are 3 row, and take 3 row, in this case: will meet eoe and eof
  58. """
  59. logger.info("test_take_03")
  60. data1 = ds.GeneratorDataset(generator, ["data"])
  61. data1 = data1.take(3)
  62. data1 = data1.repeat(2)
  63. # Here i refers to index, d refers to data elements
  64. for i, d in enumerate(data1):
  65. assert i % 3 == d[0].asnumpy()[0]
  66. assert sum([1 for _ in data1]) == 6
  67. def test_take_04():
  68. """
  69. Test take: origin there are 3 row, and take 4 row, this is more than the total rows
  70. """
  71. logger.info("test_take_04")
  72. data1 = ds.GeneratorDataset(generator, ["data"])
  73. data1 = data1.take(4)
  74. data1 = data1.repeat(2)
  75. # Here i refers to index, d refers to data element
  76. for i, d in enumerate(data1):
  77. assert i % 3 == d[0].asnumpy()[0]
  78. assert sum([1 for _ in data1]) == 6
  79. def test_take_05():
  80. """
  81. Test take: there is no repeat op
  82. """
  83. logger.info("test_take_05")
  84. data1 = ds.GeneratorDataset(generator, ["data"])
  85. data1 = data1.take(2)
  86. # Here i refers to index, d refers to data element
  87. for i, d in enumerate(data1):
  88. assert i == d[0].asnumpy()[0]
  89. assert sum([1 for _ in data1]) == 2
  90. def test_take_06():
  91. """
  92. Test take: repeat is before take
  93. """
  94. logger.info("test_take_06")
  95. data1 = ds.GeneratorDataset(generator, ["data"])
  96. data1 = data1.repeat(2)
  97. data1 = data1.take(4)
  98. # Here i refers to index, d refers to data element
  99. for i, d in enumerate(data1):
  100. assert i % 3 == d[0].asnumpy()[0]
  101. assert sum([1 for _ in data1]) == 4
  102. def test_take_07():
  103. """
  104. Test take: take is before batch, that mean take(N), N refer to rows num
  105. """
  106. logger.info("test_take_07")
  107. data1 = ds.GeneratorDataset(generator, ["data"])
  108. data1 = data1.take(2)
  109. data1 = data1.batch(2)
  110. assert sum([1 for _ in data1]) == 1
  111. def test_take_08():
  112. """
  113. Test take: take is after batch, that mean take(N), N refer to batches num
  114. """
  115. logger.info("test_take_08")
  116. data1 = ds.GeneratorDataset(generator, ["data"])
  117. data1 = data1.batch(2)
  118. data1 = data1.take(2)
  119. assert sum([1 for _ in data1]) == 2
  120. def test_take_09():
  121. """
  122. Test take: take count is -1, and read the whole dataset, take after repeat
  123. """
  124. logger.info("test_take_09")
  125. data1 = ds.GeneratorDataset(generator, ["data"])
  126. data1 = data1.repeat(2)
  127. data1 = data1.take(-1)
  128. # Here i refers to index, d refers to data element
  129. for i, d in enumerate(data1):
  130. assert i % 3 == d[0].asnumpy()[0]
  131. assert sum([1 for _ in data1]) == 6
  132. def test_take_10():
  133. """
  134. Test take: take count is -1, and read the whole dataset, take before repeat
  135. """
  136. logger.info("test_take_10")
  137. data1 = ds.GeneratorDataset(generator, ["data"])
  138. data1 = data1.take(-1)
  139. data1 = data1.repeat(2)
  140. # Here i refers to index, d refers to data element
  141. for i, d in enumerate(data1):
  142. assert i % 3 == d[0].asnumpy()[0]
  143. assert sum([1 for _ in data1]) == 6
  144. def test_take_11():
  145. """
  146. Test take: batch first, then do repeat and take operation
  147. """
  148. logger.info("test_take_11")
  149. data1 = ds.GeneratorDataset(generator, ["data"])
  150. data1 = data1.batch(2)
  151. data1 = data1.repeat(2)
  152. data1 = data1.take(-1)
  153. # Here i refers to index, d refers to data element
  154. for i, d in enumerate(data1):
  155. assert 2 * (i % 2) == d[0].asnumpy()[0]
  156. assert sum([1 for _ in data1]) == 4
  157. def test_take_12():
  158. """
  159. Test take: take first, then do batch and repeat operation
  160. """
  161. logger.info("test_take_12")
  162. data1 = ds.GeneratorDataset(generator, ["data"])
  163. data1 = data1.take(2)
  164. data1 = data1.batch(2)
  165. data1 = data1.repeat(2)
  166. # Here i refers to index, d refers to data element
  167. for _, d in enumerate(data1):
  168. assert d[0].asnumpy()[0] == 0
  169. assert sum([1 for _ in data1]) == 2
  170. def test_take_13():
  171. """
  172. Test take: skip first, then do take, batch and repeat operation
  173. """
  174. logger.info("test_take_13")
  175. data1 = ds.GeneratorDataset(generator, ["data"])
  176. data1 = data1.skip(2)
  177. data1 = data1.take(-1)
  178. data1 = data1.batch(2)
  179. data1 = data1.repeat(2)
  180. # Here i refers to index, d refers to data element
  181. for _, d in enumerate(data1):
  182. assert d[0].asnumpy()[0] == 2
  183. assert sum([1 for _ in data1]) == 2
  184. def test_take_14():
  185. """
  186. Test take: take first, then do batch, skip and repeat operation
  187. """
  188. logger.info("test_take_14")
  189. data1 = ds.GeneratorDataset(generator, ["data"])
  190. data1 = data1.take(-1)
  191. data1 = data1.batch(2)
  192. data1 = data1.skip(1)
  193. data1 = data1.repeat(2)
  194. # Here i refers to index, d refers to data element
  195. for _, d in enumerate(data1):
  196. assert d[0].asnumpy()[0] == 2
  197. assert sum([1 for _ in data1]) == 2
  198. def test_take_15():
  199. """
  200. Test take: large amount data, take a part, then do skip operation
  201. """
  202. logger.info("test_take_15")
  203. data1 = ds.GeneratorDataset(generator_10, ["data"])
  204. data1 = data1.take(6)
  205. data1 = data1.skip(2)
  206. # Here i refers to index, d refers to data element
  207. for i, d in enumerate(data1):
  208. assert (i + 2) == d[0].asnumpy()[0]
  209. assert sum([1 for _ in data1]) == 4
  210. def test_take_16():
  211. """
  212. Test take: large amount data, skip a part, then do take operation
  213. """
  214. logger.info("test_take_16")
  215. data1 = ds.GeneratorDataset(generator_10, ["data"])
  216. data1 = data1.skip(3)
  217. data1 = data1.take(5)
  218. # Here i refers to index, d refers to data element
  219. for i, d in enumerate(data1):
  220. assert (i + 3) == d[0].asnumpy()[0]
  221. assert sum([1 for _ in data1]) == 5
  222. def test_take_17():
  223. """
  224. Test take: take first, then do fiter operation
  225. """
  226. logger.info("test_take_17")
  227. data1 = ds.GeneratorDataset(generator_10, ["data"])
  228. data1 = data1.take(8)
  229. data1 = data1.filter(predicate=filter_func_ge, num_parallel_workers=4)
  230. # Here i refers to index, d refers to data element
  231. for i, d in enumerate(data1):
  232. assert i == d[0].asnumpy()[0]
  233. assert sum([1 for _ in data1]) == 4
  234. def test_take_18():
  235. """
  236. Test take: take first, then do fiter, skip, batch and repeat operation
  237. """
  238. logger.info("test_take_18")
  239. data1 = ds.GeneratorDataset(generator_10, ["data"])
  240. data1 = data1.take(8)
  241. data1 = data1.filter(predicate=filter_func_ge, num_parallel_workers=4)
  242. data1 = data1.skip(2)
  243. data1 = data1.batch(2)
  244. data1 = data1.repeat(2)
  245. # Here i refers to index, d refers to data element
  246. for _, d in enumerate(data1):
  247. assert d[0].asnumpy()[0] == 2
  248. assert sum([1 for _ in data1]) == 2
  249. def test_take_19():
  250. """
  251. Test take: take is after batch, that mean take(N), N refer to batches num
  252. """
  253. logger.info("test_take_19")
  254. with pytest.raises(ValueError) as info:
  255. data1 = ds.GeneratorDataset(generator, ["data"])
  256. data1 = data1.batch(2)
  257. data1 = data1.take(0)
  258. assert "positive integer" in str(info.value)
  259. if __name__ == '__main__':
  260. test_take_01()
  261. test_take_02()
  262. test_take_03()
  263. test_take_04()
  264. test_take_05()
  265. test_take_06()
  266. test_take_07()
  267. test_take_08()
  268. test_take_09()
  269. test_take_10()
  270. test_take_11()
  271. test_take_12()
  272. test_take_13()
  273. test_take_14()
  274. test_take_15()
  275. test_take_16()
  276. test_take_17()
  277. test_take_18()
  278. test_take_19()
  279. logger.info('== test take operation finished ==')