You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_wider_face.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import pytest
  16. import numpy as np
  17. import mindspore.dataset as ds
  18. import mindspore.dataset.vision.c_transforms as vision
  19. import mindspore.log as logger
  20. DATA_DIR = "../data/dataset/testWIDERFace/"
  21. def test_wider_face_basic():
  22. """
  23. Feature: WIDERFace dataset
  24. Description: Read all files
  25. Expectation: Throw number of data in all files
  26. """
  27. logger.info("Test WIDERFaceDataset Op")
  28. # case 1: test loading default usage dataset
  29. data1 = ds.WIDERFaceDataset(DATA_DIR)
  30. num_iter1 = 0
  31. for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  32. num_iter1 += 1
  33. assert num_iter1 == 4
  34. # case 2: test num_samples
  35. data2 = ds.WIDERFaceDataset(DATA_DIR, num_samples=1)
  36. num_iter2 = 0
  37. for _ in data2.create_dict_iterator(num_epochs=1, output_numpy=True):
  38. num_iter2 += 1
  39. assert num_iter2 == 1
  40. # case 3: test repeat
  41. data3 = ds.WIDERFaceDataset(DATA_DIR, num_samples=2)
  42. data3 = data3.repeat(5)
  43. num_iter3 = 0
  44. for _ in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
  45. num_iter3 += 1
  46. assert num_iter3 == 10
  47. def test_wider_face_noshuffle():
  48. """
  49. Feature: WIDERFace dataset
  50. Description: Test noshuffle
  51. Expectation: Throw number of data in all files
  52. """
  53. logger.info("Test Case noShuffle")
  54. # define parameters
  55. repeat_count = 1
  56. # apply dataset operations
  57. # Note: "all" reads both "train" dataset (2 samples) and "valid" dataset (2 samples)
  58. data1 = ds.WIDERFaceDataset(DATA_DIR, shuffle=False)
  59. data1 = data1.repeat(repeat_count)
  60. num_iter = 0
  61. # each data is a dictionary
  62. for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  63. num_iter += 1
  64. assert num_iter == 4
  65. def test_wider_face_usage():
  66. """
  67. Feature: WIDERFace dataset
  68. Description: Test Usage
  69. Expectation: Throw number of data in all files
  70. """
  71. logger.info("Test WIDERFaceDataset usage flag")
  72. def test_config(usage, wider_face_path=DATA_DIR):
  73. try:
  74. data = ds.WIDERFaceDataset(wider_face_path, usage=usage)
  75. num_rows = 0
  76. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  77. num_rows += 1
  78. except (ValueError, TypeError, RuntimeError) as e:
  79. return str(e)
  80. return num_rows
  81. # test the usage of WIDERFacce
  82. assert test_config("test") == 3
  83. assert test_config("train") == 2
  84. assert test_config("valid") == 2
  85. assert test_config("all") == 4
  86. assert "usage is not within the valid set of ['train', 'test', 'valid', 'all']" in test_config(
  87. "invalid")
  88. # change to the folder that contains all WIDERFacce files
  89. all_wider_face = None
  90. if all_wider_face is not None:
  91. assert test_config("test", all_wider_face) == 16097
  92. assert test_config("valid", all_wider_face) == 3226
  93. assert test_config("train", all_wider_face) == 12880
  94. assert test_config("all", all_wider_face) == 16106
  95. assert ds.WIDERFaceDataset(all_wider_face, usage="test").get_dataset_size() == 16097
  96. assert ds.WIDERFaceDataset(all_wider_face, usage="valid").get_dataset_size() == 3226
  97. assert ds.WIDERFaceDataset(all_wider_face, usage="train").get_dataset_size() == 12880
  98. assert ds.WIDERFaceDataset(all_wider_face, usage="all").get_dataset_size() == 16106
  99. def test_wider_face_sequential_sampler():
  100. """
  101. Feature: WIDERFace dataset
  102. Description: test SequentialSampler
  103. Expectation: get correct number of data
  104. """
  105. num_samples = 1
  106. sampler = ds.SequentialSampler(num_samples=num_samples)
  107. data1 = ds.WIDERFaceDataset(DATA_DIR, 'test', sampler=sampler)
  108. data2 = ds.WIDERFaceDataset(DATA_DIR, 'test', shuffle=False, num_samples=num_samples)
  109. matches_list1, matches_list2 = [], []
  110. num_iter = 0
  111. for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1), data2.create_dict_iterator(num_epochs=1)):
  112. matches_list1.append(item1["image"].asnumpy())
  113. matches_list2.append(item2["image"].asnumpy())
  114. num_iter += 1
  115. np.testing.assert_array_equal(matches_list1, matches_list2)
  116. assert num_iter == num_samples
  117. def test_wider_face_pipeline():
  118. """
  119. Feature: Pipeline test
  120. Description: Read a sample
  121. Expectation: The amount of each function are equal
  122. """
  123. dataset = ds.WIDERFaceDataset(DATA_DIR, "valid", num_samples=1, decode=True)
  124. resize_op = vision.Resize((100, 100))
  125. dataset = dataset.map(input_columns=["image"], operations=resize_op)
  126. num_iter = 0
  127. for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  128. num_iter += 1
  129. assert num_iter == 1
  130. def test_wider_face_exception():
  131. """
  132. Feature: WIDERFace dataset
  133. Description: Throw error messages when certain errors occur
  134. Expectation: Error message
  135. """
  136. logger.info("Test error cases for WIDERFaceDataset")
  137. error_msg_1 = "sampler and shuffle cannot be specified at the same time"
  138. with pytest.raises(RuntimeError, match=error_msg_1):
  139. ds.WIDERFaceDataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3))
  140. error_msg_2 = "sampler and sharding cannot be specified at the same time"
  141. with pytest.raises(RuntimeError, match=error_msg_2):
  142. ds.WIDERFaceDataset(DATA_DIR, sampler=ds.PKSampler(
  143. 3), num_shards=2, shard_id=0)
  144. error_msg_3 = "num_shards is specified and currently requires shard_id as well"
  145. with pytest.raises(RuntimeError, match=error_msg_3):
  146. ds.WIDERFaceDataset(DATA_DIR, num_shards=10)
  147. error_msg_4 = "shard_id is specified but num_shards is not"
  148. with pytest.raises(RuntimeError, match=error_msg_4):
  149. ds.WIDERFaceDataset(DATA_DIR, shard_id=0)
  150. error_msg_5 = "Input shard_id is not within the required interval"
  151. with pytest.raises(ValueError, match=error_msg_5):
  152. ds.WIDERFaceDataset(DATA_DIR, num_shards=5, shard_id=-1)
  153. with pytest.raises(ValueError, match=error_msg_5):
  154. ds.WIDERFaceDataset(DATA_DIR, num_shards=5, shard_id=5)
  155. with pytest.raises(ValueError, match=error_msg_5):
  156. ds.WIDERFaceDataset(DATA_DIR, num_shards=2, shard_id=5)
  157. error_msg_6 = "num_parallel_workers exceeds"
  158. with pytest.raises(ValueError, match=error_msg_6):
  159. ds.WIDERFaceDataset(DATA_DIR, shuffle=False, num_parallel_workers=0)
  160. with pytest.raises(ValueError, match=error_msg_6):
  161. ds.WIDERFaceDataset(DATA_DIR, shuffle=False, num_parallel_workers=256)
  162. with pytest.raises(ValueError, match=error_msg_6):
  163. ds.WIDERFaceDataset(DATA_DIR, shuffle=False, num_parallel_workers=-2)
  164. error_msg_7 = "Argument shard_id"
  165. with pytest.raises(TypeError, match=error_msg_7):
  166. ds.WIDERFaceDataset(DATA_DIR, num_shards=2, shard_id="0")
  167. def exception_func(item):
  168. raise Exception("Error occur!")
  169. # usage = test
  170. try:
  171. data = ds.WIDERFaceDataset(DATA_DIR, usage="test", shuffle=False)
  172. data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
  173. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  174. pass
  175. assert False
  176. except RuntimeError as e:
  177. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  178. # usage = all
  179. try:
  180. data = ds.WIDERFaceDataset(DATA_DIR, usage="all", shuffle=False)
  181. data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
  182. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  183. pass
  184. assert False
  185. except RuntimeError as e:
  186. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  187. try:
  188. data = ds.WIDERFaceDataset(DATA_DIR, shuffle=False)
  189. data = data.map(operations=exception_func, input_columns=["bbox"], num_parallel_workers=1)
  190. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  191. pass
  192. assert False
  193. except RuntimeError as e:
  194. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  195. try:
  196. data = ds.WIDERFaceDataset(DATA_DIR, shuffle=False)
  197. data = data.map(operations=exception_func, input_columns=["blur"], num_parallel_workers=1)
  198. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  199. pass
  200. assert False
  201. except RuntimeError as e:
  202. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  203. try:
  204. data = ds.WIDERFaceDataset(DATA_DIR, shuffle=False)
  205. data = data.map(operations=exception_func, input_columns=["expression"], num_parallel_workers=1)
  206. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  207. pass
  208. assert False
  209. except RuntimeError as e:
  210. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  211. try:
  212. data = ds.WIDERFaceDataset(DATA_DIR, shuffle=False)
  213. data = data.map(operations=exception_func, input_columns=["illumination"], num_parallel_workers=1)
  214. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  215. pass
  216. assert False
  217. except RuntimeError as e:
  218. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  219. try:
  220. data = ds.WIDERFaceDataset(DATA_DIR, shuffle=False)
  221. data = data.map(operations=exception_func, input_columns=["occlusion"], num_parallel_workers=1)
  222. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  223. pass
  224. assert False
  225. except RuntimeError as e:
  226. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  227. try:
  228. data = ds.WIDERFaceDataset(DATA_DIR, shuffle=False)
  229. data = data.map(operations=exception_func, input_columns=["pose"], num_parallel_workers=1)
  230. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  231. pass
  232. assert False
  233. except RuntimeError as e:
  234. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  235. try:
  236. data = ds.WIDERFaceDataset(DATA_DIR, shuffle=False)
  237. data = data.map(operations=exception_func, input_columns=["invalid"], num_parallel_workers=1)
  238. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  239. pass
  240. assert False
  241. except RuntimeError as e:
  242. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  243. if __name__ == '__main__':
  244. test_wider_face_basic()
  245. test_wider_face_sequential_sampler()
  246. test_wider_face_noshuffle()
  247. test_wider_face_usage()
  248. test_wider_face_pipeline()
  249. test_wider_face_exception()