You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_caltech101.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Test Caltech101 dataset operators
  17. """
  18. import os
  19. import matplotlib.pyplot as plt
  20. import numpy as np
  21. import pytest
  22. from PIL import Image
  23. from scipy.io import loadmat
  24. import mindspore.dataset as ds
  25. import mindspore.dataset.vision.c_transforms as c_vision
  26. from mindspore import log as logger
  27. DATASET_DIR = "../data/dataset/testCaltech101Data"
  28. WRONG_DIR = "../data/dataset/notExist"
  29. def get_index_info():
  30. dataset_dir = os.path.realpath(DATASET_DIR)
  31. image_dir = os.path.join(dataset_dir, "101_ObjectCategories")
  32. classes = sorted(os.listdir(image_dir))
  33. if "BACKGROUND_Google" in classes:
  34. classes.remove("BACKGROUND_Google")
  35. name_map = {"Faces": "Faces_2",
  36. "Faces_easy": "Faces_3",
  37. "Motorbikes": "Motorbikes_16",
  38. "airplanes": "Airplanes_Side_2"}
  39. annotation_classes = [name_map[class_name] if class_name in name_map else class_name for class_name in classes]
  40. image_index = []
  41. image_label = []
  42. for i, c in enumerate(classes):
  43. sub_dir = os.path.join(image_dir, c)
  44. if not os.path.isdir(sub_dir) or not os.access(sub_dir, os.R_OK):
  45. continue
  46. num_images = len(os.listdir(sub_dir))
  47. image_index.extend(range(1, num_images + 1))
  48. image_label.extend(num_images * [i])
  49. return image_index, image_label, classes, annotation_classes
  50. def load_caltech101(target_type="category", decode=False):
  51. """
  52. load Caltech101 data
  53. """
  54. dataset_dir = os.path.realpath(DATASET_DIR)
  55. image_dir = os.path.join(dataset_dir, "101_ObjectCategories")
  56. annotation_dir = os.path.join(dataset_dir, "Annotations")
  57. image_index, image_label, classes, annotation_classes = get_index_info()
  58. images, categories, annotations = [], [], []
  59. num_images = len(image_index)
  60. for i in range(num_images):
  61. image_file = os.path.join(image_dir, classes[image_label[i]], "image_{:04d}.jpg".format(image_index[i]))
  62. if not os.path.exists(image_file):
  63. raise ValueError("The image file {} does not exist or permission denied!".format(image_file))
  64. if decode:
  65. image = np.asarray(Image.open(image_file).convert("RGB"))
  66. else:
  67. image = np.fromfile(image_file, dtype=np.uint8)
  68. images.append(image)
  69. if target_type == "category":
  70. for i in range(num_images):
  71. categories.append(image_label[i])
  72. return images, categories
  73. for i in range(num_images):
  74. annotation_file = os.path.join(annotation_dir, annotation_classes[image_label[i]],
  75. "annotation_{:04d}.mat".format(image_index[i]))
  76. if not os.path.exists(annotation_file):
  77. raise ValueError("The annotation file {} does not exist or permission denied!".format(annotation_file))
  78. annotation = loadmat(annotation_file)["obj_contour"]
  79. annotations.append(annotation)
  80. if target_type == "annotation":
  81. return images, annotations
  82. for i in range(num_images):
  83. categories.append(image_label[i])
  84. return images, categories, annotations
  85. def visualize_dataset(images, labels):
  86. """
  87. Helper function to visualize the dataset samples
  88. """
  89. num_samples = len(images)
  90. for i in range(num_samples):
  91. plt.subplot(1, num_samples, i + 1)
  92. plt.imshow(images[i].squeeze())
  93. plt.title(labels[i])
  94. plt.show()
  95. def test_caltech101_content_check():
  96. """
  97. Feature: Caltech101Dataset
  98. Description: check if the image data of caltech101 dataset is read correctly
  99. Expectation: the data is processed successfully
  100. """
  101. logger.info("Test Caltech101Dataset Op with content check")
  102. all_data = ds.Caltech101Dataset(DATASET_DIR, target_type="annotation", num_samples=4, shuffle=False, decode=True)
  103. images, annotations = load_caltech101(target_type="annotation", decode=True)
  104. num_iter = 0
  105. for i, data in enumerate(all_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
  106. np.testing.assert_array_equal(data["image"], images[i])
  107. np.testing.assert_array_equal(data["annotation"], annotations[i])
  108. num_iter += 1
  109. assert num_iter == 4
  110. all_data = ds.Caltech101Dataset(DATASET_DIR, target_type="all", num_samples=4, shuffle=False, decode=True)
  111. images, categories, annotations = load_caltech101(target_type="all", decode=True)
  112. num_iter = 0
  113. for i, data in enumerate(all_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
  114. np.testing.assert_array_equal(data["image"], images[i])
  115. np.testing.assert_array_equal(data["category"], categories[i])
  116. np.testing.assert_array_equal(data["annotation"], annotations[i])
  117. num_iter += 1
  118. assert num_iter == 4
  119. def test_caltech101_basic():
  120. """
  121. Feature: Caltech101Dataset
  122. Description: basic test of Caltech101Dataset
  123. Expectation: the data is processed successfully
  124. """
  125. logger.info("Test Caltech101Dataset Op")
  126. # case 1: test target_type
  127. all_data_1 = ds.Caltech101Dataset(DATASET_DIR, shuffle=False)
  128. all_data_2 = ds.Caltech101Dataset(DATASET_DIR, shuffle=False)
  129. num_iter = 0
  130. for item1, item2 in zip(all_data_1.create_dict_iterator(num_epochs=1, output_numpy=True),
  131. all_data_2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  132. np.testing.assert_array_equal(item1["category"], item2["category"])
  133. num_iter += 1
  134. assert num_iter == 4
  135. # case 2: test decode
  136. all_data_1 = ds.Caltech101Dataset(DATASET_DIR, decode=True, shuffle=False)
  137. all_data_2 = ds.Caltech101Dataset(DATASET_DIR, decode=True, shuffle=False)
  138. num_iter = 0
  139. for item1, item2 in zip(all_data_1.create_dict_iterator(num_epochs=1, output_numpy=True),
  140. all_data_2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  141. np.testing.assert_array_equal(item1["image"], item2["image"])
  142. num_iter += 1
  143. assert num_iter == 4
  144. # case 3: test num_samples
  145. all_data = ds.Caltech101Dataset(DATASET_DIR, num_samples=4)
  146. num_iter = 0
  147. for _ in all_data.create_dict_iterator(num_epochs=1):
  148. num_iter += 1
  149. assert num_iter == 4
  150. # case 4: test repeat
  151. all_data = ds.Caltech101Dataset(DATASET_DIR, num_samples=4)
  152. all_data = all_data.repeat(2)
  153. num_iter = 0
  154. for _ in all_data.create_dict_iterator(num_epochs=1):
  155. num_iter += 1
  156. assert num_iter == 8
  157. # case 5: test get_dataset_size, resize and batch
  158. all_data = ds.Caltech101Dataset(DATASET_DIR, num_samples=4)
  159. all_data = all_data.map(operations=[c_vision.Decode(), c_vision.Resize((224, 224))], input_columns=["image"],
  160. num_parallel_workers=1)
  161. assert all_data.get_dataset_size() == 4
  162. assert all_data.get_batch_size() == 1
  163. # drop_remainder is default to be False
  164. all_data = all_data.batch(batch_size=3)
  165. assert all_data.get_batch_size() == 3
  166. assert all_data.get_dataset_size() == 2
  167. num_iter = 0
  168. for _ in all_data.create_dict_iterator(num_epochs=1):
  169. num_iter += 1
  170. assert num_iter == 2
  171. # case 6: test get_class_indexing
  172. all_data = ds.Caltech101Dataset(DATASET_DIR, num_samples=4)
  173. class_indexing = all_data.get_class_indexing()
  174. assert class_indexing["Faces"] == 0
  175. assert class_indexing["yin_yang"] == 100
  176. def test_caltech101_target_type():
  177. """
  178. Feature: Caltech101Dataset
  179. Description: test Caltech101Dataset with target_type
  180. Expectation: the data is processed successfully
  181. """
  182. logger.info("Test Caltech101Dataset Op with target_type")
  183. all_data_1 = ds.Caltech101Dataset(DATASET_DIR, target_type="annotation", shuffle=False)
  184. all_data_2 = ds.Caltech101Dataset(DATASET_DIR, target_type="annotation", shuffle=False)
  185. num_iter = 0
  186. for item1, item2 in zip(all_data_1.create_dict_iterator(num_epochs=1, output_numpy=True),
  187. all_data_2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  188. np.testing.assert_array_equal(item1["annotation"], item2["annotation"])
  189. num_iter += 1
  190. assert num_iter == 4
  191. all_data_1 = ds.Caltech101Dataset(DATASET_DIR, target_type="all", shuffle=False)
  192. all_data_2 = ds.Caltech101Dataset(DATASET_DIR, target_type="all", shuffle=False)
  193. num_iter = 0
  194. for item1, item2 in zip(all_data_1.create_dict_iterator(num_epochs=1, output_numpy=True),
  195. all_data_2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  196. np.testing.assert_array_equal(item1["category"], item2["category"])
  197. np.testing.assert_array_equal(item1["annotation"], item2["annotation"])
  198. num_iter += 1
  199. assert num_iter == 4
  200. all_data_1 = ds.Caltech101Dataset(DATASET_DIR, target_type="category", shuffle=False)
  201. all_data_2 = ds.Caltech101Dataset(DATASET_DIR, target_type="category", shuffle=False)
  202. num_iter = 0
  203. for item1, item2 in zip(all_data_1.create_dict_iterator(num_epochs=1, output_numpy=True),
  204. all_data_2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  205. np.testing.assert_array_equal(item1["category"], item2["category"])
  206. num_iter += 1
  207. assert num_iter == 4
  208. def test_caltech101_sequential_sampler():
  209. """
  210. Feature: Caltech101Dataset
  211. Description: test Caltech101Dataset with SequentialSampler
  212. Expectation: the data is processed successfully
  213. """
  214. logger.info("Test Caltech101Dataset Op with SequentialSampler")
  215. num_samples = 4
  216. sampler = ds.SequentialSampler(num_samples=num_samples)
  217. all_data_1 = ds.Caltech101Dataset(DATASET_DIR, sampler=sampler)
  218. all_data_2 = ds.Caltech101Dataset(DATASET_DIR, shuffle=False, num_samples=num_samples)
  219. label_list_1, label_list_2 = [], []
  220. num_iter = 0
  221. for item1, item2 in zip(all_data_1.create_dict_iterator(num_epochs=1),
  222. all_data_2.create_dict_iterator(num_epochs=1)):
  223. label_list_1.append(item1["category"].asnumpy())
  224. label_list_2.append(item2["category"].asnumpy())
  225. num_iter += 1
  226. np.testing.assert_array_equal(label_list_1, label_list_2)
  227. assert num_iter == num_samples
  228. def test_caltech101_exception():
  229. """
  230. Feature: Caltech101Dataset
  231. Description: test error cases for Caltech101Dataset
  232. Expectation: throw correct error and message
  233. """
  234. logger.info("Test error cases for Caltech101Dataset")
  235. error_msg_1 = "sampler and shuffle cannot be specified at the same time"
  236. with pytest.raises(RuntimeError, match=error_msg_1):
  237. ds.Caltech101Dataset(DATASET_DIR, shuffle=False, sampler=ds.SequentialSampler(1))
  238. error_msg_2 = "sampler and sharding cannot be specified at the same time"
  239. with pytest.raises(RuntimeError, match=error_msg_2):
  240. ds.Caltech101Dataset(DATASET_DIR, sampler=ds.SequentialSampler(1), num_shards=2, shard_id=0)
  241. error_msg_3 = "num_shards is specified and currently requires shard_id as well"
  242. with pytest.raises(RuntimeError, match=error_msg_3):
  243. ds.Caltech101Dataset(DATASET_DIR, num_shards=10)
  244. error_msg_4 = "shard_id is specified but num_shards is not"
  245. with pytest.raises(RuntimeError, match=error_msg_4):
  246. ds.Caltech101Dataset(DATASET_DIR, shard_id=0)
  247. error_msg_5 = "Input shard_id is not within the required interval"
  248. with pytest.raises(ValueError, match=error_msg_5):
  249. ds.Caltech101Dataset(DATASET_DIR, num_shards=5, shard_id=-1)
  250. with pytest.raises(ValueError, match=error_msg_5):
  251. ds.Caltech101Dataset(DATASET_DIR, num_shards=5, shard_id=5)
  252. with pytest.raises(ValueError, match=error_msg_5):
  253. ds.Caltech101Dataset(DATASET_DIR, num_shards=2, shard_id=5)
  254. error_msg_6 = "num_parallel_workers exceeds"
  255. with pytest.raises(ValueError, match=error_msg_6):
  256. ds.Caltech101Dataset(DATASET_DIR, shuffle=False, num_parallel_workers=0)
  257. with pytest.raises(ValueError, match=error_msg_6):
  258. ds.Caltech101Dataset(DATASET_DIR, shuffle=False, num_parallel_workers=256)
  259. with pytest.raises(ValueError, match=error_msg_6):
  260. ds.Caltech101Dataset(DATASET_DIR, shuffle=False, num_parallel_workers=-2)
  261. error_msg_7 = "Argument shard_id"
  262. with pytest.raises(TypeError, match=error_msg_7):
  263. ds.Caltech101Dataset(DATASET_DIR, num_shards=2, shard_id="0")
  264. error_msg_8 = "does not exist or is not a directory or permission denied!"
  265. with pytest.raises(ValueError, match=error_msg_8):
  266. all_data = ds.Caltech101Dataset(WRONG_DIR, WRONG_DIR)
  267. for _ in all_data.create_dict_iterator(num_epochs=1):
  268. pass
  269. error_msg_9 = "Input target_type is not within the valid set of \\['category', 'annotation', 'all'\\]."
  270. with pytest.raises(ValueError, match=error_msg_9):
  271. all_data = ds.Caltech101Dataset(DATASET_DIR, target_type="cate")
  272. for _ in all_data.create_dict_iterator(num_epochs=1):
  273. pass
  274. def test_caltech101_visualize(plot=False):
  275. """
  276. Feature: Caltech101Dataset
  277. Description: visualize Caltech101Dataset results
  278. Expectation: the data is processed successfully
  279. """
  280. logger.info("Test Caltech101Dataset visualization")
  281. all_data = ds.Caltech101Dataset(DATASET_DIR, num_samples=4, decode=True, shuffle=False)
  282. num_iter = 0
  283. image_list, category_list = [], []
  284. for item in all_data.create_dict_iterator(num_epochs=1, output_numpy=True):
  285. image = item["image"]
  286. category = item["category"]
  287. image_list.append(image)
  288. category_list.append("label {}".format(category))
  289. assert isinstance(image, np.ndarray)
  290. assert len(image.shape) == 3
  291. assert image.shape[-1] == 3
  292. assert image.dtype == np.uint8
  293. assert category.dtype == np.int64
  294. num_iter += 1
  295. assert num_iter == 4
  296. if plot:
  297. visualize_dataset(image_list, category_list)
  298. if __name__ == '__main__':
  299. test_caltech101_content_check()
  300. test_caltech101_basic()
  301. test_caltech101_target_type()
  302. test_caltech101_sequential_sampler()
  303. test_caltech101_exception()
  304. test_caltech101_visualize(plot=True)