You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_generator.py 27 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import copy
  16. import numpy as np
  17. import pytest
  18. import mindspore.common.dtype as mstype
  19. import mindspore.dataset as ds
  20. from mindspore import log as logger
  21. # Generate 1d int numpy array from 0 - 63
  22. def generator_1d():
  23. for i in range(64):
  24. yield (np.array([i]),)
  25. class DatasetGenerator:
  26. def __init__(self):
  27. pass
  28. def __getitem__(self, item):
  29. return (np.array([item]),)
  30. def __len__(self):
  31. return 10
  32. def test_generator_0():
  33. """
  34. Test 1D Generator
  35. """
  36. logger.info("Test 1D Generator : 0 - 63")
  37. # apply dataset operations
  38. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  39. i = 0
  40. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  41. golden = np.array([i])
  42. np.testing.assert_array_equal(item["data"], golden)
  43. i = i + 1
  44. # Generate md int numpy array from [[0, 1], [2, 3]] to [[63, 64], [65, 66]]
  45. def generator_md():
  46. for i in range(64):
  47. yield (np.array([[i, i + 1], [i + 2, i + 3]]),)
  48. def test_generator_1():
  49. """
  50. Test MD Generator
  51. """
  52. logger.info("Test MD Generator : 0 - 63, with shape [2, 2]")
  53. # apply dataset operations
  54. data1 = ds.GeneratorDataset(generator_md, ["data"])
  55. i = 0
  56. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  57. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  58. np.testing.assert_array_equal(item["data"], golden)
  59. i = i + 1
  60. # Generate two columns, the first column is from Generator1D, the second column is from GeneratorMD
  61. def generator_mc(maxid=64):
  62. for i in range(maxid):
  63. yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]]))
  64. def test_generator_2():
  65. """
  66. Test multi column generator
  67. """
  68. logger.info("Test multi column generator")
  69. # apply dataset operations
  70. data1 = ds.GeneratorDataset(generator_mc, ["col0", "col1"])
  71. i = 0
  72. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  73. golden = np.array([i])
  74. np.testing.assert_array_equal(item["col0"], golden)
  75. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  76. np.testing.assert_array_equal(item["col1"], golden)
  77. i = i + 1
  78. def test_generator_3():
  79. """
  80. Test 1D Generator + repeat(4)
  81. """
  82. logger.info("Test 1D Generator : 0 - 63 + Repeat(4)")
  83. # apply dataset operations
  84. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  85. data1 = data1.repeat(4)
  86. i = 0
  87. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  88. golden = np.array([i])
  89. np.testing.assert_array_equal(item["data"], golden)
  90. i = i + 1
  91. if i == 64:
  92. i = 0
  93. def test_generator_4():
  94. """
  95. Test fixed size 1D Generator + batch
  96. """
  97. logger.info("Test 1D Generator : 0 - 63 + batch(4)")
  98. # apply dataset operations
  99. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  100. data1 = data1.batch(4)
  101. i = 0
  102. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  103. golden = np.array([[i], [i + 1], [i + 2], [i + 3]])
  104. np.testing.assert_array_equal(item["data"], golden)
  105. i = i + 4
  106. def generator_with_type(t):
  107. for i in range(64):
  108. yield (np.array([i], dtype=t),)
  109. def type_tester(t):
  110. logger.info("Test with Type {}".format(t.__name__))
  111. # apply dataset operations
  112. data1 = ds.GeneratorDataset((lambda: generator_with_type(t)), ["data"])
  113. data1 = data1.batch(4)
  114. i = 0
  115. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  116. golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t)
  117. np.testing.assert_array_equal(item["data"], golden)
  118. i = i + 4
  119. def test_generator_5():
  120. """
  121. Test 1D Generator on different data type
  122. """
  123. logger.info("Test 1D Generator on all data types")
  124. types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32, np.float64]
  125. for t in types:
  126. type_tester(t)
  127. def type_tester_with_type_check(t, c):
  128. logger.info("Test with Type {}".format(t.__name__))
  129. # apply dataset operations
  130. data1 = ds.GeneratorDataset((lambda: generator_with_type(t)), ["data"], column_types=[c])
  131. data1 = data1.batch(4)
  132. i = 0
  133. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  134. golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t)
  135. np.testing.assert_array_equal(item["data"], golden)
  136. i = i + 4
  137. def test_generator_6():
  138. """
  139. Test 1D Generator on different data type with type check
  140. """
  141. logger.info("Test 1D Generator on all data types with type check")
  142. np_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32,
  143. np.float64]
  144. de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32,
  145. mstype.uint64, mstype.float32, mstype.float64]
  146. for i, _ in enumerate(np_types):
  147. type_tester_with_type_check(np_types[i], de_types[i])
  148. def generator_with_type_2c(t):
  149. for i in range(64):
  150. yield (np.array([i], dtype=t), np.array([i], dtype=t))
  151. def type_tester_with_type_check_2c(t, c):
  152. logger.info("Test with Type {}".format(t.__name__))
  153. # apply dataset operations
  154. data1 = ds.GeneratorDataset((lambda: generator_with_type_2c(t)), ["data0", "data1"], column_types=c)
  155. data1 = data1.batch(4)
  156. i = 0
  157. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  158. golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t)
  159. np.testing.assert_array_equal(item["data0"], golden)
  160. i = i + 4
  161. def test_generator_7():
  162. """
  163. Test 2 column Generator on different data type with type check
  164. """
  165. logger.info("Test 2 column Generator on all data types with type check")
  166. np_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32,
  167. np.float64]
  168. de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32,
  169. mstype.uint64, mstype.float32, mstype.float64]
  170. for i, _ in enumerate(np_types):
  171. type_tester_with_type_check_2c(np_types[i], [None, de_types[i]])
  172. def test_generator_8():
  173. """
  174. Test multi column generator with few mapops
  175. """
  176. logger.info("Test multi column generator with mapops to check the order too")
  177. # apply dataset operations
  178. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  179. data1 = data1.map(operations=(lambda x: x * 3), input_columns="col0", output_columns="out0",
  180. num_parallel_workers=2)
  181. data1 = data1.map(operations=(lambda x: (x * 7, x)), input_columns="col1", output_columns=["out1", "out2"],
  182. num_parallel_workers=2, column_order=["out0", "out1", "out2"])
  183. data1 = data1.map(operations=(lambda x: x + 1), input_columns="out2", output_columns="out2",
  184. num_parallel_workers=2)
  185. i = 0
  186. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  187. golden = np.array([i * 3])
  188. np.testing.assert_array_equal(item["out0"], golden)
  189. golden = np.array([[i * 7, (i + 1) * 7], [(i + 2) * 7, (i + 3) * 7]])
  190. np.testing.assert_array_equal(item["out1"], golden)
  191. golden = np.array([[i + 1, i + 2], [i + 3, i + 4]])
  192. np.testing.assert_array_equal(item["out2"], golden)
  193. i = i + 1
  194. def test_generator_9():
  195. """
  196. Test map column order when len(input_columns) == len(output_columns).
  197. """
  198. logger.info("Test map column order when len(input_columns) == len(output_columns).")
  199. # apply dataset operations
  200. data1 = ds.GeneratorDataset(generator_mc(2048), ["image", "label"])
  201. data2 = ds.GeneratorDataset(generator_mc(2048), ["label", "image"])
  202. data1 = data1.map(operations=(lambda x: x * 3), input_columns="label",
  203. num_parallel_workers=4)
  204. data2 = data2.map(operations=(lambda x: x * 3), input_columns="label",
  205. num_parallel_workers=4)
  206. # Expected column order is not changed.
  207. # data1 = data[0] is "image" and data[1] is "label"
  208. # data2 = data[0] is "label" and data[1] is "image"
  209. i = 0
  210. for data1, data2 in zip(data1, data2): # each data is a dictionary
  211. golden = np.array([i])
  212. np.testing.assert_array_equal(data1[0].asnumpy(), golden)
  213. golden = np.array([[i * 3, (i + 1) * 3], [(i + 2) * 3, (i + 3) * 3]])
  214. np.testing.assert_array_equal(data1[1].asnumpy(), golden)
  215. golden = np.array([i * 3])
  216. np.testing.assert_array_equal(data2[0].asnumpy(), golden)
  217. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  218. np.testing.assert_array_equal(data2[1].asnumpy(), golden)
  219. i = i + 1
  220. def test_generator_10():
  221. """
  222. Test map column order when len(input_columns) != len(output_columns).
  223. """
  224. logger.info("Test map column order when len(input_columns) != len(output_columns).")
  225. # apply dataset operations
  226. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  227. data1 = data1.map(operations=(lambda x: (x, x * 5)), input_columns="col1", output_columns=["out1", "out2"],
  228. column_order=['col0', 'out1', 'out2'], num_parallel_workers=2)
  229. # Expected column order is |col0|out1|out2|
  230. i = 0
  231. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  232. golden = np.array([i])
  233. np.testing.assert_array_equal(item[0], golden)
  234. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  235. np.testing.assert_array_equal(item[1], golden)
  236. golden = np.array([[i * 5, (i + 1) * 5], [(i + 2) * 5, (i + 3) * 5]])
  237. np.testing.assert_array_equal(item[2], golden)
  238. i = i + 1
  239. def test_generator_11():
  240. """
  241. Test map column order when len(input_columns) != len(output_columns).
  242. """
  243. logger.info("Test map column order when len(input_columns) != len(output_columns), "
  244. "and column_order drops some columns.")
  245. # apply dataset operations
  246. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  247. data1 = data1.map(operations=(lambda x: (x, x * 5)), input_columns="col1", output_columns=["out1", "out2"],
  248. column_order=['out1', 'out2'], num_parallel_workers=2)
  249. # Expected column order is |out1|out2|
  250. i = 0
  251. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  252. # len should be 2 because col0 is dropped (not included in column_order)
  253. assert len(item) == 2
  254. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  255. np.testing.assert_array_equal(item[0], golden)
  256. golden = np.array([[i * 5, (i + 1) * 5], [(i + 2) * 5, (i + 3) * 5]])
  257. np.testing.assert_array_equal(item[1], golden)
  258. i = i + 1
  259. def test_generator_12():
  260. """
  261. Test map column order when input_columns and output_columns are None.
  262. """
  263. logger.info("Test map column order when input_columns and output_columns are None.")
  264. # apply dataset operations
  265. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  266. data1 = data1.map(operations=(lambda x: (x * 5)), num_parallel_workers=2)
  267. # Expected column order is |col0|col1|
  268. i = 0
  269. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  270. assert len(item) == 2
  271. golden = np.array([i * 5])
  272. np.testing.assert_array_equal(item[0], golden)
  273. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  274. np.testing.assert_array_equal(item[1], golden)
  275. i = i + 1
  276. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  277. data1 = data1.map(operations=(lambda x: (x * 5)), column_order=["col1", "col0"], num_parallel_workers=2)
  278. # Expected column order is |col0|col1|
  279. i = 0
  280. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  281. assert len(item) == 2
  282. golden = np.array([i * 5])
  283. np.testing.assert_array_equal(item[1], golden)
  284. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  285. np.testing.assert_array_equal(item[0], golden)
  286. i = i + 1
  287. def test_generator_13():
  288. """
  289. Test map column order when input_columns is None.
  290. """
  291. logger.info("Test map column order when input_columns is None.")
  292. # apply dataset operations
  293. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  294. data1 = data1.map(operations=(lambda x: (x * 5)), output_columns=["out0"], num_parallel_workers=2)
  295. # Expected column order is |out0|col1|
  296. i = 0
  297. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  298. assert len(item) == 2
  299. golden = np.array([i * 5])
  300. np.testing.assert_array_equal(item[0], golden)
  301. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  302. np.testing.assert_array_equal(item[1], golden)
  303. i = i + 1
  304. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  305. # len should be 2 because col0 is dropped (not included in column_order)
  306. assert len(item) == 2
  307. golden = np.array([i * 5])
  308. np.testing.assert_array_equal(item["out0"], golden)
  309. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  310. np.testing.assert_array_equal(item["col1"], golden)
  311. i = i + 1
  312. def test_generator_14():
  313. """
  314. Test 1D Generator MP + CPP sampler
  315. """
  316. logger.info("Test 1D Generator MP : 0 - 63")
  317. source = [(np.array([x]),) for x in range(256)]
  318. ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(), num_parallel_workers=4).repeat(2)
  319. i = 0
  320. for data in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  321. golden = np.array([i])
  322. np.testing.assert_array_equal(data["data"], golden)
  323. i = i + 1
  324. if i == 256:
  325. i = 0
  326. def test_generator_15():
  327. """
  328. Test 1D Generator MP + Python sampler
  329. """
  330. logger.info("Test 1D Generator MP : 0 - 63")
  331. sampler = [x for x in range(256)]
  332. source = [(np.array([x]),) for x in range(256)]
  333. ds1 = ds.GeneratorDataset(source, ["data"], sampler=sampler, num_parallel_workers=4).repeat(2)
  334. i = 0
  335. for data in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  336. golden = np.array([i])
  337. np.testing.assert_array_equal(data["data"], golden)
  338. i = i + 1
  339. if i == 256:
  340. i = 0
  341. def test_generator_16():
  342. """
  343. Test multi column generator Mp + CPP sampler
  344. """
  345. logger.info("Test multi column generator")
  346. source = [(np.array([x]), np.array([x + 1])) for x in range(256)]
  347. # apply dataset operations
  348. data1 = ds.GeneratorDataset(source, ["col0", "col1"], sampler=ds.SequentialSampler())
  349. i = 0
  350. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  351. golden = np.array([i])
  352. np.testing.assert_array_equal(item["col0"], golden)
  353. golden = np.array([i + 1])
  354. np.testing.assert_array_equal(item["col1"], golden)
  355. i = i + 1
  356. def test_generator_17():
  357. """
  358. Test multi column generator Mp + Python sampler
  359. """
  360. logger.info("Test multi column generator")
  361. sampler = [x for x in range(256)]
  362. source = [(np.array([x]), np.array([x + 1])) for x in range(256)]
  363. # apply dataset operations
  364. data1 = ds.GeneratorDataset(source, ["col0", "col1"], sampler=sampler)
  365. i = 0
  366. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  367. golden = np.array([i])
  368. np.testing.assert_array_equal(item["col0"], golden)
  369. golden = np.array([i + 1])
  370. np.testing.assert_array_equal(item["col1"], golden)
  371. i = i + 1
  372. def test_generator_error_1():
  373. def generator_np():
  374. for i in range(64):
  375. yield (np.array([{i}]),)
  376. with pytest.raises(RuntimeError) as info:
  377. data1 = ds.GeneratorDataset(generator_np, ["data"])
  378. for _ in data1:
  379. pass
  380. assert "Invalid data type" in str(info.value)
  381. def test_generator_error_2():
  382. def generator_np():
  383. for i in range(64):
  384. yield ({i},)
  385. with pytest.raises(RuntimeError) as info:
  386. data1 = ds.GeneratorDataset(generator_np, ["data"])
  387. for _ in data1:
  388. pass
  389. print("========", str(info.value))
  390. assert "Generator should return a tuple of numpy arrays" in str(info.value)
  391. def test_generator_error_3():
  392. with pytest.raises(ValueError) as info:
  393. # apply dataset operations
  394. data1 = ds.GeneratorDataset(generator_mc(2048), ["label", "image"])
  395. data1 = data1.map(operations=(lambda x: (x, x * 5)), input_columns=["label"], output_columns=["out1", "out2"],
  396. num_parallel_workers=2)
  397. for _ in data1:
  398. pass
  399. assert "When length of input_columns and output_columns are not equal, column_order must be specified." in \
  400. str(info.value)
  401. def test_generator_error_4():
  402. with pytest.raises(RuntimeError) as info:
  403. # apply dataset operations
  404. data1 = ds.GeneratorDataset(generator_mc(2048), ["label", "image"])
  405. data1 = data1.map(operations=(lambda x: (x, x * 5)), input_columns=["label"],
  406. num_parallel_workers=2)
  407. for _ in data1:
  408. pass
  409. assert "Unexpected error. Result of a tensorOp doesn't match output column names" in str(info.value)
  410. def test_generator_sequential_sampler():
  411. source = [(np.array([x]),) for x in range(64)]
  412. ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler())
  413. i = 0
  414. for data in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  415. golden = np.array([i])
  416. np.testing.assert_array_equal(data["data"], golden)
  417. i = i + 1
  418. def test_generator_random_sampler():
  419. source = [(np.array([x]),) for x in range(64)]
  420. ds1 = ds.GeneratorDataset(source, ["data"], shuffle=True)
  421. for _ in ds1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  422. pass
  423. def test_generator_distributed_sampler():
  424. source = [(np.array([x]),) for x in range(64)]
  425. for sid in range(8):
  426. ds1 = ds.GeneratorDataset(source, ["data"], shuffle=False, num_shards=8, shard_id=sid)
  427. i = sid
  428. for data in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  429. golden = np.array([i])
  430. np.testing.assert_array_equal(data["data"], golden)
  431. i = i + 8
  432. def test_generator_num_samples():
  433. source = [(np.array([x]),) for x in range(64)]
  434. num_samples = 32
  435. ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(num_samples=num_samples))
  436. ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(32)], num_samples=num_samples)
  437. ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples=num_samples)
  438. count = 0
  439. for _ in ds1.create_dict_iterator(num_epochs=1):
  440. count = count + 1
  441. assert count == num_samples
  442. count = 0
  443. for _ in ds2.create_dict_iterator(num_epochs=1):
  444. count = count + 1
  445. assert count == num_samples
  446. count = 0
  447. for _ in ds3.create_dict_iterator(num_epochs=1):
  448. count = count + 1
  449. assert count == num_samples
  450. def test_generator_num_samples_underflow():
  451. source = [(np.array([x]),) for x in range(64)]
  452. num_samples = 256
  453. ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(64)], num_samples=num_samples)
  454. ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples=num_samples)
  455. count = 0
  456. for _ in ds2.create_dict_iterator(num_epochs=1):
  457. count = count + 1
  458. assert count == 64
  459. count = 0
  460. for _ in ds3.create_dict_iterator(num_epochs=1):
  461. count = count + 1
  462. assert count == 64
  463. def type_tester_with_type_check_2c_schema(t, c):
  464. logger.info("Test with Type {}".format(t.__name__))
  465. schema = ds.Schema()
  466. schema.add_column("data0", c[0])
  467. schema.add_column("data1", c[1])
  468. # apply dataset operations
  469. data1 = ds.GeneratorDataset((lambda: generator_with_type_2c(t)), schema=schema)
  470. data1 = data1.batch(4)
  471. i = 0
  472. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  473. golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t)
  474. np.testing.assert_array_equal(item["data0"], golden)
  475. i = i + 4
  476. def test_generator_schema():
  477. """
  478. Test 2 column Generator on different data type with type check with schema input
  479. """
  480. logger.info("Test 2 column Generator on all data types with type check")
  481. np_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32,
  482. np.float64]
  483. de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32,
  484. mstype.uint64, mstype.float32, mstype.float64]
  485. for i, _ in enumerate(np_types):
  486. type_tester_with_type_check_2c_schema(np_types[i], [de_types[i], de_types[i]])
  487. def test_generator_dataset_size_0():
  488. """
  489. Test GeneratorDataset get_dataset_size by iterator method.
  490. """
  491. logger.info("Test 1D Generator : 0 - 63 get_dataset_size")
  492. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  493. data_size = data1.get_dataset_size()
  494. num_rows = 0
  495. for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  496. num_rows = num_rows + 1
  497. assert data_size == num_rows
  498. def test_generator_dataset_size_1():
  499. """
  500. Test GeneratorDataset get_dataset_size by __len__ method.
  501. """
  502. logger.info("Test DatasetGenerator get_dataset_size")
  503. dataset_generator = DatasetGenerator()
  504. data1 = ds.GeneratorDataset(dataset_generator, ["data"])
  505. data_size = data1.get_dataset_size()
  506. num_rows = 0
  507. for _ in data1.create_dict_iterator(num_epochs=1):
  508. num_rows = num_rows + 1
  509. assert data_size == num_rows
  510. def test_generator_dataset_size_2():
  511. """
  512. Test GeneratorDataset + repeat get_dataset_size
  513. """
  514. logger.info("Test 1D Generator + repeat get_dataset_size")
  515. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  516. data1 = data1.repeat(2)
  517. data_size = data1.get_dataset_size()
  518. num_rows = 0
  519. for _ in data1.create_dict_iterator(num_epochs=1):
  520. num_rows = num_rows + 1
  521. assert data_size == num_rows
  522. def test_generator_dataset_size_3():
  523. """
  524. Test GeneratorDataset + batch get_dataset_size
  525. """
  526. logger.info("Test 1D Generator + batch get_dataset_size")
  527. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  528. data1 = data1.batch(4)
  529. data_size = data1.get_dataset_size()
  530. num_rows = 0
  531. for _ in data1.create_dict_iterator(num_epochs=1):
  532. num_rows += 1
  533. assert data_size == num_rows
  534. def test_generator_dataset_size_4():
  535. """
  536. Test GeneratorDataset + num_shards
  537. """
  538. logger.info("Test 1D Generator : 0 - 63 + num_shards get_dataset_size")
  539. dataset_generator = DatasetGenerator()
  540. data1 = ds.GeneratorDataset(dataset_generator, ["data"], num_shards=3, shard_id=0)
  541. data_size = data1.get_dataset_size()
  542. num_rows = 0
  543. for _ in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  544. num_rows = num_rows + 1
  545. assert data_size == num_rows
  546. def test_generator_dataset_size_5():
  547. """
  548. Test get_dataset_size after create_dict_iterator
  549. """
  550. logger.info("Test get_dataset_size after create_dict_iterator")
  551. dataset_generator = DatasetGenerator()
  552. data1 = ds.GeneratorDataset(dataset_generator, ["data"], num_shards=3, shard_id=0)
  553. num_rows = 0
  554. for _ in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  555. num_rows = num_rows + 1
  556. data_size = data1.get_dataset_size()
  557. assert data_size == num_rows
  558. def manual_test_generator_keyboard_interrupt():
  559. """
  560. Test keyboard_interrupt
  561. """
  562. logger.info("Test 1D Generator MP : 0 - 63")
  563. class MyDS():
  564. def __getitem__(self, item):
  565. while True:
  566. pass
  567. def __len__(self):
  568. return 1024
  569. ds1 = ds.GeneratorDataset(MyDS(), ["data"], num_parallel_workers=4).repeat(2)
  570. for _ in ds1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  571. pass
  572. def test_explicit_deepcopy():
  573. """
  574. Test explicit_deepcopy
  575. """
  576. logger.info("Test explicit_deepcopy")
  577. ds1 = ds.NumpySlicesDataset([1, 2], shuffle=False)
  578. ds2 = copy.deepcopy(ds1)
  579. for d1, d2 in zip(ds1, ds2):
  580. assert d1 == d2
  581. if __name__ == "__main__":
  582. test_generator_0()
  583. test_generator_1()
  584. test_generator_2()
  585. test_generator_3()
  586. test_generator_4()
  587. test_generator_5()
  588. test_generator_6()
  589. test_generator_7()
  590. test_generator_8()
  591. test_generator_9()
  592. test_generator_10()
  593. test_generator_11()
  594. test_generator_12()
  595. test_generator_13()
  596. test_generator_14()
  597. test_generator_15()
  598. test_generator_16()
  599. test_generator_17()
  600. test_generator_error_1()
  601. test_generator_error_2()
  602. test_generator_error_3()
  603. test_generator_error_4()
  604. test_generator_sequential_sampler()
  605. test_generator_distributed_sampler()
  606. test_generator_random_sampler()
  607. test_generator_num_samples()
  608. test_generator_num_samples_underflow()
  609. test_generator_schema()
  610. test_generator_dataset_size_0()
  611. test_generator_dataset_size_1()
  612. test_generator_dataset_size_2()
  613. test_generator_dataset_size_3()
  614. test_generator_dataset_size_4()
  615. test_generator_dataset_size_5()
  616. test_explicit_deepcopy()