You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_generator.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.common.dtype as mstype
  18. import mindspore.dataset as ds
  19. from mindspore import log as logger
  20. # Generate 1d int numpy array from 0 - 63
  21. def generator_1d():
  22. for i in range(64):
  23. yield (np.array([i]),)
  24. def test_case_0():
  25. """
  26. Test 1D Generator
  27. """
  28. logger.info("Test 1D Generator : 0 - 63")
  29. # apply dataset operations
  30. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  31. i = 0
  32. for item in data1.create_dict_iterator(): # each data is a dictionary
  33. golden = np.array([i])
  34. assert np.array_equal(item["data"], golden)
  35. i = i + 1
  36. # Generate md int numpy array from [[0, 1], [2, 3]] to [[63, 64], [65, 66]]
  37. def generator_md():
  38. for i in range(64):
  39. yield (np.array([[i, i + 1], [i + 2, i + 3]]),)
  40. def test_case_1():
  41. """
  42. Test MD Generator
  43. """
  44. logger.info("Test MD Generator : 0 - 63, with shape [2, 2]")
  45. # apply dataset operations
  46. data1 = ds.GeneratorDataset(generator_md, ["data"])
  47. i = 0
  48. for item in data1.create_dict_iterator(): # each data is a dictionary
  49. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  50. assert np.array_equal(item["data"], golden)
  51. i = i + 1
  52. # Generate two columns, the first column is from Generator1D, the second column is from GeneratorMD
  53. def generator_mc(maxid=64):
  54. for i in range(maxid):
  55. yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]]))
  56. def test_case_2():
  57. """
  58. Test multi column generator
  59. """
  60. logger.info("Test multi column generator")
  61. # apply dataset operations
  62. data1 = ds.GeneratorDataset(generator_mc, ["col0", "col1"])
  63. i = 0
  64. for item in data1.create_dict_iterator(): # each data is a dictionary
  65. golden = np.array([i])
  66. assert np.array_equal(item["col0"], golden)
  67. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  68. assert np.array_equal(item["col1"], golden)
  69. i = i + 1
  70. def test_case_3():
  71. """
  72. Test 1D Generator + repeat(4)
  73. """
  74. logger.info("Test 1D Generator : 0 - 63 + Repeat(4)")
  75. # apply dataset operations
  76. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  77. data1 = data1.repeat(4)
  78. i = 0
  79. for item in data1.create_dict_iterator(): # each data is a dictionary
  80. golden = np.array([i])
  81. assert np.array_equal(item["data"], golden)
  82. i = i + 1
  83. if i == 64:
  84. i = 0
  85. def test_case_4():
  86. """
  87. Test fixed size 1D Generator + batch
  88. """
  89. logger.info("Test 1D Generator : 0 - 63 + batch(4)")
  90. # apply dataset operations
  91. data1 = ds.GeneratorDataset(generator_1d, ["data"])
  92. data1 = data1.batch(4)
  93. i = 0
  94. for item in data1.create_dict_iterator(): # each data is a dictionary
  95. golden = np.array([[i], [i + 1], [i + 2], [i + 3]])
  96. assert np.array_equal(item["data"], golden)
  97. i = i + 4
  98. def generator_with_type(t):
  99. for i in range(64):
  100. yield (np.array([i], dtype=t),)
  101. def type_tester(t):
  102. logger.info("Test with Type {}".format(t.__name__))
  103. # apply dataset operations
  104. data1 = ds.GeneratorDataset((lambda: generator_with_type(t)), ["data"])
  105. data1 = data1.batch(4)
  106. i = 0
  107. for item in data1.create_dict_iterator(): # each data is a dictionary
  108. golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t)
  109. assert np.array_equal(item["data"], golden)
  110. i = i + 4
  111. def test_case_5():
  112. """
  113. Test 1D Generator on different data type
  114. """
  115. logger.info("Test 1D Generator on all data types")
  116. types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32, np.float64]
  117. for t in types:
  118. type_tester(t)
  119. def type_tester_with_type_check(t, c):
  120. logger.info("Test with Type {}".format(t.__name__))
  121. # apply dataset operations
  122. data1 = ds.GeneratorDataset((lambda: generator_with_type(t)), ["data"], column_types=[c])
  123. data1 = data1.batch(4)
  124. i = 0
  125. for item in data1.create_dict_iterator(): # each data is a dictionary
  126. golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t)
  127. assert np.array_equal(item["data"], golden)
  128. i = i + 4
  129. def test_case_6():
  130. """
  131. Test 1D Generator on different data type with type check
  132. """
  133. logger.info("Test 1D Generator on all data types with type check")
  134. np_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32,
  135. np.float64]
  136. de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32,
  137. mstype.uint64, mstype.float32, mstype.float64]
  138. for i in range(len(np_types)):
  139. type_tester_with_type_check(np_types[i], de_types[i])
  140. def generator_with_type_2c(t):
  141. for i in range(64):
  142. yield (np.array([i], dtype=t), np.array([i], dtype=t))
  143. def type_tester_with_type_check_2c(t, c):
  144. logger.info("Test with Type {}".format(t.__name__))
  145. # apply dataset operations
  146. data1 = ds.GeneratorDataset((lambda: generator_with_type_2c(t)), ["data0", "data1"], column_types=c)
  147. data1 = data1.batch(4)
  148. i = 0
  149. for item in data1.create_dict_iterator(): # each data is a dictionary
  150. golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t)
  151. assert np.array_equal(item["data0"], golden)
  152. i = i + 4
  153. def test_case_7():
  154. """
  155. Test 2 column Generator on different data type with type check
  156. """
  157. logger.info("Test 2 column Generator on all data types with type check")
  158. np_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32,
  159. np.float64]
  160. de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32,
  161. mstype.uint64, mstype.float32, mstype.float64]
  162. for i in range(len(np_types)):
  163. type_tester_with_type_check_2c(np_types[i], [None, de_types[i]])
  164. def test_case_8():
  165. """
  166. Test multi column generator with few mapops
  167. """
  168. logger.info("Test multi column generator with mapops to check the order too")
  169. # apply dataset operations
  170. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  171. data1 = data1.map(input_columns="col0", output_columns="out0", operations=(lambda x: x * 3),
  172. num_parallel_workers=2)
  173. data1 = data1.map(input_columns="col1", output_columns=["out1", "out2"], operations=(lambda x: (x * 7, x)),
  174. num_parallel_workers=2, columns_order=["out0", "out1", "out2"])
  175. data1 = data1.map(input_columns="out2", output_columns="out2", operations=(lambda x: x + 1),
  176. num_parallel_workers=2)
  177. i = 0
  178. for item in data1.create_dict_iterator(): # each data is a dictionary
  179. golden = np.array([i * 3])
  180. assert np.array_equal(item["out0"], golden)
  181. golden = np.array([[i * 7, (i + 1) * 7], [(i + 2) * 7, (i + 3) * 7]])
  182. assert np.array_equal(item["out1"], golden)
  183. golden = np.array([[i + 1, i + 2], [i + 3, i + 4]])
  184. assert np.array_equal(item["out2"], golden)
  185. i = i + 1
  186. def test_case_9():
  187. """
  188. Test map column order when len(input_columns) == len(output_columns).
  189. """
  190. logger.info("Test map column order when len(input_columns) == len(output_columns).")
  191. # apply dataset operations
  192. data1 = ds.GeneratorDataset(generator_mc(2048), ["image", "label"])
  193. data2 = ds.GeneratorDataset(generator_mc(2048), ["label", "image"])
  194. data1 = data1.map(input_columns="label", operations=(lambda x: x * 3),
  195. num_parallel_workers=4)
  196. data2 = data2.map(input_columns="label", operations=(lambda x: x * 3),
  197. num_parallel_workers=4)
  198. # Expected column order is not changed.
  199. # data1 = data[0] is "image" and data[1] is "label"
  200. # data2 = data[0] is "label" and data[1] is "image"
  201. i = 0
  202. for data1, data2 in zip(data1, data2): # each data is a dictionary
  203. golden = np.array([i])
  204. assert np.array_equal(data1[0], golden)
  205. golden = np.array([[i * 3, (i + 1) * 3], [(i + 2) * 3, (i + 3) * 3]])
  206. assert np.array_equal(data1[1], golden)
  207. golden = np.array([i * 3])
  208. assert np.array_equal(data2[0], golden)
  209. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  210. assert np.array_equal(data2[1], golden)
  211. i = i + 1
  212. def test_case_10():
  213. """
  214. Test map column order when len(input_columns) != len(output_columns).
  215. """
  216. logger.info("Test map column order when len(input_columns) != len(output_columns).")
  217. # apply dataset operations
  218. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  219. data1 = data1.map(input_columns="col1", output_columns=["out1", "out2"], operations=(lambda x: (x, x * 5)),
  220. columns_order=['col0', 'out1', 'out2'], num_parallel_workers=2)
  221. # Expected column order is |col0|out1|out2|
  222. i = 0
  223. for item in data1.create_tuple_iterator():
  224. golden = np.array([i])
  225. assert np.array_equal(item[0], golden)
  226. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  227. assert np.array_equal(item[1], golden)
  228. golden = np.array([[i * 5, (i + 1) * 5], [(i + 2) * 5, (i + 3) * 5]])
  229. assert np.array_equal(item[2], golden)
  230. i = i + 1
  231. def test_case_11():
  232. """
  233. Test map column order when len(input_columns) != len(output_columns).
  234. """
  235. logger.info("Test map column order when len(input_columns) != len(output_columns), "
  236. "and columns_order drops some columns.")
  237. # apply dataset operations
  238. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  239. data1 = data1.map(input_columns="col1", output_columns=["out1", "out2"], operations=(lambda x: (x, x * 5)),
  240. columns_order=['out1', 'out2'], num_parallel_workers=2)
  241. # Expected column order is |out1|out2|
  242. i = 0
  243. for item in data1.create_tuple_iterator():
  244. # len should be 2 because col0 is dropped (not included in columns_order)
  245. assert len(item) == 2
  246. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  247. assert np.array_equal(item[0], golden)
  248. golden = np.array([[i * 5, (i + 1) * 5], [(i + 2) * 5, (i + 3) * 5]])
  249. assert np.array_equal(item[1], golden)
  250. i = i + 1
  251. def test_case_12():
  252. """
  253. Test map column order when input_columns and output_columns are None.
  254. """
  255. logger.info("Test map column order when input_columns and output_columns are None.")
  256. # apply dataset operations
  257. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  258. data1 = data1.map(operations=(lambda x: (x * 5)), num_parallel_workers=2)
  259. # Expected column order is |col0|col1|
  260. i = 0
  261. for item in data1.create_tuple_iterator():
  262. assert len(item) == 2
  263. golden = np.array([i * 5])
  264. assert np.array_equal(item[0], golden)
  265. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  266. assert np.array_equal(item[1], golden)
  267. i = i + 1
  268. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  269. data1 = data1.map(operations=(lambda x: (x * 5)), columns_order=["col1", "col0"], num_parallel_workers=2)
  270. # Expected column order is |col0|col1|
  271. i = 0
  272. for item in data1.create_tuple_iterator():
  273. assert len(item) == 2
  274. golden = np.array([i * 5])
  275. assert np.array_equal(item[1], golden)
  276. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  277. assert np.array_equal(item[0], golden)
  278. i = i + 1
  279. def test_case_13():
  280. """
  281. Test map column order when input_columns is None.
  282. """
  283. logger.info("Test map column order when input_columns is None.")
  284. # apply dataset operations
  285. data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"])
  286. data1 = data1.map(operations=(lambda x: (x * 5)), output_columns=["out0"], num_parallel_workers=2)
  287. # Expected column order is |out0|col1|
  288. i = 0
  289. for item in data1.create_tuple_iterator():
  290. assert len(item) == 2
  291. golden = np.array([i * 5])
  292. assert np.array_equal(item[0], golden)
  293. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  294. assert np.array_equal(item[1], golden)
  295. i = i + 1
  296. for item in data1.create_dict_iterator(): # each data is a dictionary
  297. # len should be 2 because col0 is dropped (not included in columns_order)
  298. assert len(item) == 2
  299. golden = np.array([i * 5])
  300. assert np.array_equal(item["out0"], golden)
  301. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  302. assert np.array_equal(item["col1"], golden)
  303. i = i + 1
  304. def test_case_error_1():
  305. def generator_np():
  306. for i in range(64):
  307. yield (np.array([{i}]),)
  308. with pytest.raises(RuntimeError) as info:
  309. data1 = ds.GeneratorDataset(generator_np, ["data"])
  310. for _ in data1:
  311. pass
  312. assert "Invalid data type" in str(info.value)
  313. def test_case_error_2():
  314. def generator_np():
  315. for i in range(64):
  316. yield ({i},)
  317. with pytest.raises(RuntimeError) as info:
  318. data1 = ds.GeneratorDataset(generator_np, ["data"])
  319. for _ in data1:
  320. pass
  321. assert "Generator should return a tuple of numpy arrays" in str(info.value)
  322. def test_case_error_3():
  323. with pytest.raises(ValueError) as info:
  324. # apply dataset operations
  325. data1 = ds.GeneratorDataset(generator_mc(2048), ["label", "image"])
  326. data1 = data1.map(input_columns=["label"], output_columns=["out1", "out2"], operations=(lambda x: (x, x * 5)),
  327. num_parallel_workers=2)
  328. for _ in data1:
  329. pass
  330. assert "When (len(input_columns) != len(output_columns)), columns_order must be specified." in str(info.value)
  331. def test_case_error_4():
  332. with pytest.raises(RuntimeError) as info:
  333. # apply dataset operations
  334. data1 = ds.GeneratorDataset(generator_mc(2048), ["label", "image"])
  335. data1 = data1.map(input_columns=["label"], operations=(lambda x: (x, x * 5)),
  336. num_parallel_workers=2)
  337. for _ in data1:
  338. pass
  339. assert "Unexpected error. Result of a tensorOp doesn't match output column names" in str(info.value)
  340. def test_sequential_sampler():
  341. source = [(np.array([x]),) for x in range(64)]
  342. ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler())
  343. i = 0
  344. for data in ds1.create_dict_iterator(): # each data is a dictionary
  345. golden = np.array([i])
  346. assert np.array_equal(data["data"], golden)
  347. i = i + 1
  348. def test_random_sampler():
  349. source = [(np.array([x]),) for x in range(64)]
  350. ds1 = ds.GeneratorDataset(source, ["data"], shuffle = True)
  351. for data in ds1.create_dict_iterator(): # each data is a dictionary
  352. pass
  353. def test_distributed_sampler():
  354. source = [(np.array([x]),) for x in range(64)]
  355. for sid in range(8):
  356. ds1 = ds.GeneratorDataset(source, ["data"], shuffle = False, num_shards=8, shard_id=sid)
  357. i = sid
  358. for data in ds1.create_dict_iterator(): # each data is a dictionary
  359. golden = np.array([i])
  360. assert np.array_equal(data["data"], golden)
  361. i = i + 8
  362. def test_num_samples():
  363. source = [(np.array([x]),) for x in range(64)]
  364. num_samples = 32
  365. ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(), num_samples = num_samples)
  366. ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(32)], num_samples = num_samples)
  367. ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples = num_samples)
  368. count = 0
  369. for _ in ds1.create_dict_iterator():
  370. count = count + 1
  371. assert count == num_samples
  372. count = 0
  373. for _ in ds2.create_dict_iterator():
  374. count = count + 1
  375. assert count == num_samples
  376. count = 0
  377. for _ in ds3.create_dict_iterator():
  378. count = count + 1
  379. assert count == num_samples
  380. def test_num_samples_underflow():
  381. source = [(np.array([x]),) for x in range(64)]
  382. num_samples = 256
  383. ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(64)], num_samples = num_samples)
  384. ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples = num_samples)
  385. count = 0
  386. for _ in ds2.create_dict_iterator():
  387. count = count + 1
  388. assert count == 64
  389. count = 0
  390. for _ in ds3.create_dict_iterator():
  391. count = count + 1
  392. assert count == 64
  393. if __name__ == "__main__":
  394. test_case_0()
  395. test_case_1()
  396. test_case_2()
  397. test_case_3()
  398. test_case_4()
  399. test_case_5()
  400. test_case_6()
  401. test_case_7()
  402. test_case_8()
  403. test_case_9()
  404. test_case_10()
  405. test_case_11()
  406. test_case_12()
  407. test_case_13()
  408. test_case_error_1()
  409. test_case_error_2()
  410. test_case_error_3()
  411. test_case_error_4()
  412. test_sequential_sampler()
  413. test_distributed_sampler()
  414. test_random_sampler()