You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_manifestop.py 6.2 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import mindspore.dataset as ds
  17. import mindspore.dataset.vision.c_transforms as vision
  18. import mindspore.dataset.transforms.c_transforms as data_trans
  19. from mindspore import log as logger
  20. DATA_FILE = "../data/dataset/testManifestData/test.manifest"
  21. def test_manifest_dataset_train():
  22. data = ds.ManifestDataset(DATA_FILE, decode=True)
  23. count = 0
  24. cat_count = 0
  25. dog_count = 0
  26. for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  27. logger.info("item[image] is {}".format(item["image"]))
  28. count = count + 1
  29. if item["label"].size == 1 and item["label"] == 0:
  30. cat_count = cat_count + 1
  31. elif item["label"].size == 1 and item["label"] == 1:
  32. dog_count = dog_count + 1
  33. assert cat_count == 2
  34. assert dog_count == 1
  35. assert count == 4
  36. def test_manifest_dataset_eval():
  37. data = ds.ManifestDataset(DATA_FILE, "eval", decode=True)
  38. count = 0
  39. for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  40. logger.info("item[image] is {}".format(item["image"]))
  41. count = count + 1
  42. if item["label"] != 0 and item["label"] != 1:
  43. assert 0
  44. assert count == 2
  45. def test_manifest_dataset_class_index():
  46. class_indexing = {"dog": 11}
  47. data = ds.ManifestDataset(DATA_FILE, decode=True, class_indexing=class_indexing)
  48. out_class_indexing = data.get_class_indexing()
  49. assert out_class_indexing == {"dog": 11}
  50. count = 0
  51. for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  52. logger.info("item[image] is {}".format(item["image"]))
  53. count = count + 1
  54. if item["label"] != 11:
  55. assert 0
  56. assert count == 1
  57. def test_manifest_dataset_get_class_index():
  58. data = ds.ManifestDataset(DATA_FILE, decode=True)
  59. class_indexing = data.get_class_indexing()
  60. assert class_indexing == {'cat': 0, 'dog': 1, 'flower': 2}
  61. data = data.shuffle(4)
  62. class_indexing = data.get_class_indexing()
  63. assert class_indexing == {'cat': 0, 'dog': 1, 'flower': 2}
  64. count = 0
  65. for item in data.create_dict_iterator(num_epochs=1):
  66. logger.info("item[image] is {}".format(item["image"]))
  67. count = count + 1
  68. assert count == 4
  69. def test_manifest_dataset_multi_label():
  70. data = ds.ManifestDataset(DATA_FILE, decode=True, shuffle=False)
  71. count = 0
  72. expect_label = [1, 0, 0, [0, 2]]
  73. for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  74. assert item["label"].tolist() == expect_label[count]
  75. logger.info("item[image] is {}".format(item["image"]))
  76. count = count + 1
  77. assert count == 4
  78. def multi_label_hot(x):
  79. result = np.zeros(x.size // x.ndim, dtype=int)
  80. if x.ndim > 1:
  81. for i in range(x.ndim):
  82. result = np.add(result, x[i])
  83. else:
  84. result = np.add(result, x)
  85. return result
  86. def test_manifest_dataset_multi_label_onehot():
  87. data = ds.ManifestDataset(DATA_FILE, decode=True, shuffle=False)
  88. expect_label = [[[0, 1, 0], [1, 0, 0]], [[1, 0, 0], [1, 0, 1]]]
  89. one_hot_encode = data_trans.OneHot(3)
  90. data = data.map(operations=one_hot_encode, input_columns=["label"])
  91. data = data.map(operations=multi_label_hot, input_columns=["label"])
  92. data = data.batch(2)
  93. count = 0
  94. for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  95. assert item["label"].tolist() == expect_label[count]
  96. logger.info("item[image] is {}".format(item["image"]))
  97. count = count + 1
  98. def test_manifest_dataset_get_num_class():
  99. data = ds.ManifestDataset(DATA_FILE, decode=True, shuffle=False)
  100. assert data.num_classes() == 3
  101. padded_samples = [{'image': np.zeros(1, np.uint8), 'label': np.array(1, np.int32)}]
  102. padded_ds = ds.PaddedDataset(padded_samples)
  103. data = data.repeat(2)
  104. padded_ds = padded_ds.repeat(2)
  105. data1 = data + padded_ds
  106. assert data1.num_classes() == 3
  107. def test_manifest_dataset_exception():
  108. def exception_func(item):
  109. raise Exception("Error occur!")
  110. try:
  111. data = ds.ManifestDataset(DATA_FILE)
  112. data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
  113. for _ in data.__iter__():
  114. pass
  115. assert False
  116. except RuntimeError as e:
  117. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  118. try:
  119. data = ds.ManifestDataset(DATA_FILE)
  120. data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
  121. data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
  122. for _ in data.__iter__():
  123. pass
  124. assert False
  125. except RuntimeError as e:
  126. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  127. try:
  128. data = ds.ManifestDataset(DATA_FILE)
  129. data = data.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1)
  130. for _ in data.__iter__():
  131. pass
  132. assert False
  133. except RuntimeError as e:
  134. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  135. if __name__ == '__main__':
  136. test_manifest_dataset_train()
  137. test_manifest_dataset_eval()
  138. test_manifest_dataset_class_index()
  139. test_manifest_dataset_get_class_index()
  140. test_manifest_dataset_multi_label()
  141. test_manifest_dataset_multi_label_onehot()
  142. test_manifest_dataset_get_num_class()
  143. test_manifest_dataset_exception()