You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_kitti.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import re
  16. import pytest
  17. import mindspore.dataset as ds
  18. import mindspore.dataset.vision.c_transforms as vision
  19. from mindspore import log as logger
  20. DATA_DIR = "../data/dataset/testKITTI"
  21. IMAGE_SHAPE = [2268, 642, 2268]
  22. def test_func_kitti_dataset_basic():
  23. """
  24. Feature: KITTI
  25. Description: test basic function of KITTI with default parament
  26. Expectation: the dataset is as expected
  27. """
  28. repeat_count = 2
  29. # apply dataset operations.
  30. data = ds.KITTIDataset(DATA_DIR, shuffle=False)
  31. data = data.repeat(repeat_count)
  32. num_iter = 0
  33. count = [0, 0, 0, 0, 0, 0, 0, 0]
  34. SHAPE = [159109, 176455, 54214, 159109, 176455, 54214]
  35. ANNOTATIONSHAPE = [6, 3, 7, 6, 3, 7]
  36. # each data is a dictionary.
  37. for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  38. # in this example, each dictionary has keys "image", "label", "truncated", "occluded", "alpha", "bbox",
  39. # "dimensions", "location", "rotation_y".
  40. assert item["image"].shape[0] == SHAPE[num_iter]
  41. for label in item["label"]:
  42. count[label[0]] += 1
  43. assert item["truncated"].shape[0] == ANNOTATIONSHAPE[num_iter]
  44. assert item["occluded"].shape[0] == ANNOTATIONSHAPE[num_iter]
  45. assert item["alpha"].shape[0] == ANNOTATIONSHAPE[num_iter]
  46. assert item["bbox"].shape[0] == ANNOTATIONSHAPE[num_iter]
  47. assert item["dimensions"].shape[0] == ANNOTATIONSHAPE[num_iter]
  48. assert item["location"].shape[0] == ANNOTATIONSHAPE[num_iter]
  49. assert item["rotation_y"].shape[0] == ANNOTATIONSHAPE[num_iter]
  50. num_iter += 1
  51. logger.info("Number of data in data1: {}".format(num_iter))
  52. assert num_iter == 6
  53. assert count == [8, 20, 2, 2, 0, 0, 0, 0]
  54. def test_kitti_usage_train():
  55. """
  56. Feature: KITTI
  57. Description: test basic usage "train" of KITTI
  58. Expectation: the dataset is as expected
  59. """
  60. data1 = ds.KITTIDataset(DATA_DIR, usage="train")
  61. num = 0
  62. count = [0, 0, 0, 0, 0, 0, 0, 0]
  63. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  64. for label in item["label"]:
  65. count[label[0]] += 1
  66. num += 1
  67. assert num == 3
  68. assert count == [4, 10, 1, 1, 0, 0, 0, 0]
  69. def test_kitti_usage_test():
  70. """
  71. Feature: KITTI
  72. Description: test basic usage "test" of KITTI
  73. Expectation: the dataset is as expected
  74. """
  75. data1 = ds.KITTIDataset(
  76. DATA_DIR, usage="test", shuffle=False, decode=True, num_samples=3)
  77. num = 0
  78. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  79. assert item["image"].shape[0] == IMAGE_SHAPE[num]
  80. num += 1
  81. assert num == 3
  82. def test_kitti_case():
  83. """
  84. Feature: KITTI
  85. Description: test basic usage of KITTI
  86. Expectation: the dataset is as expected
  87. """
  88. data1 = ds.KITTIDataset(DATA_DIR,
  89. usage="train", decode=True, num_samples=3)
  90. resize_op = vision.Resize((224, 224))
  91. data1 = data1.map(operations=resize_op, input_columns=["image"])
  92. repeat_num = 4
  93. data1 = data1.repeat(repeat_num)
  94. batch_size = 2
  95. data1 = data1.batch(batch_size, drop_remainder=True, pad_info={})
  96. num = 0
  97. for _ in data1.create_dict_iterator(num_epochs=1):
  98. num += 1
  99. assert num == 6
  100. def test_func_kitti_dataset_numsamples_num_parallel_workers():
  101. """
  102. Feature: KITTI
  103. Description: test numsamples and num_parallel_workers of KITTI
  104. Expectation: the dataset is as expected
  105. """
  106. # define parameters.
  107. repeat_count = 2
  108. # apply dataset operations.
  109. data1 = ds.KITTIDataset(DATA_DIR, num_samples=2, num_parallel_workers=2)
  110. data1 = data1.repeat(repeat_count)
  111. num_iter = 0
  112. # each data is a dictionary.
  113. for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  114. num_iter += 1
  115. logger.info("Number of data in data1: {}".format(num_iter))
  116. assert num_iter == 4
  117. random_sampler = ds.RandomSampler(num_samples=3, replacement=True)
  118. data1 = ds.KITTIDataset(DATA_DIR, num_parallel_workers=2,
  119. sampler=random_sampler)
  120. num_iter = 0
  121. for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  122. num_iter += 1
  123. assert num_iter == 3
  124. random_sampler = ds.RandomSampler(num_samples=3, replacement=False)
  125. data1 = ds.KITTIDataset(DATA_DIR, num_parallel_workers=2,
  126. sampler=random_sampler)
  127. num_iter = 0
  128. for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  129. num_iter += 1
  130. assert num_iter == 3
  131. def test_func_kitti_dataset_extrashuffle():
  132. """
  133. Feature: KITTI
  134. Description: test extrashuffle of KITTI
  135. Expectation: the dataset is as expected
  136. """
  137. # define parameters.
  138. repeat_count = 2
  139. # apply dataset operations.
  140. data1 = ds.KITTIDataset(DATA_DIR, shuffle=True)
  141. data1 = data1.shuffle(buffer_size=3)
  142. data1 = data1.repeat(repeat_count)
  143. num_iter = 0
  144. # each data is a dictionary.
  145. for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  146. num_iter += 1
  147. logger.info("Number of data in data1: {}".format(num_iter))
  148. assert num_iter == 6
  149. def test_func_kitti_dataset_no_para():
  150. """
  151. Feature: KITTI
  152. Description: test no para of KITTI
  153. Expectation: throw exception correctly
  154. """
  155. with pytest.raises(TypeError, match="missing a required argument: 'dataset_dir'"):
  156. dataset = ds.KITTIDataset()
  157. num_iter = 0
  158. for data in dataset.create_dict_iterator(output_numpy=True):
  159. assert "image" in str(data.keys())
  160. num_iter += 1
  161. def test_func_kitti_dataset_distributed_sampler():
  162. """
  163. Feature: KITTI
  164. Description: test DistributedSampler of KITTI
  165. Expectation: throw exception correctly
  166. """
  167. # define parameters.
  168. repeat_count = 2
  169. # apply dataset operations.
  170. sampler = ds.DistributedSampler(3, 1)
  171. data1 = ds.KITTIDataset(DATA_DIR, sampler=sampler)
  172. data1 = data1.repeat(repeat_count)
  173. num_iter = 0
  174. # each data is a dictionary.
  175. for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  176. num_iter += 1
  177. logger.info("Number of data in data1: {}".format(num_iter))
  178. assert num_iter == 2
  179. def test_func_kitti_dataset_decode():
  180. """
  181. Feature: KITTI
  182. Description: test decode of KITTI
  183. Expectation: throw exception correctly
  184. """
  185. # define parameters.
  186. repeat_count = 2
  187. # apply dataset operations.
  188. data1 = ds.KITTIDataset(DATA_DIR, decode=True)
  189. data1 = data1.repeat(repeat_count)
  190. num_iter = 0
  191. # each data is a dictionary.
  192. for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  193. # in this example, each dictionary has keys "image" and "label".
  194. num_iter += 1
  195. logger.info("Number of data in data1: {}".format(num_iter))
  196. assert num_iter == 6
  197. def test_kitti_numshards():
  198. """
  199. Feature: KITTI
  200. Description: test numShards of KITTI
  201. Expectation: throw exception correctly
  202. """
  203. # define parameters.
  204. repeat_count = 2
  205. # apply dataset operations.
  206. data1 = ds.KITTIDataset(DATA_DIR, num_shards=3, shard_id=2)
  207. data1 = data1.repeat(repeat_count)
  208. num_iter = 0
  209. # each data is a dictionary.
  210. for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  211. num_iter += 1
  212. logger.info("Number of data in data1: {}".format(num_iter))
  213. assert num_iter == 2
  214. def test_func_kitti_dataset_more_para():
  215. """
  216. Feature: KITTI
  217. Description: test more para of KITTI
  218. Expectation: throw exception correctly
  219. """
  220. with pytest.raises(TypeError, match="got an unexpected keyword argument 'more_para'"):
  221. dataset = ds.KITTIDataset(DATA_DIR, usage="train", num_samples=6, num_parallel_workers=None,
  222. shuffle=True, sampler=None, decode=True, num_shards=3,
  223. shard_id=2, cache=None, more_para=None)
  224. num_iter = 0
  225. for data in dataset.create_dict_iterator(output_numpy=True):
  226. num_iter += 1
  227. assert "image" in str(data.keys())
  228. def test_kitti_exception():
  229. """
  230. Feature: KITTI
  231. Description: test error cases of KITTI
  232. Expectation: throw exception correctly
  233. """
  234. logger.info("Test error cases for KITTIDataset")
  235. error_msg_1 = "sampler and shuffle cannot be specified at the same time"
  236. with pytest.raises(RuntimeError, match=error_msg_1):
  237. ds.KITTIDataset(DATA_DIR, shuffle=False, decode=True, sampler=ds.SequentialSampler(1))
  238. error_msg_2 = "sampler and sharding cannot be specified at the same time"
  239. with pytest.raises(RuntimeError, match=error_msg_2):
  240. ds.KITTIDataset(DATA_DIR, sampler=ds.SequentialSampler(1), decode=True, num_shards=2, shard_id=0)
  241. error_msg_3 = "num_shards is specified and currently requires shard_id as well"
  242. with pytest.raises(RuntimeError, match=error_msg_3):
  243. ds.KITTIDataset(DATA_DIR, decode=True, num_shards=10)
  244. error_msg_4 = "shard_id is specified but num_shards is not"
  245. with pytest.raises(RuntimeError, match=error_msg_4):
  246. ds.KITTIDataset(DATA_DIR, decode=True, shard_id=0)
  247. error_msg_5 = "Input shard_id is not within the required interval"
  248. with pytest.raises(ValueError, match=error_msg_5):
  249. ds.KITTIDataset(DATA_DIR, decode=True, num_shards=5, shard_id=-1)
  250. with pytest.raises(ValueError, match=error_msg_5):
  251. ds.KITTIDataset(DATA_DIR, decode=True, num_shards=5, shard_id=5)
  252. error_msg_6 = "num_parallel_workers exceeds"
  253. with pytest.raises(ValueError, match=error_msg_6):
  254. ds.KITTIDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=0)
  255. with pytest.raises(ValueError, match=error_msg_6):
  256. ds.KITTIDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=256)
  257. error_msg_7 = "Argument shard_id"
  258. with pytest.raises(TypeError, match=error_msg_7):
  259. ds.KITTIDataset(DATA_DIR, decode=True, num_shards=2, shard_id="0")
  260. error_msg_8 = "does not exist or is not a directory or permission denied!"
  261. with pytest.raises(ValueError, match=error_msg_8):
  262. all_data = ds.KITTIDataset("../data/dataset/testKITTI2", decode=True)
  263. for _ in all_data.create_dict_iterator(num_epochs=1):
  264. pass
  265. error_msg_9 = "Input usage is not within the valid set of ['train', 'test']."
  266. with pytest.raises(ValueError, match=re.escape(error_msg_9)):
  267. all_data = ds.KITTIDataset(DATA_DIR, usage="all")
  268. for _ in all_data.create_dict_iterator(num_epochs=1):
  269. pass
  270. error_msg_10 = "Argument decode with value 123 is not of type [<class 'bool'>], but got <class 'int'>."
  271. with pytest.raises(TypeError, match=re.escape(error_msg_10)):
  272. all_data = ds.KITTIDataset(DATA_DIR, decode=123)
  273. for _ in all_data.create_dict_iterator(num_epochs=1):
  274. pass
  275. if __name__ == '__main__':
  276. test_func_kitti_dataset_basic()
  277. test_kitti_usage_train()
  278. test_kitti_usage_test()
  279. test_kitti_case()
  280. test_func_kitti_dataset_numsamples_num_parallel_workers()
  281. test_func_kitti_dataset_extrashuffle()
  282. test_func_kitti_dataset_no_para()
  283. test_func_kitti_dataset_distributed_sampler()
  284. test_func_kitti_dataset_decode()
  285. test_kitti_numshards()
  286. test_func_kitti_dataset_more_para()
  287. test_kitti_exception()