You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_generator.py 30 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import copy
  16. import numpy as np
  17. import pytest
  18. import mindspore.common.dtype as mstype
  19. import mindspore.dataset as ds
  20. from mindspore import log as logger
  21. # Generate 1d int numpy array from 0 - 63
  22. def generator_1d():
  23. for i in range(64):
  24. yield (np.array([i]),)
  25. class DatasetGenerator:
  26. def __init__(self):
  27. pass
  28. def __getitem__(self, item):
  29. return (np.array([item]),)
  30. def __len__(self):
  31. return 10
  32. class DatasetGeneratorLarge:
  33. def __init__(self):
  34. self.data = np.array(range(4000))
  35. def __getitem__(self, item):
  36. return (self.data + item, self.data *10)
  37. def __len__(self):
  38. return 10
  39. def test_generator_0():
  40. """
  41. Test 1D Generator
  42. """
  43. logger.info("Test 1D Generator : 0 - 63")
  44. # apply dataset operations
  45. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  46. i = 0
  47. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  48. golden = np.array([i])
  49. np.testing.assert_array_equal(item["data"], golden)
  50. i = i + 1
  51. # Generate md int numpy array from [[0, 1], [2, 3]] to [[63, 64], [65, 66]]
  52. def generator_md():
  53. for i in range(64):
  54. yield (np.array([[i, i + 1], [i + 2, i + 3]]),)
  55. def test_generator_1():
  56. """
  57. Test MD Generator
  58. """
  59. logger.info("Test MD Generator : 0 - 63, with shape [2, 2]")
  60. # apply dataset operations
  61. data1 = ds.GeneratorDataset(generator_md, ["data"])
  62. i = 0
  63. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  64. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  65. np.testing.assert_array_equal(item["data"], golden)
  66. i = i + 1
  67. # Generate two columns, the first column is from Generator1D, the second column is from GeneratorMD
  68. def generator_mc(maxid=64):
  69. for i in range(maxid):
  70. yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]]))
  71. def test_generator_2():
  72. """
  73. Test multi column generator
  74. """
  75. logger.info("Test multi column generator")
  76. # apply dataset operations
  77. data1 = ds.GeneratorDataset(generator_mc, ["col0", "col1"])
  78. i = 0
  79. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  80. golden = np.array([i])
  81. np.testing.assert_array_equal(item["col0"], golden)
  82. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  83. np.testing.assert_array_equal(item["col1"], golden)
  84. i = i + 1
  85. def test_generator_3():
  86. """
  87. Test 1D Generator + repeat(4)
  88. """
  89. logger.info("Test 1D Generator : 0 - 63 + Repeat(4)")
  90. # apply dataset operations
  91. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  92. data1 = data1.repeat(4)
  93. i = 0
  94. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  95. golden = np.array([i])
  96. np.testing.assert_array_equal(item["data"], golden)
  97. i = i + 1
  98. if i == 64:
  99. i = 0
  100. def test_generator_4():
  101. """
  102. Test fixed size 1D Generator + batch
  103. """
  104. logger.info("Test 1D Generator : 0 - 63 + batch(4)")
  105. # apply dataset operations
  106. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  107. data1 = data1.batch(4)
  108. i = 0
  109. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  110. golden = np.array([[i], [i + 1], [i + 2], [i + 3]])
  111. np.testing.assert_array_equal(item["data"], golden)
  112. i = i + 4
  113. def generator_with_type(t):
  114. for i in range(64):
  115. yield (np.array([i], dtype=t),)
  116. def type_tester(t):
  117. logger.info("Test with Type {}".format(t.__name__))
  118. # apply dataset operations
  119. data1 = ds.GeneratorDataset((lambda: generator_with_type(t)), ["data"])
  120. data1 = data1.batch(4)
  121. i = 0
  122. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  123. golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t)
  124. np.testing.assert_array_equal(item["data"], golden)
  125. i = i + 4
  126. def test_generator_5():
  127. """
  128. Test 1D Generator on different data type
  129. """
  130. logger.info("Test 1D Generator on all data types")
  131. types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32, np.float64]
  132. for t in types:
  133. type_tester(t)
  134. def type_tester_with_type_check(t, c):
  135. logger.info("Test with Type {}".format(t.__name__))
  136. # apply dataset operations
  137. data1 = ds.GeneratorDataset((lambda: generator_with_type(t)), ["data"], column_types=[c])
  138. data1 = data1.batch(4)
  139. i = 0
  140. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  141. golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t)
  142. np.testing.assert_array_equal(item["data"], golden)
  143. i = i + 4
  144. def test_generator_6():
  145. """
  146. Test 1D Generator on different data type with type check
  147. """
  148. logger.info("Test 1D Generator on all data types with type check")
  149. np_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32,
  150. np.float64]
  151. de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32,
  152. mstype.uint64, mstype.float32, mstype.float64]
  153. for i, _ in enumerate(np_types):
  154. type_tester_with_type_check(np_types[i], de_types[i])
  155. def generator_with_type_2c(t):
  156. for i in range(64):
  157. yield (np.array([i], dtype=t), np.array([i], dtype=t))
  158. def type_tester_with_type_check_2c(t, c):
  159. logger.info("Test with Type {}".format(t.__name__))
  160. # apply dataset operations
  161. data1 = ds.GeneratorDataset((lambda: generator_with_type_2c(t)), ["data0", "data1"], column_types=c)
  162. data1 = data1.batch(4)
  163. i = 0
  164. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  165. golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t)
  166. np.testing.assert_array_equal(item["data0"], golden)
  167. i = i + 4
  168. def test_generator_7():
  169. """
  170. Test 2 column Generator on different data type with type check
  171. """
  172. logger.info("Test 2 column Generator on all data types with type check")
  173. np_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32,
  174. np.float64]
  175. de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32,
  176. mstype.uint64, mstype.float32, mstype.float64]
  177. for i, _ in enumerate(np_types):
  178. type_tester_with_type_check_2c(np_types[i], [None, de_types[i]])
  179. def test_generator_8():
  180. """
  181. Test multi column generator with few mapops
  182. """
  183. logger.info("Test multi column generator with mapops to check the order too")
  184. # apply dataset operations
  185. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  186. data1 = data1.map(operations=(lambda x: x * 3), input_columns="col0", output_columns="out0",
  187. num_parallel_workers=2)
  188. data1 = data1.map(operations=(lambda x: (x * 7, x)), input_columns="col1", output_columns=["out1", "out2"],
  189. num_parallel_workers=2, column_order=["out0", "out1", "out2"])
  190. data1 = data1.map(operations=(lambda x: x + 1), input_columns="out2", output_columns="out2",
  191. num_parallel_workers=2)
  192. i = 0
  193. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  194. golden = np.array([i * 3])
  195. np.testing.assert_array_equal(item["out0"], golden)
  196. golden = np.array([[i * 7, (i + 1) * 7], [(i + 2) * 7, (i + 3) * 7]])
  197. np.testing.assert_array_equal(item["out1"], golden)
  198. golden = np.array([[i + 1, i + 2], [i + 3, i + 4]])
  199. np.testing.assert_array_equal(item["out2"], golden)
  200. i = i + 1
  201. def test_generator_9():
  202. """
  203. Test map column order when len(input_columns) == len(output_columns).
  204. """
  205. logger.info("Test map column order when len(input_columns) == len(output_columns).")
  206. # apply dataset operations
  207. data1 = ds.GeneratorDataset(generator_mc(2048), ["image", "label"])
  208. data2 = ds.GeneratorDataset(generator_mc(2048), ["label", "image"])
  209. data1 = data1.map(operations=(lambda x: x * 3), input_columns="label",
  210. num_parallel_workers=4)
  211. data2 = data2.map(operations=(lambda x: x * 3), input_columns="label",
  212. num_parallel_workers=4)
  213. # Expected column order is not changed.
  214. # data1 = data[0] is "image" and data[1] is "label"
  215. # data2 = data[0] is "label" and data[1] is "image"
  216. i = 0
  217. for data1, data2 in zip(data1, data2): # each data is a dictionary
  218. golden = np.array([i])
  219. np.testing.assert_array_equal(data1[0].asnumpy(), golden)
  220. golden = np.array([[i * 3, (i + 1) * 3], [(i + 2) * 3, (i + 3) * 3]])
  221. np.testing.assert_array_equal(data1[1].asnumpy(), golden)
  222. golden = np.array([i * 3])
  223. np.testing.assert_array_equal(data2[0].asnumpy(), golden)
  224. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  225. np.testing.assert_array_equal(data2[1].asnumpy(), golden)
  226. i = i + 1
  227. def test_generator_10():
  228. """
  229. Test map column order when len(input_columns) != len(output_columns).
  230. """
  231. logger.info("Test map column order when len(input_columns) != len(output_columns).")
  232. # apply dataset operations
  233. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  234. data1 = data1.map(operations=(lambda x: (x, x * 5)), input_columns="col1", output_columns=["out1", "out2"],
  235. column_order=['col0', 'out1', 'out2'], num_parallel_workers=2)
  236. # Expected column order is |col0|out1|out2|
  237. i = 0
  238. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  239. golden = np.array([i])
  240. np.testing.assert_array_equal(item[0], golden)
  241. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  242. np.testing.assert_array_equal(item[1], golden)
  243. golden = np.array([[i * 5, (i + 1) * 5], [(i + 2) * 5, (i + 3) * 5]])
  244. np.testing.assert_array_equal(item[2], golden)
  245. i = i + 1
  246. def test_generator_11():
  247. """
  248. Test map column order when len(input_columns) != len(output_columns).
  249. """
  250. logger.info("Test map column order when len(input_columns) != len(output_columns), "
  251. "and column_order drops some columns.")
  252. # apply dataset operations
  253. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  254. data1 = data1.map(operations=(lambda x: (x, x * 5)), input_columns="col1", output_columns=["out1", "out2"],
  255. column_order=['out1', 'out2'], num_parallel_workers=2)
  256. # Expected column order is |out1|out2|
  257. i = 0
  258. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  259. # len should be 2 because col0 is dropped (not included in column_order)
  260. assert len(item) == 2
  261. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  262. np.testing.assert_array_equal(item[0], golden)
  263. golden = np.array([[i * 5, (i + 1) * 5], [(i + 2) * 5, (i + 3) * 5]])
  264. np.testing.assert_array_equal(item[1], golden)
  265. i = i + 1
  266. def test_generator_12():
  267. """
  268. Test map column order when input_columns and output_columns are None.
  269. """
  270. logger.info("Test map column order when input_columns and output_columns are None.")
  271. # apply dataset operations
  272. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  273. data1 = data1.map(operations=(lambda x: (x * 5)), num_parallel_workers=2)
  274. # Expected column order is |col0|col1|
  275. i = 0
  276. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  277. assert len(item) == 2
  278. golden = np.array([i * 5])
  279. np.testing.assert_array_equal(item[0], golden)
  280. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  281. np.testing.assert_array_equal(item[1], golden)
  282. i = i + 1
  283. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  284. data1 = data1.map(operations=(lambda x: (x * 5)), column_order=["col1", "col0"], num_parallel_workers=2)
  285. # Expected column order is |col0|col1|
  286. i = 0
  287. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  288. assert len(item) == 2
  289. golden = np.array([i * 5])
  290. np.testing.assert_array_equal(item[1], golden)
  291. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  292. np.testing.assert_array_equal(item[0], golden)
  293. i = i + 1
  294. def test_generator_13():
  295. """
  296. Test map column order when input_columns is None.
  297. """
  298. logger.info("Test map column order when input_columns is None.")
  299. # apply dataset operations
  300. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  301. data1 = data1.map(operations=(lambda x: (x * 5)), output_columns=["out0"], num_parallel_workers=2)
  302. # Expected column order is |out0|col1|
  303. i = 0
  304. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  305. assert len(item) == 2
  306. golden = np.array([i * 5])
  307. np.testing.assert_array_equal(item[0], golden)
  308. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  309. np.testing.assert_array_equal(item[1], golden)
  310. i = i + 1
  311. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  312. # len should be 2 because col0 is dropped (not included in column_order)
  313. assert len(item) == 2
  314. golden = np.array([i * 5])
  315. np.testing.assert_array_equal(item["out0"], golden)
  316. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  317. np.testing.assert_array_equal(item["col1"], golden)
  318. i = i + 1
  319. def test_generator_14():
  320. """
  321. Test 1D Generator MP + CPP sampler
  322. """
  323. logger.info("Test 1D Generator MP : 0 - 63")
  324. source = [(np.array([x]),) for x in range(256)]
  325. ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(), num_parallel_workers=4).repeat(2)
  326. i = 0
  327. for data in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  328. golden = np.array([i])
  329. np.testing.assert_array_equal(data["data"], golden)
  330. i = i + 1
  331. if i == 256:
  332. i = 0
  333. def test_generator_15():
  334. """
  335. Test 1D Generator MP + Python sampler
  336. """
  337. logger.info("Test 1D Generator MP : 0 - 63")
  338. sampler = [x for x in range(256)]
  339. source = [(np.array([x]),) for x in range(256)]
  340. ds1 = ds.GeneratorDataset(source, ["data"], sampler=sampler, num_parallel_workers=4).repeat(2)
  341. i = 0
  342. for data in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  343. golden = np.array([i])
  344. np.testing.assert_array_equal(data["data"], golden)
  345. i = i + 1
  346. if i == 256:
  347. i = 0
  348. def test_generator_16():
  349. """
  350. Test multi column generator Mp + CPP sampler
  351. """
  352. logger.info("Test multi column generator")
  353. source = [(np.array([x]), np.array([x + 1])) for x in range(256)]
  354. # apply dataset operations
  355. data1 = ds.GeneratorDataset(source, ["col0", "col1"], sampler=ds.SequentialSampler())
  356. i = 0
  357. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  358. golden = np.array([i])
  359. np.testing.assert_array_equal(item["col0"], golden)
  360. golden = np.array([i + 1])
  361. np.testing.assert_array_equal(item["col1"], golden)
  362. i = i + 1
  363. def test_generator_17():
  364. """
  365. Test multi column generator Mp + Python sampler
  366. """
  367. logger.info("Test multi column generator")
  368. sampler = [x for x in range(256)]
  369. source = [(np.array([x]), np.array([x + 1])) for x in range(256)]
  370. # apply dataset operations
  371. data1 = ds.GeneratorDataset(source, ["col0", "col1"], sampler=sampler)
  372. i = 0
  373. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  374. golden = np.array([i])
  375. np.testing.assert_array_equal(item["col0"], golden)
  376. golden = np.array([i + 1])
  377. np.testing.assert_array_equal(item["col1"], golden)
  378. i = i + 1
  379. def test_generator_18():
  380. """
  381. Test multiprocessing flag (same as test 13 with python_multiprocessing=True flag)
  382. """
  383. logger.info("Test map column order when input_columns is None.")
  384. # apply dataset operations
  385. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"], python_multiprocessing=True)
  386. data1 = data1.map(operations=(lambda x: (x * 5)), output_columns=["out0"], num_parallel_workers=2,
  387. python_multiprocessing=True)
  388. # Expected column order is |out0|col1|
  389. i = 0
  390. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  391. assert len(item) == 2
  392. golden = np.array([i * 5])
  393. np.testing.assert_array_equal(item[0], golden)
  394. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  395. np.testing.assert_array_equal(item[1], golden)
  396. i = i + 1
  397. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  398. # len should be 2 because col0 is dropped (not included in column_order)
  399. assert len(item) == 2
  400. golden = np.array([i * 5])
  401. np.testing.assert_array_equal(item["out0"], golden)
  402. def test_generator_19():
  403. """
  404. Test multiprocessing flag with 2 different large columns
  405. """
  406. logger.info("Test map column order when input_columns is None.")
  407. # apply dataset operations
  408. data1 = ds.GeneratorDataset(DatasetGeneratorLarge(), ["col0", "col1"], python_multiprocessing=True, shuffle=False)
  409. # Expected column order is |out0|col1|
  410. i = 0
  411. for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
  412. assert len(item) == 2
  413. golden = np.array(range(4000)) + i
  414. np.testing.assert_array_equal(item[0], golden)
  415. golden = np.array(range(4000)) * 10
  416. np.testing.assert_array_equal(item[1], golden)
  417. i = i + 1
  418. class RandomAccessDataset:
  419. def __init__(self):
  420. self.__data = np.random.sample((5, 1))
  421. def __getitem__(self, item):
  422. return self.__data[item]
  423. def __len__(self):
  424. return 5
  425. class RandomAccessDatasetWithoutLen:
  426. def __init__(self):
  427. self.__data = np.random.sample((5, 1))
  428. def __getitem__(self, item):
  429. return self.__data[item]
  430. class IterableDataset:
  431. def __init__(self):
  432. self.count = 0
  433. self.max = 10
  434. def __iter__(self):
  435. return self
  436. def __next__(self):
  437. if self.count >= self.max:
  438. raise StopIteration
  439. self.count += 1
  440. return (np.array(self.count),)
  441. def test_generator_20():
  442. """
  443. Test mappable and unmappable dataset as source for GeneratorDataset.
  444. """
  445. logger.info("Test mappable and unmappable dataset as source for GeneratorDataset.")
  446. # Mappable dataset
  447. data1 = ds.GeneratorDataset(RandomAccessDataset(), ["col0"])
  448. dataset_size1 = data1.get_dataset_size()
  449. assert dataset_size1 == 5
  450. # Mappable dataset without __len__
  451. data2 = ds.GeneratorDataset(RandomAccessDatasetWithoutLen(), ["col0"])
  452. try:
  453. data2.get_dataset_size()
  454. except RuntimeError as e:
  455. assert "'__len__' method is required" in str(e)
  456. # Unmappable dataset
  457. data3 = ds.GeneratorDataset(IterableDataset(), ["col0"])
  458. dataset_size3 = data3.get_dataset_size()
  459. assert dataset_size3 == 10
  460. def test_generator_error_1():
  461. def generator_np():
  462. for i in range(64):
  463. yield (np.array([{i}]),)
  464. with pytest.raises(RuntimeError) as info:
  465. data1 = ds.GeneratorDataset(generator_np, ["data"])
  466. for _ in data1:
  467. pass
  468. assert "Invalid data type" in str(info.value)
  469. def test_generator_error_2():
  470. def generator_np():
  471. for i in range(64):
  472. yield ({i},)
  473. with pytest.raises(RuntimeError) as info:
  474. data1 = ds.GeneratorDataset(generator_np, ["data"])
  475. for _ in data1:
  476. pass
  477. print("========", str(info.value))
  478. assert "Generator should return a tuple of NumPy arrays" in str(info.value)
  479. def test_generator_error_3():
  480. with pytest.raises(ValueError) as info:
  481. # apply dataset operations
  482. data1 = ds.GeneratorDataset(generator_mc(2048), ["label", "image"])
  483. data1 = data1.map(operations=(lambda x: (x, x * 5)), input_columns=["label"], output_columns=["out1", "out2"],
  484. num_parallel_workers=2)
  485. for _ in data1:
  486. pass
  487. assert "When length of input_columns and output_columns are not equal, column_order must be specified." in \
  488. str(info.value)
  489. def test_generator_error_4():
  490. with pytest.raises(RuntimeError) as info:
  491. # apply dataset operations
  492. data1 = ds.GeneratorDataset(generator_mc(2048), ["label", "image"])
  493. data1 = data1.map(operations=(lambda x: (x, x * 5)), input_columns=["label"],
  494. num_parallel_workers=2)
  495. for _ in data1:
  496. pass
  497. assert "Unexpected error. Result of a tensorOp doesn't match output column names" in str(info.value)
  498. def test_generator_sequential_sampler():
  499. source = [(np.array([x]),) for x in range(64)]
  500. ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler())
  501. i = 0
  502. for data in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  503. golden = np.array([i])
  504. np.testing.assert_array_equal(data["data"], golden)
  505. i = i + 1
  506. def test_generator_random_sampler():
  507. source = [(np.array([x]),) for x in range(64)]
  508. ds1 = ds.GeneratorDataset(source, ["data"], shuffle=True)
  509. for _ in ds1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  510. pass
  511. def test_generator_distributed_sampler():
  512. source = [(np.array([x]),) for x in range(64)]
  513. for sid in range(8):
  514. ds1 = ds.GeneratorDataset(source, ["data"], shuffle=False, num_shards=8, shard_id=sid)
  515. i = sid
  516. for data in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  517. golden = np.array([i])
  518. np.testing.assert_array_equal(data["data"], golden)
  519. i = i + 8
  520. def test_generator_num_samples():
  521. source = [(np.array([x]),) for x in range(64)]
  522. num_samples = 32
  523. ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(num_samples=num_samples))
  524. ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(32)], num_samples=num_samples)
  525. ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples=num_samples)
  526. count = 0
  527. for _ in ds1.create_dict_iterator(num_epochs=1):
  528. count = count + 1
  529. assert count == num_samples
  530. count = 0
  531. for _ in ds2.create_dict_iterator(num_epochs=1):
  532. count = count + 1
  533. assert count == num_samples
  534. count = 0
  535. for _ in ds3.create_dict_iterator(num_epochs=1):
  536. count = count + 1
  537. assert count == num_samples
  538. def test_generator_num_samples_underflow():
  539. source = [(np.array([x]),) for x in range(64)]
  540. num_samples = 256
  541. ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(64)], num_samples=num_samples)
  542. ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples=num_samples)
  543. count = 0
  544. for _ in ds2.create_dict_iterator(num_epochs=1):
  545. count = count + 1
  546. assert count == 64
  547. count = 0
  548. for _ in ds3.create_dict_iterator(num_epochs=1):
  549. count = count + 1
  550. assert count == 64
  551. def type_tester_with_type_check_2c_schema(t, c):
  552. logger.info("Test with Type {}".format(t.__name__))
  553. schema = ds.Schema()
  554. schema.add_column("data0", c[0])
  555. schema.add_column("data1", c[1])
  556. # apply dataset operations
  557. data1 = ds.GeneratorDataset((lambda: generator_with_type_2c(t)), schema=schema)
  558. data1 = data1.batch(4)
  559. i = 0
  560. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  561. golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t)
  562. np.testing.assert_array_equal(item["data0"], golden)
  563. i = i + 4
  564. def test_generator_schema():
  565. """
  566. Test 2 column Generator on different data type with type check with schema input
  567. """
  568. logger.info("Test 2 column Generator on all data types with type check")
  569. np_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32,
  570. np.float64]
  571. de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32,
  572. mstype.uint64, mstype.float32, mstype.float64]
  573. for i, _ in enumerate(np_types):
  574. type_tester_with_type_check_2c_schema(np_types[i], [de_types[i], de_types[i]])
  575. def test_generator_dataset_size_0():
  576. """
  577. Test GeneratorDataset get_dataset_size by iterator method.
  578. """
  579. logger.info("Test 1D Generator : 0 - 63 get_dataset_size")
  580. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  581. data_size = data1.get_dataset_size()
  582. num_rows = 0
  583. for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  584. num_rows = num_rows + 1
  585. assert data_size == num_rows
  586. def test_generator_dataset_size_1():
  587. """
  588. Test GeneratorDataset get_dataset_size by __len__ method.
  589. """
  590. logger.info("Test DatasetGenerator get_dataset_size")
  591. dataset_generator = DatasetGenerator()
  592. data1 = ds.GeneratorDataset(dataset_generator, ["data"])
  593. data_size = data1.get_dataset_size()
  594. num_rows = 0
  595. for _ in data1.create_dict_iterator(num_epochs=1):
  596. num_rows = num_rows + 1
  597. assert data_size == num_rows
  598. def test_generator_dataset_size_2():
  599. """
  600. Test GeneratorDataset + repeat get_dataset_size
  601. """
  602. logger.info("Test 1D Generator + repeat get_dataset_size")
  603. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  604. data1 = data1.repeat(2)
  605. data_size = data1.get_dataset_size()
  606. num_rows = 0
  607. for _ in data1.create_dict_iterator(num_epochs=1):
  608. num_rows = num_rows + 1
  609. assert data_size == num_rows
  610. def test_generator_dataset_size_3():
  611. """
  612. Test GeneratorDataset + batch get_dataset_size
  613. """
  614. logger.info("Test 1D Generator + batch get_dataset_size")
  615. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  616. data1 = data1.batch(4)
  617. data_size = data1.get_dataset_size()
  618. num_rows = 0
  619. for _ in data1.create_dict_iterator(num_epochs=1):
  620. num_rows += 1
  621. assert data_size == num_rows
  622. def test_generator_dataset_size_4():
  623. """
  624. Test GeneratorDataset + num_shards
  625. """
  626. logger.info("Test 1D Generator : 0 - 63 + num_shards get_dataset_size")
  627. dataset_generator = DatasetGenerator()
  628. data1 = ds.GeneratorDataset(dataset_generator, ["data"], num_shards=3, shard_id=0)
  629. data_size = data1.get_dataset_size()
  630. num_rows = 0
  631. for _ in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  632. num_rows = num_rows + 1
  633. assert data_size == num_rows
  634. def test_generator_dataset_size_5():
  635. """
  636. Test get_dataset_size after create_dict_iterator
  637. """
  638. logger.info("Test get_dataset_size after create_dict_iterator")
  639. dataset_generator = DatasetGenerator()
  640. data1 = ds.GeneratorDataset(dataset_generator, ["data"], num_shards=3, shard_id=0)
  641. num_rows = 0
  642. for _ in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  643. num_rows = num_rows + 1
  644. data_size = data1.get_dataset_size()
  645. assert data_size == num_rows
  646. def manual_test_generator_keyboard_interrupt():
  647. """
  648. Test keyboard_interrupt
  649. """
  650. logger.info("Test 1D Generator MP : 0 - 63")
  651. class MyDS():
  652. def __getitem__(self, item):
  653. while True:
  654. pass
  655. def __len__(self):
  656. return 1024
  657. ds1 = ds.GeneratorDataset(MyDS(), ["data"], num_parallel_workers=4).repeat(2)
  658. for _ in ds1.create_dict_iterator(num_epochs=1): # each data is a dictionary
  659. pass
  660. def test_explicit_deepcopy():
  661. """
  662. Test explicit_deepcopy
  663. """
  664. logger.info("Test explicit_deepcopy")
  665. ds1 = ds.NumpySlicesDataset([1, 2], shuffle=False)
  666. ds2 = copy.deepcopy(ds1)
  667. for d1, d2 in zip(ds1, ds2):
  668. assert d1 == d2
  669. if __name__ == "__main__":
  670. test_generator_0()
  671. test_generator_1()
  672. test_generator_2()
  673. test_generator_3()
  674. test_generator_4()
  675. test_generator_5()
  676. test_generator_6()
  677. test_generator_7()
  678. test_generator_8()
  679. test_generator_9()
  680. test_generator_10()
  681. test_generator_11()
  682. test_generator_12()
  683. test_generator_13()
  684. test_generator_14()
  685. test_generator_15()
  686. test_generator_16()
  687. test_generator_17()
  688. test_generator_18()
  689. test_generator_19()
  690. test_generator_error_1()
  691. test_generator_error_2()
  692. test_generator_error_3()
  693. test_generator_error_4()
  694. test_generator_sequential_sampler()
  695. test_generator_distributed_sampler()
  696. test_generator_random_sampler()
  697. test_generator_num_samples()
  698. test_generator_num_samples_underflow()
  699. test_generator_schema()
  700. test_generator_dataset_size_0()
  701. test_generator_dataset_size_1()
  702. test_generator_dataset_size_2()
  703. test_generator_dataset_size_3()
  704. test_generator_dataset_size_4()
  705. test_generator_dataset_size_5()
  706. test_explicit_deepcopy()