You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_generator.py 28 kB

5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import copy
  16. import numpy as np
  17. import pytest
  18. import mindspore.common.dtype as mstype
  19. import mindspore.dataset as ds
  20. from mindspore import log as logger
  21. # Generate 1d int numpy array from 0 - 63
  22. def generator_1d():
  23. for i in range(64):
  24. yield (np.array([i]),)
  25. class DatasetGenerator:
  26. def __init__(self):
  27. pass
  28. def __getitem__(self, item):
  29. return (np.array([item]),)
  30. def __len__(self):
  31. return 10
  32. def test_generator_0():
  33. """
  34. Test 1D Generator
  35. """
  36. logger.info("Test 1D Generator : 0 - 63")
  37. # apply dataset operations
  38. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  39. i = 0
  40. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  41. golden = np.array([i])
  42. np.testing.assert_array_equal(item["data"], golden)
  43. i = i + 1
  44. # Generate md int numpy array from [[0, 1], [2, 3]] to [[63, 64], [65, 66]]
  45. def generator_md():
  46. for i in range(64):
  47. yield (np.array([[i, i + 1], [i + 2, i + 3]]),)
  48. def test_generator_1():
  49. """
  50. Test MD Generator
  51. """
  52. logger.info("Test MD Generator : 0 - 63, with shape [2, 2]")
  53. # apply dataset operations
  54. data1 = ds.GeneratorDataset(generator_md, ["data"])
  55. i = 0
  56. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  57. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  58. np.testing.assert_array_equal(item["data"], golden)
  59. i = i + 1
  60. # Generate two columns, the first column is from Generator1D, the second column is from GeneratorMD
  61. def generator_mc(maxid=64):
  62. for i in range(maxid):
  63. yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]]))
  64. def test_generator_2():
  65. """
  66. Test multi column generator
  67. """
  68. logger.info("Test multi column generator")
  69. # apply dataset operations
  70. data1 = ds.GeneratorDataset(generator_mc, ["col0", "col1"])
  71. i = 0
  72. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  73. golden = np.array([i])
  74. np.testing.assert_array_equal(item["col0"], golden)
  75. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  76. np.testing.assert_array_equal(item["col1"], golden)
  77. i = i + 1
  78. def test_generator_3():
  79. """
  80. Test 1D Generator + repeat(4)
  81. """
  82. logger.info("Test 1D Generator : 0 - 63 + Repeat(4)")
  83. # apply dataset operations
  84. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  85. data1 = data1.repeat(4)
  86. i = 0
  87. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  88. golden = np.array([i])
  89. np.testing.assert_array_equal(item["data"], golden)
  90. i = i + 1
  91. if i == 64:
  92. i = 0
  93. def test_generator_4():
  94. """
  95. Test fixed size 1D Generator + batch
  96. """
  97. logger.info("Test 1D Generator : 0 - 63 + batch(4)")
  98. # apply dataset operations
  99. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  100. data1 = data1.batch(4)
  101. i = 0
  102. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  103. golden = np.array([[i], [i + 1], [i + 2], [i + 3]])
  104. np.testing.assert_array_equal(item["data"], golden)
  105. i = i + 4
  106. def generator_with_type(t):
  107. for i in range(64):
  108. yield (np.array([i], dtype=t),)
  109. def type_tester(t):
  110. logger.info("Test with Type {}".format(t.__name__))
  111. # apply dataset operations
  112. data1 = ds.GeneratorDataset((lambda: generator_with_type(t)), ["data"])
  113. data1 = data1.batch(4)
  114. i = 0
  115. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  116. golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t)
  117. np.testing.assert_array_equal(item["data"], golden)
  118. i = i + 4
  119. def test_generator_5():
  120. """
  121. Test 1D Generator on different data type
  122. """
  123. logger.info("Test 1D Generator on all data types")
  124. types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32, np.float64]
  125. for t in types:
  126. type_tester(t)
  127. def type_tester_with_type_check(t, c):
  128. logger.info("Test with Type {}".format(t.__name__))
  129. # apply dataset operations
  130. data1 = ds.GeneratorDataset((lambda: generator_with_type(t)), ["data"], column_types=[c])
  131. data1 = data1.batch(4)
  132. i = 0
  133. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  134. golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t)
  135. np.testing.assert_array_equal(item["data"], golden)
  136. i = i + 4
  137. def test_generator_6():
  138. """
  139. Test 1D Generator on different data type with type check
  140. """
  141. logger.info("Test 1D Generator on all data types with type check")
  142. np_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32,
  143. np.float64]
  144. de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32,
  145. mstype.uint64, mstype.float32, mstype.float64]
  146. for i, _ in enumerate(np_types):
  147. type_tester_with_type_check(np_types[i], de_types[i])
  148. def generator_with_type_2c(t):
  149. for i in range(64):
  150. yield (np.array([i], dtype=t), np.array([i], dtype=t))
  151. def type_tester_with_type_check_2c(t, c):
  152. logger.info("Test with Type {}".format(t.__name__))
  153. # apply dataset operations
  154. data1 = ds.GeneratorDataset((lambda: generator_with_type_2c(t)), ["data0", "data1"], column_types=c)
  155. data1 = data1.batch(4)
  156. i = 0
  157. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  158. golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t)
  159. np.testing.assert_array_equal(item["data0"], golden)
  160. i = i + 4
  161. def test_generator_7():
  162. """
  163. Test 2 column Generator on different data type with type check
  164. """
  165. logger.info("Test 2 column Generator on all data types with type check")
  166. np_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32,
  167. np.float64]
  168. de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32,
  169. mstype.uint64, mstype.float32, mstype.float64]
  170. for i, _ in enumerate(np_types):
  171. type_tester_with_type_check_2c(np_types[i], [None, de_types[i]])
  172. def test_generator_8():
  173. """
  174. Test multi column generator with few mapops
  175. """
  176. logger.info("Test multi column generator with mapops to check the order too")
  177. # apply dataset operations
  178. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  179. data1 = data1.map(operations=(lambda x: x * 3), input_columns="col0", output_columns="out0",
  180. num_parallel_workers=2)
  181. data1 = data1.map(operations=(lambda x: (x * 7, x)), input_columns="col1", output_columns=["out1", "out2"],
  182. num_parallel_workers=2, column_order=["out0", "out1", "out2"])
  183. data1 = data1.map(operations=(lambda x: x + 1), input_columns="out2", output_columns="out2",
  184. num_parallel_workers=2)
  185. i = 0
  186. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  187. golden = np.array([i * 3])
  188. np.testing.assert_array_equal(item["out0"], golden)
  189. golden = np.array([[i * 7, (i + 1) * 7], [(i + 2) * 7, (i + 3) * 7]])
  190. np.testing.assert_array_equal(item["out1"], golden)
  191. golden = np.array([[i + 1, i + 2], [i + 3, i + 4]])
  192. np.testing.assert_array_equal(item["out2"], golden)
  193. i = i + 1
  194. def test_generator_9():
  195. """
  196. Test map column order when len(input_columns) == len(output_columns).
  197. """
  198. logger.info("Test map column order when len(input_columns) == len(output_columns).")
  199. # apply dataset operations
  200. data1 = ds.GeneratorDataset(generator_mc(2048), ["image", "label"])
  201. data2 = ds.GeneratorDataset(generator_mc(2048), ["label", "image"])
  202. data1 = data1.map(operations=(lambda x: x * 3), input_columns="label",
  203. num_parallel_workers=4)
  204. data2 = data2.map(operations=(lambda x: x * 3), input_columns="label",
  205. num_parallel_workers=4)
  206. # Expected column order is not changed.
  207. # data1 = data[0] is "image" and data[1] is "label"
  208. # data2 = data[0] is "label" and data[1] is "image"
  209. i = 0
  210. for data1, data2 in zip(data1, data2): # each data is a dictionary
  211. golden = np.array([i])
  212. np.testing.assert_array_equal(data1[0].asnumpy(), golden)
  213. golden = np.array([[i * 3, (i + 1) * 3], [(i + 2) * 3, (i + 3) * 3]])
  214. np.testing.assert_array_equal(data1[1].asnumpy(), golden)
  215. golden = np.array([i * 3])
  216. np.testing.assert_array_equal(data2[0].asnumpy(), golden)
  217. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  218. np.testing.assert_array_equal(data2[1].asnumpy(), golden)
  219. i = i + 1
  220. def test_generator_10():
  221. """
  222. Test map column order when len(input_columns) != len(output_columns).
  223. """
  224. logger.info("Test map column order when len(input_columns) != len(output_columns).")
  225. # apply dataset operations
  226. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  227. data1 = data1.map(operations=(lambda x: (x, x * 5)), input_columns="col1", output_columns=["out1", "out2"],
  228. column_order=['col0', 'out1', 'out2'], num_parallel_workers=2)
  229. # Expected column order is |col0|out1|out2|
  230. i = 0
  231. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  232. golden = np.array([i])
  233. np.testing.assert_array_equal(item[0], golden)
  234. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  235. np.testing.assert_array_equal(item[1], golden)
  236. golden = np.array([[i * 5, (i + 1) * 5], [(i + 2) * 5, (i + 3) * 5]])
  237. np.testing.assert_array_equal(item[2], golden)
  238. i = i + 1
  239. def test_generator_11():
  240. """
  241. Test map column order when len(input_columns) != len(output_columns).
  242. """
  243. logger.info("Test map column order when len(input_columns) != len(output_columns), "
  244. "and column_order drops some columns.")
  245. # apply dataset operations
  246. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  247. data1 = data1.map(operations=(lambda x: (x, x * 5)), input_columns="col1", output_columns=["out1", "out2"],
  248. column_order=['out1', 'out2'], num_parallel_workers=2)
  249. # Expected column order is |out1|out2|
  250. i = 0
  251. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  252. # len should be 2 because col0 is dropped (not included in column_order)
  253. assert len(item) == 2
  254. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  255. np.testing.assert_array_equal(item[0], golden)
  256. golden = np.array([[i * 5, (i + 1) * 5], [(i + 2) * 5, (i + 3) * 5]])
  257. np.testing.assert_array_equal(item[1], golden)
  258. i = i + 1
  259. def test_generator_12():
  260. """
  261. Test map column order when input_columns and output_columns are None.
  262. """
  263. logger.info("Test map column order when input_columns and output_columns are None.")
  264. # apply dataset operations
  265. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  266. data1 = data1.map(operations=(lambda x: (x * 5)), num_parallel_workers=2)
  267. # Expected column order is |col0|col1|
  268. i = 0
  269. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  270. assert len(item) == 2
  271. golden = np.array([i * 5])
  272. np.testing.assert_array_equal(item[0], golden)
  273. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  274. np.testing.assert_array_equal(item[1], golden)
  275. i = i + 1
  276. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  277. data1 = data1.map(operations=(lambda x: (x * 5)), column_order=["col1", "col0"], num_parallel_workers=2)
  278. # Expected column order is |col0|col1|
  279. i = 0
  280. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  281. assert len(item) == 2
  282. golden = np.array([i * 5])
  283. np.testing.assert_array_equal(item[1], golden)
  284. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  285. np.testing.assert_array_equal(item[0], golden)
  286. i = i + 1
  287. def test_generator_13():
  288. """
  289. Test map column order when input_columns is None.
  290. """
  291. logger.info("Test map column order when input_columns is None.")
  292. # apply dataset operations
  293. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  294. data1 = data1.map(operations=(lambda x: (x * 5)), output_columns=["out0"], num_parallel_workers=2)
  295. # Expected column order is |out0|col1|
  296. i = 0
  297. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  298. assert len(item) == 2
  299. golden = np.array([i * 5])
  300. np.testing.assert_array_equal(item[0], golden)
  301. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  302. np.testing.assert_array_equal(item[1], golden)
  303. i = i + 1
  304. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  305. # len should be 2 because col0 is dropped (not included in column_order)
  306. assert len(item) == 2
  307. golden = np.array([i * 5])
  308. np.testing.assert_array_equal(item["out0"], golden)
  309. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  310. np.testing.assert_array_equal(item["col1"], golden)
  311. i = i + 1
  312. def test_generator_14():
  313. """
  314. Test 1D Generator MP + CPP sampler
  315. """
  316. logger.info("Test 1D Generator MP : 0 - 63")
  317. source = [(np.array([x]),) for x in range(256)]
  318. ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(), num_parallel_workers=4).repeat(2)
  319. i = 0
  320. for data in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  321. golden = np.array([i])
  322. np.testing.assert_array_equal(data["data"], golden)
  323. i = i + 1
  324. if i == 256:
  325. i = 0
  326. def test_generator_15():
  327. """
  328. Test 1D Generator MP + Python sampler
  329. """
  330. logger.info("Test 1D Generator MP : 0 - 63")
  331. sampler = [x for x in range(256)]
  332. source = [(np.array([x]),) for x in range(256)]
  333. ds1 = ds.GeneratorDataset(source, ["data"], sampler=sampler, num_parallel_workers=4).repeat(2)
  334. i = 0
  335. for data in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  336. golden = np.array([i])
  337. np.testing.assert_array_equal(data["data"], golden)
  338. i = i + 1
  339. if i == 256:
  340. i = 0
  341. def test_generator_16():
  342. """
  343. Test multi column generator Mp + CPP sampler
  344. """
  345. logger.info("Test multi column generator")
  346. source = [(np.array([x]), np.array([x + 1])) for x in range(256)]
  347. # apply dataset operations
  348. data1 = ds.GeneratorDataset(source, ["col0", "col1"], sampler=ds.SequentialSampler())
  349. i = 0
  350. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  351. golden = np.array([i])
  352. np.testing.assert_array_equal(item["col0"], golden)
  353. golden = np.array([i + 1])
  354. np.testing.assert_array_equal(item["col1"], golden)
  355. i = i + 1
  356. def test_generator_17():
  357. """
  358. Test multi column generator Mp + Python sampler
  359. """
  360. logger.info("Test multi column generator")
  361. sampler = [x for x in range(256)]
  362. source = [(np.array([x]), np.array([x + 1])) for x in range(256)]
  363. # apply dataset operations
  364. data1 = ds.GeneratorDataset(source, ["col0", "col1"], sampler=sampler)
  365. i = 0
  366. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  367. golden = np.array([i])
  368. np.testing.assert_array_equal(item["col0"], golden)
  369. golden = np.array([i + 1])
  370. np.testing.assert_array_equal(item["col1"], golden)
  371. i = i + 1
  372. def test_generator_18():
  373. """
  374. Test multiprocessing flag (same as test 13 with python_multiprocessing=True flag)
  375. """
  376. logger.info("Test map column order when input_columns is None.")
  377. # apply dataset operations
  378. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"], python_multiprocessing=True)
  379. data1 = data1.map(operations=(lambda x: (x * 5)), output_columns=["out0"], num_parallel_workers=2,
  380. python_multiprocessing=True)
  381. # Expected column order is |out0|col1|
  382. i = 0
  383. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  384. assert len(item) == 2
  385. golden = np.array([i * 5])
  386. np.testing.assert_array_equal(item[0], golden)
  387. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  388. np.testing.assert_array_equal(item[1], golden)
  389. i = i + 1
  390. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  391. # len should be 2 because col0 is dropped (not included in column_order)
  392. assert len(item) == 2
  393. golden = np.array([i * 5])
  394. np.testing.assert_array_equal(item["out0"], golden)
  395. def test_generator_error_1():
  396. def generator_np():
  397. for i in range(64):
  398. yield (np.array([{i}]),)
  399. with pytest.raises(RuntimeError) as info:
  400. data1 = ds.GeneratorDataset(generator_np, ["data"])
  401. for _ in data1:
  402. pass
  403. assert "Invalid data type" in str(info.value)
  404. def test_generator_error_2():
  405. def generator_np():
  406. for i in range(64):
  407. yield ({i},)
  408. with pytest.raises(RuntimeError) as info:
  409. data1 = ds.GeneratorDataset(generator_np, ["data"])
  410. for _ in data1:
  411. pass
  412. print("========", str(info.value))
  413. assert "Generator should return a tuple of numpy arrays" in str(info.value)
  414. def test_generator_error_3():
  415. with pytest.raises(ValueError) as info:
  416. # apply dataset operations
  417. data1 = ds.GeneratorDataset(generator_mc(2048), ["label", "image"])
  418. data1 = data1.map(operations=(lambda x: (x, x * 5)), input_columns=["label"], output_columns=["out1", "out2"],
  419. num_parallel_workers=2)
  420. for _ in data1:
  421. pass
  422. assert "When length of input_columns and output_columns are not equal, column_order must be specified." in \
  423. str(info.value)
  424. def test_generator_error_4():
  425. with pytest.raises(RuntimeError) as info:
  426. # apply dataset operations
  427. data1 = ds.GeneratorDataset(generator_mc(2048), ["label", "image"])
  428. data1 = data1.map(operations=(lambda x: (x, x * 5)), input_columns=["label"],
  429. num_parallel_workers=2)
  430. for _ in data1:
  431. pass
  432. assert "Unexpected error. Result of a tensorOp doesn't match output column names" in str(info.value)
  433. def test_generator_sequential_sampler():
  434. source = [(np.array([x]),) for x in range(64)]
  435. ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler())
  436. i = 0
  437. for data in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  438. golden = np.array([i])
  439. np.testing.assert_array_equal(data["data"], golden)
  440. i = i + 1
  441. def test_generator_random_sampler():
  442. source = [(np.array([x]),) for x in range(64)]
  443. ds1 = ds.GeneratorDataset(source, ["data"], shuffle=True)
  444. for _ in ds1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  445. pass
  446. def test_generator_distributed_sampler():
  447. source = [(np.array([x]),) for x in range(64)]
  448. for sid in range(8):
  449. ds1 = ds.GeneratorDataset(source, ["data"], shuffle=False, num_shards=8, shard_id=sid)
  450. i = sid
  451. for data in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  452. golden = np.array([i])
  453. np.testing.assert_array_equal(data["data"], golden)
  454. i = i + 8
  455. def test_generator_num_samples():
  456. source = [(np.array([x]),) for x in range(64)]
  457. num_samples = 32
  458. ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(num_samples=num_samples))
  459. ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(32)], num_samples=num_samples)
  460. ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples=num_samples)
  461. count = 0
  462. for _ in ds1.create_dict_iterator(num_epochs=1):
  463. count = count + 1
  464. assert count == num_samples
  465. count = 0
  466. for _ in ds2.create_dict_iterator(num_epochs=1):
  467. count = count + 1
  468. assert count == num_samples
  469. count = 0
  470. for _ in ds3.create_dict_iterator(num_epochs=1):
  471. count = count + 1
  472. assert count == num_samples
  473. def test_generator_num_samples_underflow():
  474. source = [(np.array([x]),) for x in range(64)]
  475. num_samples = 256
  476. ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(64)], num_samples=num_samples)
  477. ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples=num_samples)
  478. count = 0
  479. for _ in ds2.create_dict_iterator(num_epochs=1):
  480. count = count + 1
  481. assert count == 64
  482. count = 0
  483. for _ in ds3.create_dict_iterator(num_epochs=1):
  484. count = count + 1
  485. assert count == 64
  486. def type_tester_with_type_check_2c_schema(t, c):
  487. logger.info("Test with Type {}".format(t.__name__))
  488. schema = ds.Schema()
  489. schema.add_column("data0", c[0])
  490. schema.add_column("data1", c[1])
  491. # apply dataset operations
  492. data1 = ds.GeneratorDataset((lambda: generator_with_type_2c(t)), schema=schema)
  493. data1 = data1.batch(4)
  494. i = 0
  495. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  496. golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t)
  497. np.testing.assert_array_equal(item["data0"], golden)
  498. i = i + 4
  499. def test_generator_schema():
  500. """
  501. Test 2 column Generator on different data type with type check with schema input
  502. """
  503. logger.info("Test 2 column Generator on all data types with type check")
  504. np_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32,
  505. np.float64]
  506. de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32,
  507. mstype.uint64, mstype.float32, mstype.float64]
  508. for i, _ in enumerate(np_types):
  509. type_tester_with_type_check_2c_schema(np_types[i], [de_types[i], de_types[i]])
  510. def test_generator_dataset_size_0():
  511. """
  512. Test GeneratorDataset get_dataset_size by iterator method.
  513. """
  514. logger.info("Test 1D Generator : 0 - 63 get_dataset_size")
  515. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  516. data_size = data1.get_dataset_size()
  517. num_rows = 0
  518. for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  519. num_rows = num_rows + 1
  520. assert data_size == num_rows
  521. def test_generator_dataset_size_1():
  522. """
  523. Test GeneratorDataset get_dataset_size by __len__ method.
  524. """
  525. logger.info("Test DatasetGenerator get_dataset_size")
  526. dataset_generator = DatasetGenerator()
  527. data1 = ds.GeneratorDataset(dataset_generator, ["data"])
  528. data_size = data1.get_dataset_size()
  529. num_rows = 0
  530. for _ in data1.create_dict_iterator(num_epochs=1):
  531. num_rows = num_rows + 1
  532. assert data_size == num_rows
  533. def test_generator_dataset_size_2():
  534. """
  535. Test GeneratorDataset + repeat get_dataset_size
  536. """
  537. logger.info("Test 1D Generator + repeat get_dataset_size")
  538. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  539. data1 = data1.repeat(2)
  540. data_size = data1.get_dataset_size()
  541. num_rows = 0
  542. for _ in data1.create_dict_iterator(num_epochs=1):
  543. num_rows = num_rows + 1
  544. assert data_size == num_rows
  545. def test_generator_dataset_size_3():
  546. """
  547. Test GeneratorDataset + batch get_dataset_size
  548. """
  549. logger.info("Test 1D Generator + batch get_dataset_size")
  550. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  551. data1 = data1.batch(4)
  552. data_size = data1.get_dataset_size()
  553. num_rows = 0
  554. for _ in data1.create_dict_iterator(num_epochs=1):
  555. num_rows += 1
  556. assert data_size == num_rows
  557. def test_generator_dataset_size_4():
  558. """
  559. Test GeneratorDataset + num_shards
  560. """
  561. logger.info("Test 1D Generator : 0 - 63 + num_shards get_dataset_size")
  562. dataset_generator = DatasetGenerator()
  563. data1 = ds.GeneratorDataset(dataset_generator, ["data"], num_shards=3, shard_id=0)
  564. data_size = data1.get_dataset_size()
  565. num_rows = 0
  566. for _ in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  567. num_rows = num_rows + 1
  568. assert data_size == num_rows
  569. def test_generator_dataset_size_5():
  570. """
  571. Test get_dataset_size after create_dict_iterator
  572. """
  573. logger.info("Test get_dataset_size after create_dict_iterator")
  574. dataset_generator = DatasetGenerator()
  575. data1 = ds.GeneratorDataset(dataset_generator, ["data"], num_shards=3, shard_id=0)
  576. num_rows = 0
  577. for _ in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  578. num_rows = num_rows + 1
  579. data_size = data1.get_dataset_size()
  580. assert data_size == num_rows
  581. def manual_test_generator_keyboard_interrupt():
  582. """
  583. Test keyboard_interrupt
  584. """
  585. logger.info("Test 1D Generator MP : 0 - 63")
  586. class MyDS():
  587. def __getitem__(self, item):
  588. while True:
  589. pass
  590. def __len__(self):
  591. return 1024
  592. ds1 = ds.GeneratorDataset(MyDS(), ["data"], num_parallel_workers=4).repeat(2)
  593. for _ in ds1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  594. pass
  595. def test_explicit_deepcopy():
  596. """
  597. Test explicit_deepcopy
  598. """
  599. logger.info("Test explicit_deepcopy")
  600. ds1 = ds.NumpySlicesDataset([1, 2], shuffle=False)
  601. ds2 = copy.deepcopy(ds1)
  602. for d1, d2 in zip(ds1, ds2):
  603. assert d1 == d2
  604. if __name__ == "__main__":
  605. test_generator_0()
  606. test_generator_1()
  607. test_generator_2()
  608. test_generator_3()
  609. test_generator_4()
  610. test_generator_5()
  611. test_generator_6()
  612. test_generator_7()
  613. test_generator_8()
  614. test_generator_9()
  615. test_generator_10()
  616. test_generator_11()
  617. test_generator_12()
  618. test_generator_13()
  619. test_generator_14()
  620. test_generator_15()
  621. test_generator_16()
  622. test_generator_17()
  623. test_generator_18()
  624. test_generator_error_1()
  625. test_generator_error_2()
  626. test_generator_error_3()
  627. test_generator_error_4()
  628. test_generator_sequential_sampler()
  629. test_generator_distributed_sampler()
  630. test_generator_random_sampler()
  631. test_generator_num_samples()
  632. test_generator_num_samples_underflow()
  633. test_generator_schema()
  634. test_generator_dataset_size_0()
  635. test_generator_dataset_size_1()
  636. test_generator_dataset_size_2()
  637. test_generator_dataset_size_3()
  638. test_generator_dataset_size_4()
  639. test_generator_dataset_size_5()
  640. test_explicit_deepcopy()