You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_split.py 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import pytest
  16. import mindspore.dataset as ds
  17. from util import config_get_set_num_parallel_workers
  18. # test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631]
  19. # the label of each image is [0,0,0,1,1] each image can be uniquely identified
  20. # via the following lookup table (dict){(83554, 0): 0, (54214, 0): 1, (54214, 1): 2, (65512, 0): 3, (64631, 1): 4}
  21. manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
  22. manifest_map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
  23. text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
  24. text_file_data = ["This is a text file.", "Another file.", "Be happy every day.",
  25. "End of file.", "Good luck to everyone."]
  26. def split_with_invalid_inputs(d):
  27. with pytest.raises(ValueError) as info:
  28. _, _ = d.split([])
  29. assert "sizes cannot be empty" in str(info.value)
  30. with pytest.raises(ValueError) as info:
  31. _, _ = d.split([5, 0.6])
  32. assert "sizes should be list of int or list of float" in str(info.value)
  33. with pytest.raises(ValueError) as info:
  34. _, _ = d.split([-1, 6])
  35. assert "there should be no negative or zero numbers" in str(info.value)
  36. with pytest.raises(RuntimeError) as info:
  37. _, _ = d.split([3, 1])
  38. assert "Sum of split sizes 4 is not equal to dataset size 5" in str(info.value)
  39. with pytest.raises(RuntimeError) as info:
  40. _, _ = d.split([5, 1])
  41. assert "Sum of split sizes 6 is not equal to dataset size 5" in str(info.value)
  42. with pytest.raises(RuntimeError) as info:
  43. _, _ = d.split([0.15, 0.15, 0.15, 0.15, 0.15, 0.25])
  44. assert "Sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value)
  45. with pytest.raises(ValueError) as info:
  46. _, _ = d.split([-0.5, 0.5])
  47. assert "there should be no numbers outside the range (0, 1]" in str(info.value)
  48. with pytest.raises(ValueError) as info:
  49. _, _ = d.split([1.5, 0.5])
  50. assert "there should be no numbers outside the range (0, 1]" in str(info.value)
  51. with pytest.raises(ValueError) as info:
  52. _, _ = d.split([0.5, 0.6])
  53. assert "percentages do not sum up to 1" in str(info.value)
  54. with pytest.raises(ValueError) as info:
  55. _, _ = d.split([0.3, 0.6])
  56. assert "percentages do not sum up to 1" in str(info.value)
  57. with pytest.raises(RuntimeError) as info:
  58. _, _ = d.split([0.05, 0.95])
  59. assert "percentage 0.05 is too small" in str(info.value)
  60. def test_unmappable_invalid_input():
  61. d = ds.TextFileDataset(text_file_dataset_path)
  62. split_with_invalid_inputs(d)
  63. d = ds.TextFileDataset(text_file_dataset_path, num_shards=2, shard_id=0)
  64. with pytest.raises(RuntimeError) as info:
  65. _, _ = d.split([4, 1])
  66. assert "Dataset should not be sharded before split" in str(info.value)
  67. def test_unmappable_split():
  68. original_num_parallel_workers = config_get_set_num_parallel_workers(4)
  69. d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
  70. s1, s2 = d.split([4, 1], randomize=False)
  71. s1_output = []
  72. for item in s1.create_dict_iterator(num_epochs=1):
  73. s1_output.append(item["text"].item().decode("utf8"))
  74. s2_output = []
  75. for item in s2.create_dict_iterator(num_epochs=1):
  76. s2_output.append(item["text"].item().decode("utf8"))
  77. assert s1_output == text_file_data[0:4]
  78. assert s2_output == text_file_data[4:]
  79. # exact percentages
  80. s1, s2 = d.split([0.8, 0.2], randomize=False)
  81. s1_output = []
  82. for item in s1.create_dict_iterator(num_epochs=1):
  83. s1_output.append(item["text"].item().decode("utf8"))
  84. s2_output = []
  85. for item in s2.create_dict_iterator(num_epochs=1):
  86. s2_output.append(item["text"].item().decode("utf8"))
  87. assert s1_output == text_file_data[0:4]
  88. assert s2_output == text_file_data[4:]
  89. # fuzzy percentages
  90. s1, s2 = d.split([0.33, 0.67], randomize=False)
  91. s1_output = []
  92. for item in s1.create_dict_iterator(num_epochs=1):
  93. s1_output.append(item["text"].item().decode("utf8"))
  94. s2_output = []
  95. for item in s2.create_dict_iterator(num_epochs=1):
  96. s2_output.append(item["text"].item().decode("utf8"))
  97. assert s1_output == text_file_data[0:2]
  98. assert s2_output == text_file_data[2:]
  99. # Restore configuration num_parallel_workers
  100. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  101. def test_unmappable_randomize_deterministic():
  102. original_num_parallel_workers = config_get_set_num_parallel_workers(4)
  103. # the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3]
  104. ds.config.set_seed(53)
  105. d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
  106. s1, s2 = d.split([0.8, 0.2])
  107. for _ in range(10):
  108. s1_output = []
  109. for item in s1.create_dict_iterator(num_epochs=1):
  110. s1_output.append(item["text"].item().decode("utf8"))
  111. s2_output = []
  112. for item in s2.create_dict_iterator(num_epochs=1):
  113. s2_output.append(item["text"].item().decode("utf8"))
  114. # note no overlap
  115. assert s1_output == [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]]
  116. assert s2_output == [text_file_data[3]]
  117. # Restore configuration num_parallel_workers
  118. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  119. def test_unmappable_randomize_repeatable():
  120. original_num_parallel_workers = config_get_set_num_parallel_workers(4)
  121. # the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3]
  122. ds.config.set_seed(53)
  123. d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
  124. s1, s2 = d.split([0.8, 0.2])
  125. num_epochs = 5
  126. s1 = s1.repeat(num_epochs)
  127. s2 = s2.repeat(num_epochs)
  128. s1_output = []
  129. for item in s1.create_dict_iterator(num_epochs=1):
  130. s1_output.append(item["text"].item().decode("utf8"))
  131. s2_output = []
  132. for item in s2.create_dict_iterator(num_epochs=1):
  133. s2_output.append(item["text"].item().decode("utf8"))
  134. # note no overlap
  135. assert s1_output == [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]] * num_epochs
  136. assert s2_output == [text_file_data[3]] * num_epochs
  137. # Restore configuration num_parallel_workers
  138. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  139. def test_unmappable_get_dataset_size():
  140. d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
  141. s1, s2 = d.split([0.8, 0.2])
  142. assert d.get_dataset_size() == 5
  143. assert s1.get_dataset_size() == 4
  144. assert s2.get_dataset_size() == 1
  145. def test_unmappable_multi_split():
  146. original_num_parallel_workers = config_get_set_num_parallel_workers(4)
  147. # the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3]
  148. ds.config.set_seed(53)
  149. d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
  150. s1, s2 = d.split([4, 1])
  151. s1_correct_output = [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]]
  152. s1_output = []
  153. for item in s1.create_dict_iterator(num_epochs=1):
  154. s1_output.append(item["text"].item().decode("utf8"))
  155. assert s1_output == s1_correct_output
  156. # no randomize in second split
  157. s1s1, s1s2, s1s3 = s1.split([1, 2, 1], randomize=False)
  158. s1s1_output = []
  159. for item in s1s1.create_dict_iterator(num_epochs=1):
  160. s1s1_output.append(item["text"].item().decode("utf8"))
  161. s1s2_output = []
  162. for item in s1s2.create_dict_iterator(num_epochs=1):
  163. s1s2_output.append(item["text"].item().decode("utf8"))
  164. s1s3_output = []
  165. for item in s1s3.create_dict_iterator(num_epochs=1):
  166. s1s3_output.append(item["text"].item().decode("utf8"))
  167. assert s1s1_output == [s1_correct_output[0]]
  168. assert s1s2_output == [s1_correct_output[1], s1_correct_output[2]]
  169. assert s1s3_output == [s1_correct_output[3]]
  170. s2_output = []
  171. for item in s2.create_dict_iterator(num_epochs=1):
  172. s2_output.append(item["text"].item().decode("utf8"))
  173. assert s2_output == [text_file_data[3]]
  174. # randomize in second split
  175. # the labels outputted by the ShuffleOp for seed 53 is [2, 3, 1, 0]
  176. shuffled_ids = [2, 3, 1, 0]
  177. s1s1, s1s2, s1s3 = s1.split([1, 2, 1])
  178. s1s1_output = []
  179. for item in s1s1.create_dict_iterator(num_epochs=1):
  180. s1s1_output.append(item["text"].item().decode("utf8"))
  181. s1s2_output = []
  182. for item in s1s2.create_dict_iterator(num_epochs=1):
  183. s1s2_output.append(item["text"].item().decode("utf8"))
  184. s1s3_output = []
  185. for item in s1s3.create_dict_iterator(num_epochs=1):
  186. s1s3_output.append(item["text"].item().decode("utf8"))
  187. assert s1s1_output == [s1_correct_output[shuffled_ids[0]]]
  188. assert s1s2_output == [s1_correct_output[shuffled_ids[1]], s1_correct_output[shuffled_ids[2]]]
  189. assert s1s3_output == [s1_correct_output[shuffled_ids[3]]]
  190. s2_output = []
  191. for item in s2.create_dict_iterator(num_epochs=1):
  192. s2_output.append(item["text"].item().decode("utf8"))
  193. assert s2_output == [text_file_data[3]]
  194. # Restore configuration num_parallel_workers
  195. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  196. def test_mappable_invalid_input():
  197. d = ds.ManifestDataset(manifest_file)
  198. split_with_invalid_inputs(d)
  199. d = ds.ManifestDataset(manifest_file, num_shards=2, shard_id=0)
  200. with pytest.raises(RuntimeError) as info:
  201. _, _ = d.split([4, 1])
  202. assert "Dataset should not be sharded before split" in str(info.value)
  203. def test_mappable_split_general():
  204. d = ds.ManifestDataset(manifest_file, shuffle=False)
  205. d = d.take(5)
  206. # absolute rows
  207. s1, s2 = d.split([4, 1], randomize=False)
  208. s1_output = []
  209. for item in s1.create_dict_iterator(num_epochs=1):
  210. s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  211. s2_output = []
  212. for item in s2.create_dict_iterator(num_epochs=1):
  213. s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  214. assert s1_output == [0, 1, 2, 3]
  215. assert s2_output == [4]
  216. # exact percentages
  217. s1, s2 = d.split([0.8, 0.2], randomize=False)
  218. s1_output = []
  219. for item in s1.create_dict_iterator(num_epochs=1):
  220. s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  221. s2_output = []
  222. for item in s2.create_dict_iterator(num_epochs=1):
  223. s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  224. assert s1_output == [0, 1, 2, 3]
  225. assert s2_output == [4]
  226. # fuzzy percentages
  227. s1, s2 = d.split([0.33, 0.67], randomize=False)
  228. s1_output = []
  229. for item in s1.create_dict_iterator(num_epochs=1):
  230. s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  231. s2_output = []
  232. for item in s2.create_dict_iterator(num_epochs=1):
  233. s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  234. assert s1_output == [0, 1]
  235. assert s2_output == [2, 3, 4]
  236. def test_mappable_split_optimized():
  237. d = ds.ManifestDataset(manifest_file, shuffle=False)
  238. # absolute rows
  239. s1, s2 = d.split([4, 1], randomize=False)
  240. s1_output = []
  241. for item in s1.create_dict_iterator(num_epochs=1):
  242. s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  243. s2_output = []
  244. for item in s2.create_dict_iterator(num_epochs=1):
  245. s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  246. assert s1_output == [0, 1, 2, 3]
  247. assert s2_output == [4]
  248. # exact percentages
  249. s1, s2 = d.split([0.8, 0.2], randomize=False)
  250. s1_output = []
  251. for item in s1.create_dict_iterator(num_epochs=1):
  252. s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  253. s2_output = []
  254. for item in s2.create_dict_iterator(num_epochs=1):
  255. s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  256. assert s1_output == [0, 1, 2, 3]
  257. assert s2_output == [4]
  258. # fuzzy percentages
  259. s1, s2 = d.split([0.33, 0.67], randomize=False)
  260. s1_output = []
  261. for item in s1.create_dict_iterator(num_epochs=1):
  262. s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  263. s2_output = []
  264. for item in s2.create_dict_iterator(num_epochs=1):
  265. s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  266. assert s1_output == [0, 1]
  267. assert s2_output == [2, 3, 4]
  268. def test_mappable_randomize_deterministic():
  269. # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2]
  270. ds.config.set_seed(53)
  271. d = ds.ManifestDataset(manifest_file, shuffle=False)
  272. s1, s2 = d.split([0.8, 0.2])
  273. for _ in range(10):
  274. s1_output = []
  275. for item in s1.create_dict_iterator(num_epochs=1):
  276. s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  277. s2_output = []
  278. for item in s2.create_dict_iterator(num_epochs=1):
  279. s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  280. # note no overlap
  281. assert s1_output == [0, 1, 3, 4]
  282. assert s2_output == [2]
  283. def test_mappable_randomize_repeatable():
  284. # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2]
  285. ds.config.set_seed(53)
  286. d = ds.ManifestDataset(manifest_file, shuffle=False)
  287. s1, s2 = d.split([0.8, 0.2])
  288. num_epochs = 5
  289. s1 = s1.repeat(num_epochs)
  290. s2 = s2.repeat(num_epochs)
  291. s1_output = []
  292. for item in s1.create_dict_iterator(num_epochs=1):
  293. s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  294. s2_output = []
  295. for item in s2.create_dict_iterator(num_epochs=1):
  296. s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  297. # note no overlap
  298. assert s1_output == [0, 1, 3, 4] * num_epochs
  299. assert s2_output == [2] * num_epochs
  300. def test_mappable_sharding():
  301. # set arbitrary seed for repeatability for shard after split
  302. # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2]
  303. ds.config.set_seed(53)
  304. num_epochs = 5
  305. first_split_num_rows = 4
  306. d = ds.ManifestDataset(manifest_file, shuffle=False)
  307. s1, s2 = d.split([first_split_num_rows, 1])
  308. distributed_sampler = ds.DistributedSampler(2, 0)
  309. s1.use_sampler(distributed_sampler)
  310. s1 = s1.repeat(num_epochs)
  311. # testing sharding, second dataset to simulate another instance
  312. d2 = ds.ManifestDataset(manifest_file, shuffle=False)
  313. d2s1, d2s2 = d2.split([first_split_num_rows, 1])
  314. distributed_sampler = ds.DistributedSampler(2, 1)
  315. d2s1.use_sampler(distributed_sampler)
  316. d2s1 = d2s1.repeat(num_epochs)
  317. # shard 0
  318. s1_output = []
  319. for item in s1.create_dict_iterator(num_epochs=1):
  320. s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  321. # shard 1
  322. d2s1_output = []
  323. for item in d2s1.create_dict_iterator(num_epochs=1):
  324. d2s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  325. rows_per_shard_per_epoch = 2
  326. assert len(s1_output) == rows_per_shard_per_epoch * num_epochs
  327. assert len(d2s1_output) == rows_per_shard_per_epoch * num_epochs
  328. # verify each epoch that
  329. # 1. shards contain no common elements
  330. # 2. the data was split the same way, and that the union of shards equal the split
  331. correct_sorted_split_result = [0, 1, 3, 4]
  332. for i in range(num_epochs):
  333. combined_data = []
  334. for j in range(rows_per_shard_per_epoch):
  335. combined_data.append(s1_output[i * rows_per_shard_per_epoch + j])
  336. combined_data.append(d2s1_output[i * rows_per_shard_per_epoch + j])
  337. assert sorted(combined_data) == correct_sorted_split_result
  338. # test other split
  339. s2_output = []
  340. for item in s2.create_dict_iterator(num_epochs=1):
  341. s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  342. d2s2_output = []
  343. for item in d2s2.create_dict_iterator(num_epochs=1):
  344. d2s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  345. assert s2_output == [2]
  346. assert d2s2_output == [2]
  347. def test_mappable_get_dataset_size():
  348. d = ds.ManifestDataset(manifest_file, shuffle=False)
  349. s1, s2 = d.split([4, 1])
  350. assert d.get_dataset_size() == 5
  351. assert s1.get_dataset_size() == 4
  352. assert s2.get_dataset_size() == 1
  353. def test_mappable_multi_split():
  354. # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2]
  355. ds.config.set_seed(53)
  356. d = ds.ManifestDataset(manifest_file, shuffle=False)
  357. s1, s2 = d.split([4, 1])
  358. s1_correct_output = [0, 1, 3, 4]
  359. s1_output = []
  360. for item in s1.create_dict_iterator(num_epochs=1):
  361. s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  362. assert s1_output == s1_correct_output
  363. # no randomize in second split
  364. s1s1, s1s2, s1s3 = s1.split([1, 2, 1], randomize=False)
  365. s1s1_output = []
  366. for item in s1s1.create_dict_iterator(num_epochs=1):
  367. s1s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  368. s1s2_output = []
  369. for item in s1s2.create_dict_iterator(num_epochs=1):
  370. s1s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  371. s1s3_output = []
  372. for item in s1s3.create_dict_iterator(num_epochs=1):
  373. s1s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  374. assert s1s1_output == [s1_correct_output[0]]
  375. assert s1s2_output == [s1_correct_output[1], s1_correct_output[2]]
  376. assert s1s3_output == [s1_correct_output[3]]
  377. s2_output = []
  378. for item in s2.create_dict_iterator(num_epochs=1):
  379. s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  380. assert s2_output == [2]
  381. # randomize in second split
  382. # the labels outputted by the RandomSampler for seed 53 is [3, 1, 2, 0]
  383. random_sampler_ids = [3, 1, 2, 0]
  384. s1s1, s1s2, s1s3 = s1.split([1, 2, 1])
  385. s1s1_output = []
  386. for item in s1s1.create_dict_iterator(num_epochs=1):
  387. s1s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  388. s1s2_output = []
  389. for item in s1s2.create_dict_iterator(num_epochs=1):
  390. s1s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  391. s1s3_output = []
  392. for item in s1s3.create_dict_iterator(num_epochs=1):
  393. s1s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  394. assert s1s1_output == [s1_correct_output[random_sampler_ids[0]]]
  395. assert s1s2_output == [s1_correct_output[random_sampler_ids[1]], s1_correct_output[random_sampler_ids[2]]]
  396. assert s1s3_output == [s1_correct_output[random_sampler_ids[3]]]
  397. s2_output = []
  398. for item in s2.create_dict_iterator(num_epochs=1):
  399. s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  400. assert s2_output == [2]
  401. def test_rounding():
  402. d = ds.ManifestDataset(manifest_file, shuffle=False)
  403. # under rounding
  404. s1, s2 = d.split([0.5, 0.5], randomize=False)
  405. s1_output = []
  406. for item in s1.create_dict_iterator(num_epochs=1):
  407. s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  408. s2_output = []
  409. for item in s2.create_dict_iterator(num_epochs=1):
  410. s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  411. assert s1_output == [0, 1, 2]
  412. assert s2_output == [3, 4]
  413. # over rounding
  414. s1, s2, s3 = d.split([0.15, 0.55, 0.3], randomize=False)
  415. s1_output = []
  416. for item in s1.create_dict_iterator(num_epochs=1):
  417. s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  418. s2_output = []
  419. for item in s2.create_dict_iterator(num_epochs=1):
  420. s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  421. s3_output = []
  422. for item in s3.create_dict_iterator(num_epochs=1):
  423. s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  424. assert s1_output == [0]
  425. assert s2_output == [1, 2]
  426. assert s3_output == [3, 4]
  427. if __name__ == '__main__':
  428. test_unmappable_invalid_input()
  429. test_unmappable_split()
  430. test_unmappable_randomize_deterministic()
  431. test_unmappable_randomize_repeatable()
  432. test_unmappable_get_dataset_size()
  433. test_unmappable_multi_split()
  434. test_mappable_invalid_input()
  435. test_mappable_split_general()
  436. test_mappable_split_optimized()
  437. test_mappable_randomize_deterministic()
  438. test_mappable_randomize_repeatable()
  439. test_mappable_sharding()
  440. test_mappable_get_dataset_size()
  441. test_mappable_multi_split()
  442. test_rounding()