You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_split.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import pytest
  16. import mindspore.dataset as ds
  17. from util import config_get_set_num_parallel_workers
  18. # test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631]
  19. # the label of each image is [0,0,0,1,1] each image can be uniquely identified
  20. # via the following lookup table (dict){(83554, 0): 0, (54214, 0): 1, (54214, 1): 2, (65512, 0): 3, (64631, 1): 4}
  21. manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
  22. manifest_map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
  23. def split_with_invalid_inputs(d):
  24. with pytest.raises(ValueError) as info:
  25. s1, s2 = d.split([])
  26. assert "sizes cannot be empty" in str(info.value)
  27. with pytest.raises(ValueError) as info:
  28. s1, s2 = d.split([5, 0.6])
  29. assert "sizes should be list of int or list of float" in str(info.value)
  30. with pytest.raises(ValueError) as info:
  31. s1, s2 = d.split([-1, 6])
  32. assert "there should be no negative numbers" in str(info.value)
  33. with pytest.raises(RuntimeError) as info:
  34. s1, s2 = d.split([3, 1])
  35. assert "sum of split sizes 4 is not equal to dataset size 5" in str(info.value)
  36. with pytest.raises(RuntimeError) as info:
  37. s1, s2 = d.split([5, 1])
  38. assert "sum of split sizes 6 is not equal to dataset size 5" in str(info.value)
  39. with pytest.raises(RuntimeError) as info:
  40. s1, s2 = d.split([0.15, 0.15, 0.15, 0.15, 0.15, 0.25])
  41. assert "sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value)
  42. with pytest.raises(ValueError) as info:
  43. s1, s2 = d.split([-0.5, 0.5])
  44. assert "there should be no numbers outside the range [0, 1]" in str(info.value)
  45. with pytest.raises(ValueError) as info:
  46. s1, s2 = d.split([1.5, 0.5])
  47. assert "there should be no numbers outside the range [0, 1]" in str(info.value)
  48. with pytest.raises(ValueError) as info:
  49. s1, s2 = d.split([0.5, 0.6])
  50. assert "percentages do not sum up to 1" in str(info.value)
  51. with pytest.raises(ValueError) as info:
  52. s1, s2 = d.split([0.3, 0.6])
  53. assert "percentages do not sum up to 1" in str(info.value)
  54. with pytest.raises(RuntimeError) as info:
  55. s1, s2 = d.split([0.05, 0.95])
  56. assert "percentage 0.05 is too small" in str(info.value)
  57. def test_unmappable_invalid_input():
  58. text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
  59. d = ds.TextFileDataset(text_file_dataset_path)
  60. split_with_invalid_inputs(d)
  61. d = ds.TextFileDataset(text_file_dataset_path, num_shards=2, shard_id=0)
  62. with pytest.raises(RuntimeError) as info:
  63. s1, s2 = d.split([4, 1])
  64. assert "dataset should not be sharded before split" in str(info.value)
  65. def test_unmappable_split():
  66. text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
  67. text_file_data = ["This is a text file.", "Another file.", "Be happy every day.",
  68. "End of file.", "Good luck to everyone."]
  69. original_num_parallel_workers = config_get_set_num_parallel_workers(4)
  70. d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
  71. s1, s2 = d.split([4, 1], randomize=False)
  72. s1_output = []
  73. for item in s1.create_dict_iterator():
  74. s1_output.append(item["text"].item().decode("utf8"))
  75. s2_output = []
  76. for item in s2.create_dict_iterator():
  77. s2_output.append(item["text"].item().decode("utf8"))
  78. assert s1_output == text_file_data[0:4]
  79. assert s2_output == text_file_data[4:]
  80. # exact percentages
  81. s1, s2 = d.split([0.8, 0.2], randomize=False)
  82. s1_output = []
  83. for item in s1.create_dict_iterator():
  84. s1_output.append(item["text"].item().decode("utf8"))
  85. s2_output = []
  86. for item in s2.create_dict_iterator():
  87. s2_output.append(item["text"].item().decode("utf8"))
  88. assert s1_output == text_file_data[0:4]
  89. assert s2_output == text_file_data[4:]
  90. # fuzzy percentages
  91. s1, s2 = d.split([0.33, 0.67], randomize=False)
  92. s1_output = []
  93. for item in s1.create_dict_iterator():
  94. s1_output.append(item["text"].item().decode("utf8"))
  95. s2_output = []
  96. for item in s2.create_dict_iterator():
  97. s2_output.append(item["text"].item().decode("utf8"))
  98. assert s1_output == text_file_data[0:2]
  99. assert s2_output == text_file_data[2:]
  100. # Restore configuration num_parallel_workers
  101. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  102. def test_mappable_invalid_input():
  103. d = ds.ManifestDataset(manifest_file)
  104. split_with_invalid_inputs(d)
  105. d = ds.ManifestDataset(manifest_file, num_shards=2, shard_id=0)
  106. with pytest.raises(RuntimeError) as info:
  107. s1, s2 = d.split([4, 1])
  108. assert "dataset should not be sharded before split" in str(info.value)
  109. def test_mappable_split_general():
  110. d = ds.ManifestDataset(manifest_file, shuffle=False)
  111. d = d.take(5)
  112. # absolute rows
  113. s1, s2 = d.split([4, 1], randomize=False)
  114. s1_output = []
  115. for item in s1.create_dict_iterator():
  116. s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  117. s2_output = []
  118. for item in s2.create_dict_iterator():
  119. s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  120. assert s1_output == [0, 1, 2, 3]
  121. assert s2_output == [4]
  122. # exact percentages
  123. s1, s2 = d.split([0.8, 0.2], randomize=False)
  124. s1_output = []
  125. for item in s1.create_dict_iterator():
  126. s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  127. s2_output = []
  128. for item in s2.create_dict_iterator():
  129. s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  130. assert s1_output == [0, 1, 2, 3]
  131. assert s2_output == [4]
  132. # fuzzy percentages
  133. s1, s2 = d.split([0.33, 0.67], randomize=False)
  134. s1_output = []
  135. for item in s1.create_dict_iterator():
  136. s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  137. s2_output = []
  138. for item in s2.create_dict_iterator():
  139. s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  140. assert s1_output == [0, 1]
  141. assert s2_output == [2, 3, 4]
  142. def test_mappable_split_optimized():
  143. d = ds.ManifestDataset(manifest_file, shuffle=False)
  144. # absolute rows
  145. s1, s2 = d.split([4, 1], randomize=False)
  146. s1_output = []
  147. for item in s1.create_dict_iterator():
  148. s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  149. s2_output = []
  150. for item in s2.create_dict_iterator():
  151. s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  152. assert s1_output == [0, 1, 2, 3]
  153. assert s2_output == [4]
  154. # exact percentages
  155. s1, s2 = d.split([0.8, 0.2], randomize=False)
  156. s1_output = []
  157. for item in s1.create_dict_iterator():
  158. s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  159. s2_output = []
  160. for item in s2.create_dict_iterator():
  161. s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  162. assert s1_output == [0, 1, 2, 3]
  163. assert s2_output == [4]
  164. # fuzzy percentages
  165. s1, s2 = d.split([0.33, 0.67], randomize=False)
  166. s1_output = []
  167. for item in s1.create_dict_iterator():
  168. s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  169. s2_output = []
  170. for item in s2.create_dict_iterator():
  171. s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  172. assert s1_output == [0, 1]
  173. assert s2_output == [2, 3, 4]
  174. def test_mappable_randomize_deterministic():
  175. # set arbitrary seed for shard after split
  176. # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
  177. ds.config.set_seed(53)
  178. d = ds.ManifestDataset(manifest_file, shuffle=False)
  179. s1, s2 = d.split([0.8, 0.2])
  180. for _ in range(10):
  181. s1_output = []
  182. for item in s1.create_dict_iterator():
  183. s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  184. s2_output = []
  185. for item in s2.create_dict_iterator():
  186. s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  187. # note no overlap
  188. assert s1_output == [0, 1, 3, 4]
  189. assert s2_output == [2]
  190. def test_mappable_randomize_repeatable():
  191. # set arbitrary seed for shard after split
  192. # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
  193. ds.config.set_seed(53)
  194. d = ds.ManifestDataset(manifest_file, shuffle=False)
  195. s1, s2 = d.split([0.8, 0.2])
  196. num_epochs = 5
  197. s1 = s1.repeat(num_epochs)
  198. s2 = s2.repeat(num_epochs)
  199. s1_output = []
  200. for item in s1.create_dict_iterator():
  201. s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  202. s2_output = []
  203. for item in s2.create_dict_iterator():
  204. s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  205. # note no overlap
  206. assert s1_output == [0, 1, 3, 4] * num_epochs
  207. assert s2_output == [2] * num_epochs
  208. def test_mappable_sharding():
  209. # set arbitrary seed for repeatability for shard after split
  210. # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
  211. ds.config.set_seed(53)
  212. num_epochs = 5
  213. first_split_num_rows = 4
  214. d = ds.ManifestDataset(manifest_file, shuffle=False)
  215. s1, s2 = d.split([first_split_num_rows, 1])
  216. distributed_sampler = ds.DistributedSampler(2, 0)
  217. s1.use_sampler(distributed_sampler)
  218. s1 = s1.repeat(num_epochs)
  219. # testing sharding, second dataset to simulate another instance
  220. d2 = ds.ManifestDataset(manifest_file, shuffle=False)
  221. d2s1, d2s2 = d2.split([first_split_num_rows, 1])
  222. distributed_sampler = ds.DistributedSampler(2, 1)
  223. d2s1.use_sampler(distributed_sampler)
  224. d2s1 = d2s1.repeat(num_epochs)
  225. # shard 0
  226. s1_output = []
  227. for item in s1.create_dict_iterator():
  228. s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  229. # shard 1
  230. d2s1_output = []
  231. for item in d2s1.create_dict_iterator():
  232. d2s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  233. rows_per_shard_per_epoch = 2
  234. assert len(s1_output) == rows_per_shard_per_epoch * num_epochs
  235. assert len(d2s1_output) == rows_per_shard_per_epoch * num_epochs
  236. # verify each epoch that
  237. # 1. shards contain no common elements
  238. # 2. the data was split the same way, and that the union of shards equal the split
  239. correct_sorted_split_result = [0, 1, 3, 4]
  240. for i in range(num_epochs):
  241. combined_data = []
  242. for j in range(rows_per_shard_per_epoch):
  243. combined_data.append(s1_output[i * rows_per_shard_per_epoch + j])
  244. combined_data.append(d2s1_output[i * rows_per_shard_per_epoch + j])
  245. assert sorted(combined_data) == correct_sorted_split_result
  246. # test other split
  247. s2_output = []
  248. for item in s2.create_dict_iterator():
  249. s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  250. d2s2_output = []
  251. for item in d2s2.create_dict_iterator():
  252. d2s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
  253. assert s2_output == [2]
  254. assert d2s2_output == [2]
  255. if __name__ == '__main__':
  256. test_unmappable_invalid_input()
  257. test_unmappable_split()
  258. test_mappable_invalid_input()
  259. test_mappable_split_general()
  260. test_mappable_split_optimized()
  261. test_mappable_randomize_deterministic()
  262. test_mappable_randomize_repeatable()
  263. test_mappable_sharding()