You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_flickr.py 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import matplotlib.pyplot as plt
  17. import mindspore.dataset as ds
  18. import mindspore.dataset.vision.c_transforms as c_vision
  19. from mindspore import log as logger
  20. FLICKR30K_DATASET_DIR = "../data/dataset/testFlickrData/flickr30k/flickr30k-images"
  21. FLICKR30K_ANNOTATION_FILE_1 = "../data/dataset/testFlickrData/flickr30k/test1.token"
  22. FLICKR30K_ANNOTATION_FILE_2 = "../data/dataset/testFlickrData/flickr30k/test2.token"
  23. def visualize_dataset(images, labels):
  24. """
  25. Helper function to visualize the dataset samples
  26. """
  27. plt.figure(figsize=(10, 10))
  28. for i, item in enumerate(zip(images, labels), start=1):
  29. plt.imshow(item[0])
  30. plt.title('\n'.join([s.decode('utf-8') for s in item[1]]))
  31. plt.savefig('./flickr_' + str(i) + '.jpg')
  32. def test_flickr30k_dataset_train(plot=False):
  33. data = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_1, decode=True)
  34. count = 0
  35. images_list = []
  36. annotation_list = []
  37. for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  38. logger.info("item[image] is {}".format(item["image"]))
  39. images_list.append(item['image'])
  40. annotation_list.append(item['annotation'])
  41. count = count + 1
  42. assert count == 2
  43. if plot:
  44. visualize_dataset(images_list, annotation_list)
  45. def test_flickr30k_dataset_annotation_check():
  46. data = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_1, decode=True, shuffle=False)
  47. count = 0
  48. expect_annotation_arr = [
  49. np.array([
  50. r'This is \*a banana.',
  51. 'This is a yellow banana.',
  52. 'This is a banana on the table.',
  53. 'The banana is yellow.',
  54. 'The banana is very big.',
  55. ]),
  56. np.array([
  57. 'This is a pen.',
  58. 'This is a red and black pen.',
  59. 'This is a pen on the table.',
  60. 'The color of the pen is red and black.',
  61. 'The pen has two colors.',
  62. ])
  63. ]
  64. for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  65. annotation = [s.decode("utf8") for s in item["annotation"]]
  66. np.testing.assert_array_equal(annotation, expect_annotation_arr[count])
  67. logger.info("item[image] is {}".format(item["image"]))
  68. count = count + 1
  69. assert count == 2
  70. def test_flickr30k_dataset_basic():
  71. # case 1: test num_samples
  72. data1 = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_2, num_samples=2, decode=True)
  73. num_iter1 = 0
  74. for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  75. num_iter1 += 1
  76. assert num_iter1 == 2
  77. # case 2: test repeat
  78. data2 = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_1, decode=True)
  79. data2 = data2.repeat(5)
  80. num_iter2 = 0
  81. for _ in data2.create_dict_iterator(num_epochs=1, output_numpy=True):
  82. num_iter2 += 1
  83. assert num_iter2 == 10
  84. # case 3: test batch with drop_remainder=False
  85. data3 = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_2, decode=True, shuffle=False)
  86. resize_op = c_vision.Resize((100, 100))
  87. data3 = data3.map(operations=resize_op, input_columns=["image"], num_parallel_workers=1)
  88. assert data3.get_dataset_size() == 3
  89. assert data3.get_batch_size() == 1
  90. data3 = data3.batch(batch_size=2) # drop_remainder is default to be False
  91. assert data3.get_dataset_size() == 2
  92. assert data3.get_batch_size() == 2
  93. num_iter3 = 0
  94. for _ in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
  95. num_iter3 += 1
  96. assert num_iter3 == 2
  97. # case 4: test batch with drop_remainder=True
  98. data4 = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_2, decode=True, shuffle=False)
  99. resize_op = c_vision.Resize((100, 100))
  100. data4 = data4.map(operations=resize_op, input_columns=["image"], num_parallel_workers=1)
  101. assert data4.get_dataset_size() == 3
  102. assert data4.get_batch_size() == 1
  103. data4 = data4.batch(batch_size=2, drop_remainder=True) # the rest of incomplete batch will be dropped
  104. assert data4.get_dataset_size() == 1
  105. assert data4.get_batch_size() == 2
  106. num_iter4 = 0
  107. for _ in data4.create_dict_iterator(num_epochs=1, output_numpy=True):
  108. num_iter4 += 1
  109. assert num_iter4 == 1
  110. def test_flickr30k_dataset_exception():
  111. def exception_func(item):
  112. raise Exception("Error occur!")
  113. try:
  114. data = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_1, decode=True)
  115. data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
  116. num_rows = 0
  117. for _ in data.create_dict_iterator():
  118. num_rows += 1
  119. assert False
  120. except RuntimeError as e:
  121. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  122. try:
  123. data = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_1, decode=True)
  124. data = data.map(operations=exception_func, input_columns=["annotation"], num_parallel_workers=1)
  125. num_rows = 0
  126. for _ in data.create_dict_iterator():
  127. num_rows += 1
  128. assert False
  129. except RuntimeError as e:
  130. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  131. if __name__ == "__main__":
  132. test_flickr30k_dataset_train(False)
  133. test_flickr30k_dataset_annotation_check()
  134. test_flickr30k_dataset_basic()
  135. test_flickr30k_dataset_exception()