You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_take.py 8.0 kB


  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import mindspore.dataset as ds
  16. import mindspore.dataset.transforms.vision.c_transforms as vision
  17. from mindspore import log as logger
  18. import numpy as np
  19. # In generator dataset: Number of rows is 3, its value is 0, 1, 2
  20. def generator():
  21. for i in range(3):
  22. yield np.array([i]),
  23. # In generator dataset: Number of rows is 10, its value is 0, 1, 2 ... 10
  24. def generator_10():
  25. for i in range(10):
  26. yield np.array([i]),
  27. def test_take_01():
  28. """
  29. Test take: origin there are 3 row, and take 1 row, in this case: will not meet eoe and eof
  30. """
  31. logger.info("test_take_01")
  32. data1 = ds.GeneratorDataset(generator, ["data"])
  33. data1 = data1.take(1)
  34. data1 = data1.repeat(2)
  35. # Here i refers to index, d refers to data element
  36. for i, d in enumerate(data1):
  37. assert 0 == d[0][0]
  38. assert sum([1 for _ in data1]) == 2
  39. def test_take_02():
  40. """
  41. Test take: origin there are 3 row, and take 2 row, in this case: will meet eoe
  42. """
  43. logger.info("test_take_02")
  44. data1 = ds.GeneratorDataset(generator, ["data"])
  45. data1 = data1.take(2)
  46. data1 = data1.repeat(2)
  47. # Here i refers to index, d refers to data element
  48. for i, d in enumerate(data1):
  49. assert i % 2 == d[0][0]
  50. assert sum([1 for _ in data1]) == 4
  51. def test_take_03():
  52. """
  53. Test take: origin there are 3 row, and take 3 row, in this case: will meet eoe and eof
  54. """
  55. logger.info("test_take_03")
  56. data1 = ds.GeneratorDataset(generator, ["data"])
  57. data1 = data1.take(3)
  58. data1 = data1.repeat(2)
  59. # Here i refers to index, d refers to data element
  60. for i, d in enumerate(data1):
  61. assert i % 3 == d[0][0]
  62. assert sum([1 for _ in data1]) == 6
  63. def test_take_04():
  64. """
  65. Test take: origin there are 3 row, and take 4 row, this is more than the total rows
  66. """
  67. logger.info("test_take_04")
  68. data1 = ds.GeneratorDataset(generator, ["data"])
  69. data1 = data1.take(4)
  70. data1 = data1.repeat(2)
  71. # Here i refers to index, d refers to data element
  72. for i, d in enumerate(data1):
  73. assert i % 3 == d[0][0]
  74. assert sum([1 for _ in data1]) == 6
  75. def test_take_05():
  76. """
  77. Test take: there is no repeat op
  78. """
  79. logger.info("test_take_05")
  80. data1 = ds.GeneratorDataset(generator, ["data"])
  81. data1 = data1.take(2)
  82. # Here i refers to index, d refers to data element
  83. for i, d in enumerate(data1):
  84. assert i == d[0][0]
  85. assert sum([1 for _ in data1]) == 2
  86. def test_take_06():
  87. """
  88. Test take: repeat is before take
  89. """
  90. logger.info("test_take_06")
  91. data1 = ds.GeneratorDataset(generator, ["data"])
  92. data1 = data1.repeat(2)
  93. data1 = data1.take(4)
  94. # Here i refers to index, d refers to data element
  95. for i, d in enumerate(data1):
  96. assert i % 3 == d[0][0]
  97. assert sum([1 for _ in data1]) == 4
  98. def test_take_07():
  99. """
  100. Test take: take is before batch, that mean take(N), N refer to rows num
  101. """
  102. logger.info("test_take_07")
  103. data1 = ds.GeneratorDataset(generator, ["data"])
  104. data1 = data1.take(2)
  105. data1 = data1.batch(2)
  106. assert sum([1 for _ in data1]) == 1
  107. def test_take_08():
  108. """
  109. Test take: take is after batch, that mean take(N), N refer to batches num
  110. """
  111. logger.info("test_take_08")
  112. data1 = ds.GeneratorDataset(generator, ["data"])
  113. data1 = data1.batch(2)
  114. data1 = data1.take(2)
  115. assert sum([1 for _ in data1]) == 2
  116. def test_take_09():
  117. """
  118. Test take: repeat count is -1, and read the whole dataset, take after repeat
  119. """
  120. logger.info("test_take_09")
  121. data1 = ds.GeneratorDataset(generator, ["data"])
  122. data1 = data1.repeat(2)
  123. data1 = data1.take(-1)
  124. # Here i refers to index, d refers to data element
  125. for i, d in enumerate(data1):
  126. assert i % 3 == d[0][0]
  127. assert sum([1 for _ in data1]) == 6
  128. def test_take_10():
  129. """
  130. Test take: repeat count is -1, and read the whole dataset, take before repeat
  131. """
  132. logger.info("test_take_10")
  133. data1 = ds.GeneratorDataset(generator, ["data"])
  134. data1 = data1.take(-1)
  135. data1 = data1.repeat(2)
  136. # Here i refers to index, d refers to data element
  137. for i, d in enumerate(data1):
  138. assert i % 3 == d[0][0]
  139. assert sum([1 for _ in data1]) == 6
  140. def test_take_11():
  141. """
  142. Test take: batch first, then do repeat and take operation
  143. """
  144. logger.info("test_take_11")
  145. data1 = ds.GeneratorDataset(generator, ["data"])
  146. data1 = data1.batch(2)
  147. data1 = data1.repeat(2)
  148. data1 = data1.take(-1)
  149. # Here i refers to index, d refers to data element
  150. for i, d in enumerate(data1):
  151. assert 2 * (i % 2) == d[0][0]
  152. assert sum([1 for _ in data1]) == 4
  153. def test_take_12():
  154. """
  155. Test take: take first, then do batch and repeat operation
  156. """
  157. logger.info("test_take_12")
  158. data1 = ds.GeneratorDataset(generator, ["data"])
  159. data1 = data1.take(2)
  160. data1 = data1.batch(2)
  161. data1 = data1.repeat(2)
  162. # Here i refers to index, d refers to data element
  163. for i, d in enumerate(data1):
  164. assert 0 == d[0][0]
  165. assert sum([1 for _ in data1]) == 2
  166. def test_take_13():
  167. """
  168. Test take: skip first, then do take, batch and repeat operation
  169. """
  170. logger.info("test_take_13")
  171. data1 = ds.GeneratorDataset(generator, ["data"])
  172. data1 = data1.skip(2)
  173. data1 = data1.take(-1)
  174. data1 = data1.batch(2)
  175. data1 = data1.repeat(2)
  176. # Here i refers to index, d refers to data element
  177. for i, d in enumerate(data1):
  178. assert 2 == d[0][0]
  179. assert sum([1 for _ in data1]) == 2
  180. def test_take_14():
  181. """
  182. Test take: take first, then do batch, skip and repeat operation
  183. """
  184. logger.info("test_take_14")
  185. data1 = ds.GeneratorDataset(generator, ["data"])
  186. data1 = data1.take(-1)
  187. data1 = data1.batch(2)
  188. data1 = data1.skip(1)
  189. data1 = data1.repeat(2)
  190. # Here i refers to index, d refers to data element
  191. for i, d in enumerate(data1):
  192. assert 2 == d[0][0]
  193. assert sum([1 for _ in data1]) == 2
  194. def test_take_15():
  195. """
  196. Test take: large amount data, take a part, then do skip operation
  197. """
  198. logger.info("test_take_15")
  199. data1 = ds.GeneratorDataset(generator_10, ["data"])
  200. data1 = data1.take(6)
  201. data1 = data1.skip(2)
  202. # Here i refers to index, d refers to data element
  203. for i, d in enumerate(data1):
  204. assert (i + 2) == d[0][0]
  205. assert sum([1 for _ in data1]) == 4
  206. def test_take_16():
  207. """
  208. Test take: large amount data, skip a part, then do take operation
  209. """
  210. logger.info("test_take_16")
  211. data1 = ds.GeneratorDataset(generator_10, ["data"])
  212. data1 = data1.skip(3)
  213. data1 = data1.take(5)
  214. # Here i refers to index, d refers to data element
  215. for i, d in enumerate(data1):
  216. assert (i + 3) == d[0][0]
  217. assert sum([1 for _ in data1]) == 5
  218. if __name__ == '__main__':
  219. test_take_01()
  220. test_take_02()
  221. test_take_03()
  222. test_take_04()
  223. test_take_05()
  224. test_take_06()
  225. test_take_07()
  226. test_take_08()
  227. test_take_09()
  228. test_take_10()
  229. test_take_11()
  230. test_take_12()
  231. test_take_13()
  232. test_take_14()
  233. test_take_15()
  234. test_take_16()
  235. logger.info('== test take operation finished ==')