You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_generator.py 26 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.common.dtype as mstype
  18. import mindspore.dataset as ds
  19. from mindspore import log as logger
  20. # Generate 1d int numpy array from 0 - 63
  21. def generator_1d():
  22. for i in range(64):
  23. yield (np.array([i]),)
  24. class DatasetGenerator:
  25. def __init__(self):
  26. pass
  27. def __getitem__(self, item):
  28. return (np.array([item]),)
  29. def __len__(self):
  30. return 10
  31. def test_generator_0():
  32. """
  33. Test 1D Generator
  34. """
  35. logger.info("Test 1D Generator : 0 - 63")
  36. # apply dataset operations
  37. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  38. i = 0
  39. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  40. golden = np.array([i])
  41. np.testing.assert_array_equal(item["data"], golden)
  42. i = i + 1
  43. # Generate md int numpy array from [[0, 1], [2, 3]] to [[63, 64], [65, 66]]
  44. def generator_md():
  45. for i in range(64):
  46. yield (np.array([[i, i + 1], [i + 2, i + 3]]),)
  47. def test_generator_1():
  48. """
  49. Test MD Generator
  50. """
  51. logger.info("Test MD Generator : 0 - 63, with shape [2, 2]")
  52. # apply dataset operations
  53. data1 = ds.GeneratorDataset(generator_md, ["data"])
  54. i = 0
  55. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  56. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  57. np.testing.assert_array_equal(item["data"], golden)
  58. i = i + 1
  59. # Generate two columns, the first column is from Generator1D, the second column is from GeneratorMD
  60. def generator_mc(maxid=64):
  61. for i in range(maxid):
  62. yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]]))
  63. def test_generator_2():
  64. """
  65. Test multi column generator
  66. """
  67. logger.info("Test multi column generator")
  68. # apply dataset operations
  69. data1 = ds.GeneratorDataset(generator_mc, ["col0", "col1"])
  70. i = 0
  71. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  72. golden = np.array([i])
  73. np.testing.assert_array_equal(item["col0"], golden)
  74. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  75. np.testing.assert_array_equal(item["col1"], golden)
  76. i = i + 1
  77. def test_generator_3():
  78. """
  79. Test 1D Generator + repeat(4)
  80. """
  81. logger.info("Test 1D Generator : 0 - 63 + Repeat(4)")
  82. # apply dataset operations
  83. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  84. data1 = data1.repeat(4)
  85. i = 0
  86. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  87. golden = np.array([i])
  88. np.testing.assert_array_equal(item["data"], golden)
  89. i = i + 1
  90. if i == 64:
  91. i = 0
  92. def test_generator_4():
  93. """
  94. Test fixed size 1D Generator + batch
  95. """
  96. logger.info("Test 1D Generator : 0 - 63 + batch(4)")
  97. # apply dataset operations
  98. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  99. data1 = data1.batch(4)
  100. i = 0
  101. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  102. golden = np.array([[i], [i + 1], [i + 2], [i + 3]])
  103. np.testing.assert_array_equal(item["data"], golden)
  104. i = i + 4
  105. def generator_with_type(t):
  106. for i in range(64):
  107. yield (np.array([i], dtype=t),)
  108. def type_tester(t):
  109. logger.info("Test with Type {}".format(t.__name__))
  110. # apply dataset operations
  111. data1 = ds.GeneratorDataset((lambda: generator_with_type(t)), ["data"])
  112. data1 = data1.batch(4)
  113. i = 0
  114. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  115. golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t)
  116. np.testing.assert_array_equal(item["data"], golden)
  117. i = i + 4
  118. def test_generator_5():
  119. """
  120. Test 1D Generator on different data type
  121. """
  122. logger.info("Test 1D Generator on all data types")
  123. types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32, np.float64]
  124. for t in types:
  125. type_tester(t)
  126. def type_tester_with_type_check(t, c):
  127. logger.info("Test with Type {}".format(t.__name__))
  128. # apply dataset operations
  129. data1 = ds.GeneratorDataset((lambda: generator_with_type(t)), ["data"], column_types=[c])
  130. data1 = data1.batch(4)
  131. i = 0
  132. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  133. golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t)
  134. np.testing.assert_array_equal(item["data"], golden)
  135. i = i + 4
  136. def test_generator_6():
  137. """
  138. Test 1D Generator on different data type with type check
  139. """
  140. logger.info("Test 1D Generator on all data types with type check")
  141. np_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32,
  142. np.float64]
  143. de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32,
  144. mstype.uint64, mstype.float32, mstype.float64]
  145. for i, _ in enumerate(np_types):
  146. type_tester_with_type_check(np_types[i], de_types[i])
  147. def generator_with_type_2c(t):
  148. for i in range(64):
  149. yield (np.array([i], dtype=t), np.array([i], dtype=t))
  150. def type_tester_with_type_check_2c(t, c):
  151. logger.info("Test with Type {}".format(t.__name__))
  152. # apply dataset operations
  153. data1 = ds.GeneratorDataset((lambda: generator_with_type_2c(t)), ["data0", "data1"], column_types=c)
  154. data1 = data1.batch(4)
  155. i = 0
  156. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  157. golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t)
  158. np.testing.assert_array_equal(item["data0"], golden)
  159. i = i + 4
  160. def test_generator_7():
  161. """
  162. Test 2 column Generator on different data type with type check
  163. """
  164. logger.info("Test 2 column Generator on all data types with type check")
  165. np_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32,
  166. np.float64]
  167. de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32,
  168. mstype.uint64, mstype.float32, mstype.float64]
  169. for i, _ in enumerate(np_types):
  170. type_tester_with_type_check_2c(np_types[i], [None, de_types[i]])
  171. def test_generator_8():
  172. """
  173. Test multi column generator with few mapops
  174. """
  175. logger.info("Test multi column generator with mapops to check the order too")
  176. # apply dataset operations
  177. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  178. data1 = data1.map(operations=(lambda x: x * 3), input_columns="col0", output_columns="out0",
  179. num_parallel_workers=2)
  180. data1 = data1.map(operations=(lambda x: (x * 7, x)), input_columns="col1", output_columns=["out1", "out2"],
  181. num_parallel_workers=2, column_order=["out0", "out1", "out2"])
  182. data1 = data1.map(operations=(lambda x: x + 1), input_columns="out2", output_columns="out2",
  183. num_parallel_workers=2)
  184. i = 0
  185. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  186. golden = np.array([i * 3])
  187. np.testing.assert_array_equal(item["out0"], golden)
  188. golden = np.array([[i * 7, (i + 1) * 7], [(i + 2) * 7, (i + 3) * 7]])
  189. np.testing.assert_array_equal(item["out1"], golden)
  190. golden = np.array([[i + 1, i + 2], [i + 3, i + 4]])
  191. np.testing.assert_array_equal(item["out2"], golden)
  192. i = i + 1
  193. def test_generator_9():
  194. """
  195. Test map column order when len(input_columns) == len(output_columns).
  196. """
  197. logger.info("Test map column order when len(input_columns) == len(output_columns).")
  198. # apply dataset operations
  199. data1 = ds.GeneratorDataset(generator_mc(2048), ["image", "label"])
  200. data2 = ds.GeneratorDataset(generator_mc(2048), ["label", "image"])
  201. data1 = data1.map(operations=(lambda x: x * 3), input_columns="label",
  202. num_parallel_workers=4)
  203. data2 = data2.map(operations=(lambda x: x * 3), input_columns="label",
  204. num_parallel_workers=4)
  205. # Expected column order is not changed.
  206. # data1 = data[0] is "image" and data[1] is "label"
  207. # data2 = data[0] is "label" and data[1] is "image"
  208. i = 0
  209. for data1, data2 in zip(data1, data2): # each data is a dictionary
  210. golden = np.array([i])
  211. np.testing.assert_array_equal(data1[0].asnumpy(), golden)
  212. golden = np.array([[i * 3, (i + 1) * 3], [(i + 2) * 3, (i + 3) * 3]])
  213. np.testing.assert_array_equal(data1[1].asnumpy(), golden)
  214. golden = np.array([i * 3])
  215. np.testing.assert_array_equal(data2[0].asnumpy(), golden)
  216. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  217. np.testing.assert_array_equal(data2[1].asnumpy(), golden)
  218. i = i + 1
  219. def test_generator_10():
  220. """
  221. Test map column order when len(input_columns) != len(output_columns).
  222. """
  223. logger.info("Test map column order when len(input_columns) != len(output_columns).")
  224. # apply dataset operations
  225. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  226. data1 = data1.map(operations=(lambda x: (x, x * 5)), input_columns="col1", output_columns=["out1", "out2"],
  227. column_order=['col0', 'out1', 'out2'], num_parallel_workers=2)
  228. # Expected column order is |col0|out1|out2|
  229. i = 0
  230. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  231. golden = np.array([i])
  232. np.testing.assert_array_equal(item[0], golden)
  233. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  234. np.testing.assert_array_equal(item[1], golden)
  235. golden = np.array([[i * 5, (i + 1) * 5], [(i + 2) * 5, (i + 3) * 5]])
  236. np.testing.assert_array_equal(item[2], golden)
  237. i = i + 1
  238. def test_generator_11():
  239. """
  240. Test map column order when len(input_columns) != len(output_columns).
  241. """
  242. logger.info("Test map column order when len(input_columns) != len(output_columns), "
  243. "and column_order drops some columns.")
  244. # apply dataset operations
  245. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  246. data1 = data1.map(operations=(lambda x: (x, x * 5)), input_columns="col1", output_columns=["out1", "out2"],
  247. column_order=['out1', 'out2'], num_parallel_workers=2)
  248. # Expected column order is |out1|out2|
  249. i = 0
  250. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  251. # len should be 2 because col0 is dropped (not included in column_order)
  252. assert len(item) == 2
  253. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  254. np.testing.assert_array_equal(item[0], golden)
  255. golden = np.array([[i * 5, (i + 1) * 5], [(i + 2) * 5, (i + 3) * 5]])
  256. np.testing.assert_array_equal(item[1], golden)
  257. i = i + 1
  258. def test_generator_12():
  259. """
  260. Test map column order when input_columns and output_columns are None.
  261. """
  262. logger.info("Test map column order when input_columns and output_columns are None.")
  263. # apply dataset operations
  264. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  265. data1 = data1.map(operations=(lambda x: (x * 5)), num_parallel_workers=2)
  266. # Expected column order is |col0|col1|
  267. i = 0
  268. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  269. assert len(item) == 2
  270. golden = np.array([i * 5])
  271. np.testing.assert_array_equal(item[0], golden)
  272. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  273. np.testing.assert_array_equal(item[1], golden)
  274. i = i + 1
  275. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  276. data1 = data1.map(operations=(lambda x: (x * 5)), column_order=["col1", "col0"], num_parallel_workers=2)
  277. # Expected column order is |col0|col1|
  278. i = 0
  279. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  280. assert len(item) == 2
  281. golden = np.array([i * 5])
  282. np.testing.assert_array_equal(item[1], golden)
  283. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  284. np.testing.assert_array_equal(item[0], golden)
  285. i = i + 1
  286. def test_generator_13():
  287. """
  288. Test map column order when input_columns is None.
  289. """
  290. logger.info("Test map column order when input_columns is None.")
  291. # apply dataset operations
  292. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  293. data1 = data1.map(operations=(lambda x: (x * 5)), output_columns=["out0"], num_parallel_workers=2)
  294. # Expected column order is |out0|col1|
  295. i = 0
  296. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  297. assert len(item) == 2
  298. golden = np.array([i * 5])
  299. np.testing.assert_array_equal(item[0], golden)
  300. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  301. np.testing.assert_array_equal(item[1], golden)
  302. i = i + 1
  303. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  304. # len should be 2 because col0 is dropped (not included in column_order)
  305. assert len(item) == 2
  306. golden = np.array([i * 5])
  307. np.testing.assert_array_equal(item["out0"], golden)
  308. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  309. np.testing.assert_array_equal(item["col1"], golden)
  310. i = i + 1
  311. def test_generator_14():
  312. """
  313. Test 1D Generator MP + CPP sampler
  314. """
  315. logger.info("Test 1D Generator MP : 0 - 63")
  316. source = [(np.array([x]),) for x in range(256)]
  317. ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(), num_parallel_workers=4).repeat(2)
  318. i = 0
  319. for data in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  320. golden = np.array([i])
  321. np.testing.assert_array_equal(data["data"], golden)
  322. i = i + 1
  323. if i == 256:
  324. i = 0
  325. def test_generator_15():
  326. """
  327. Test 1D Generator MP + Python sampler
  328. """
  329. logger.info("Test 1D Generator MP : 0 - 63")
  330. sampler = [x for x in range(256)]
  331. source = [(np.array([x]),) for x in range(256)]
  332. ds1 = ds.GeneratorDataset(source, ["data"], sampler=sampler, num_parallel_workers=4).repeat(2)
  333. i = 0
  334. for data in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  335. golden = np.array([i])
  336. np.testing.assert_array_equal(data["data"], golden)
  337. i = i + 1
  338. if i == 256:
  339. i = 0
  340. def test_generator_16():
  341. """
  342. Test multi column generator Mp + CPP sampler
  343. """
  344. logger.info("Test multi column generator")
  345. source = [(np.array([x]), np.array([x + 1])) for x in range(256)]
  346. # apply dataset operations
  347. data1 = ds.GeneratorDataset(source, ["col0", "col1"], sampler=ds.SequentialSampler())
  348. i = 0
  349. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  350. golden = np.array([i])
  351. np.testing.assert_array_equal(item["col0"], golden)
  352. golden = np.array([i + 1])
  353. np.testing.assert_array_equal(item["col1"], golden)
  354. i = i + 1
  355. def test_generator_17():
  356. """
  357. Test multi column generator Mp + Python sampler
  358. """
  359. logger.info("Test multi column generator")
  360. sampler = [x for x in range(256)]
  361. source = [(np.array([x]), np.array([x + 1])) for x in range(256)]
  362. # apply dataset operations
  363. data1 = ds.GeneratorDataset(source, ["col0", "col1"], sampler=sampler)
  364. i = 0
  365. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  366. golden = np.array([i])
  367. np.testing.assert_array_equal(item["col0"], golden)
  368. golden = np.array([i + 1])
  369. np.testing.assert_array_equal(item["col1"], golden)
  370. i = i + 1
  371. def test_generator_error_1():
  372. def generator_np():
  373. for i in range(64):
  374. yield (np.array([{i}]),)
  375. with pytest.raises(RuntimeError) as info:
  376. data1 = ds.GeneratorDataset(generator_np, ["data"])
  377. for _ in data1:
  378. pass
  379. assert "Invalid data type" in str(info.value)
  380. def test_generator_error_2():
  381. def generator_np():
  382. for i in range(64):
  383. yield ({i},)
  384. with pytest.raises(RuntimeError) as info:
  385. data1 = ds.GeneratorDataset(generator_np, ["data"])
  386. for _ in data1:
  387. pass
  388. print("========", str(info.value))
  389. assert "Generator should return a tuple of numpy arrays" in str(info.value)
  390. def test_generator_error_3():
  391. with pytest.raises(ValueError) as info:
  392. # apply dataset operations
  393. data1 = ds.GeneratorDataset(generator_mc(2048), ["label", "image"])
  394. data1 = data1.map(operations=(lambda x: (x, x * 5)), input_columns=["label"], output_columns=["out1", "out2"],
  395. num_parallel_workers=2)
  396. for _ in data1:
  397. pass
  398. assert "When (len(input_columns) != len(output_columns)), column_order must be specified." in str(info.value)
  399. def test_generator_error_4():
  400. with pytest.raises(RuntimeError) as info:
  401. # apply dataset operations
  402. data1 = ds.GeneratorDataset(generator_mc(2048), ["label", "image"])
  403. data1 = data1.map(operations=(lambda x: (x, x * 5)), input_columns=["label"],
  404. num_parallel_workers=2)
  405. for _ in data1:
  406. pass
  407. assert "Unexpected error. Result of a tensorOp doesn't match output column names" in str(info.value)
  408. def test_generator_sequential_sampler():
  409. source = [(np.array([x]),) for x in range(64)]
  410. ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler())
  411. i = 0
  412. for data in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  413. golden = np.array([i])
  414. np.testing.assert_array_equal(data["data"], golden)
  415. i = i + 1
  416. def test_generator_random_sampler():
  417. source = [(np.array([x]),) for x in range(64)]
  418. ds1 = ds.GeneratorDataset(source, ["data"], shuffle=True)
  419. for _ in ds1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  420. pass
  421. def test_generator_distributed_sampler():
  422. source = [(np.array([x]),) for x in range(64)]
  423. for sid in range(8):
  424. ds1 = ds.GeneratorDataset(source, ["data"], shuffle=False, num_shards=8, shard_id=sid)
  425. i = sid
  426. for data in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  427. golden = np.array([i])
  428. np.testing.assert_array_equal(data["data"], golden)
  429. i = i + 8
  430. def test_generator_num_samples():
  431. source = [(np.array([x]),) for x in range(64)]
  432. num_samples = 32
  433. ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(num_samples=num_samples))
  434. ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(32)], num_samples=num_samples)
  435. ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples=num_samples)
  436. count = 0
  437. for _ in ds1.create_dict_iterator(num_epochs=1):
  438. count = count + 1
  439. assert count == num_samples
  440. count = 0
  441. for _ in ds2.create_dict_iterator(num_epochs=1):
  442. count = count + 1
  443. assert count == num_samples
  444. count = 0
  445. for _ in ds3.create_dict_iterator(num_epochs=1):
  446. count = count + 1
  447. assert count == num_samples
  448. def test_generator_num_samples_underflow():
  449. source = [(np.array([x]),) for x in range(64)]
  450. num_samples = 256
  451. ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(64)], num_samples=num_samples)
  452. ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples=num_samples)
  453. count = 0
  454. for _ in ds2.create_dict_iterator(num_epochs=1):
  455. count = count + 1
  456. assert count == 64
  457. count = 0
  458. for _ in ds3.create_dict_iterator(num_epochs=1):
  459. count = count + 1
  460. assert count == 64
  461. def type_tester_with_type_check_2c_schema(t, c):
  462. logger.info("Test with Type {}".format(t.__name__))
  463. schema = ds.Schema()
  464. schema.add_column("data0", c[0])
  465. schema.add_column("data1", c[1])
  466. # apply dataset operations
  467. data1 = ds.GeneratorDataset((lambda: generator_with_type_2c(t)), schema=schema)
  468. data1 = data1.batch(4)
  469. i = 0
  470. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  471. golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t)
  472. np.testing.assert_array_equal(item["data0"], golden)
  473. i = i + 4
  474. def test_generator_schema():
  475. """
  476. Test 2 column Generator on different data type with type check with schema input
  477. """
  478. logger.info("Test 2 column Generator on all data types with type check")
  479. np_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32,
  480. np.float64]
  481. de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32,
  482. mstype.uint64, mstype.float32, mstype.float64]
  483. for i, _ in enumerate(np_types):
  484. type_tester_with_type_check_2c_schema(np_types[i], [de_types[i], de_types[i]])
  485. def test_generator_dataset_size_0():
  486. """
  487. Test GeneratorDataset get_dataset_size by iterator method.
  488. """
  489. logger.info("Test 1D Generator : 0 - 63 get_dataset_size")
  490. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  491. data_size = data1.get_dataset_size()
  492. num_rows = 0
  493. for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  494. num_rows = num_rows + 1
  495. assert data_size == num_rows
  496. def test_generator_dataset_size_1():
  497. """
  498. Test GeneratorDataset get_dataset_size by __len__ method.
  499. """
  500. logger.info("Test DatasetGenerator get_dataset_size")
  501. dataset_generator = DatasetGenerator()
  502. data1 = ds.GeneratorDataset(dataset_generator, ["data"])
  503. data_size = data1.get_dataset_size()
  504. num_rows = 0
  505. for _ in data1.create_dict_iterator(num_epochs=1):
  506. num_rows = num_rows + 1
  507. assert data_size == num_rows
  508. def test_generator_dataset_size_2():
  509. """
  510. Test GeneratorDataset + repeat get_dataset_size
  511. """
  512. logger.info("Test 1D Generator + repeat get_dataset_size")
  513. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  514. data1 = data1.repeat(2)
  515. data_size = data1.get_dataset_size()
  516. num_rows = 0
  517. for _ in data1.create_dict_iterator(num_epochs=1):
  518. num_rows = num_rows + 1
  519. assert data_size == num_rows
  520. def test_generator_dataset_size_3():
  521. """
  522. Test GeneratorDataset + batch get_dataset_size
  523. """
  524. logger.info("Test 1D Generator + batch get_dataset_size")
  525. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  526. data1 = data1.batch(4)
  527. data_size = data1.get_dataset_size()
  528. num_rows = 0
  529. for _ in data1.create_dict_iterator(num_epochs=1):
  530. num_rows += 1
  531. assert data_size == num_rows
  532. def test_generator_dataset_size_4():
  533. """
  534. Test GeneratorDataset + num_shards
  535. """
  536. logger.info("Test 1D Generator : 0 - 63 + num_shards get_dataset_size")
  537. dataset_generator = DatasetGenerator()
  538. data1 = ds.GeneratorDataset(dataset_generator, ["data"], num_shards=3, shard_id=0)
  539. data_size = data1.get_dataset_size()
  540. num_rows = 0
  541. for _ in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  542. num_rows = num_rows + 1
  543. assert data_size == num_rows
  544. def test_generator_dataset_size_5():
  545. """
  546. Test get_dataset_size after create_dict_iterator
  547. """
  548. logger.info("Test get_dataset_size after create_dict_iterator")
  549. dataset_generator = DatasetGenerator()
  550. data1 = ds.GeneratorDataset(dataset_generator, ["data"], num_shards=3, shard_id=0)
  551. num_rows = 0
  552. for _ in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  553. num_rows = num_rows + 1
  554. data_size = data1.get_dataset_size()
  555. assert data_size == num_rows
  556. def manual_test_generator_keyboard_interrupt():
  557. """
  558. Test keyboard_interrupt
  559. """
  560. logger.info("Test 1D Generator MP : 0 - 63")
  561. class MyDS():
  562. def __getitem__(self, item):
  563. while True:
  564. pass
  565. def __len__(self):
  566. return 1024
  567. ds1 = ds.GeneratorDataset(MyDS(), ["data"], num_parallel_workers=4).repeat(2)
  568. for _ in ds1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  569. pass
  570. if __name__ == "__main__":
  571. test_generator_0()
  572. test_generator_1()
  573. test_generator_2()
  574. test_generator_3()
  575. test_generator_4()
  576. test_generator_5()
  577. test_generator_6()
  578. test_generator_7()
  579. test_generator_8()
  580. test_generator_9()
  581. test_generator_10()
  582. test_generator_11()
  583. test_generator_12()
  584. test_generator_13()
  585. test_generator_14()
  586. test_generator_15()
  587. test_generator_16()
  588. test_generator_17()
  589. test_generator_error_1()
  590. test_generator_error_2()
  591. test_generator_error_3()
  592. test_generator_error_4()
  593. test_generator_sequential_sampler()
  594. test_generator_distributed_sampler()
  595. test_generator_random_sampler()
  596. test_generator_num_samples()
  597. test_generator_num_samples_underflow()
  598. test_generator_schema()
  599. test_generator_dataset_size_0()
  600. test_generator_dataset_size_1()
  601. test_generator_dataset_size_2()
  602. test_generator_dataset_size_3()
  603. test_generator_dataset_size_4()
  604. test_generator_dataset_size_5()