You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_manifestop.py 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import mindspore.dataset.transforms.vision.c_transforms as vision
  16. import mindspore.dataset.transforms.c_transforms as data_trans
  17. import numpy as np
  18. import mindspore.dataset as ds
  19. from mindspore import log as logger
  20. DATA_FILE = "../data/dataset/testManifestData/test.manifest"
  21. def test_manifest_dataset_train():
  22. data = ds.ManifestDataset(DATA_FILE, decode=True)
  23. count = 0
  24. cat_count = 0
  25. dog_count = 0
  26. for item in data.create_dict_iterator():
  27. logger.info("item[image] is {}".format(item["image"]))
  28. count = count + 1
  29. if item["label"].size == 1 and item["label"] == 0:
  30. cat_count = cat_count + 1
  31. elif item["label"].size == 1 and item["label"] == 1:
  32. dog_count = dog_count + 1
  33. assert (cat_count == 2)
  34. assert (dog_count == 1)
  35. assert (count == 4)
  36. def test_manifest_dataset_eval():
  37. data = ds.ManifestDataset(DATA_FILE, "eval", decode=True)
  38. count = 0
  39. for item in data.create_dict_iterator():
  40. logger.info("item[image] is {}".format(item["image"]))
  41. count = count + 1
  42. if item["label"] != 0 and item["label"] != 1:
  43. assert (0)
  44. assert (count == 2)
  45. def test_manifest_dataset_class_index():
  46. class_indexing = {"dog": 11}
  47. data = ds.ManifestDataset(DATA_FILE, decode=True, class_indexing=class_indexing)
  48. out_class_indexing = data.get_class_indexing()
  49. assert (out_class_indexing == {"dog": 11})
  50. count = 0
  51. for item in data.create_dict_iterator():
  52. logger.info("item[image] is {}".format(item["image"]))
  53. count = count + 1
  54. if item["label"] != 11:
  55. assert (0)
  56. assert (count == 1)
  57. def test_manifest_dataset_get_class_index():
  58. data = ds.ManifestDataset(DATA_FILE, decode=True)
  59. class_indexing = data.get_class_indexing()
  60. assert (class_indexing == {'cat': 0, 'dog': 1, 'flower': 2})
  61. data = data.shuffle(4)
  62. class_indexing = data.get_class_indexing()
  63. assert (class_indexing == {'cat': 0, 'dog': 1, 'flower': 2})
  64. count = 0
  65. for item in data.create_dict_iterator():
  66. logger.info("item[image] is {}".format(item["image"]))
  67. count = count + 1
  68. assert (count == 4)
  69. def test_manifest_dataset_multi_label():
  70. data = ds.ManifestDataset(DATA_FILE, decode=True, shuffle=False)
  71. count = 0
  72. expect_label = [1, 0, 0, [0, 2]]
  73. for item in data.create_dict_iterator():
  74. assert (item["label"].tolist() == expect_label[count])
  75. logger.info("item[image] is {}".format(item["image"]))
  76. count = count + 1
  77. assert (count == 4)
  78. def multi_label_hot(x):
  79. result = np.zeros(x.size // x.ndim, dtype=int)
  80. if x.ndim > 1:
  81. for i in range(x.ndim):
  82. result = np.add(result, x[i])
  83. else:
  84. result = np.add(result, x)
  85. return result
  86. def test_manifest_dataset_multi_label_onehot():
  87. data = ds.ManifestDataset(DATA_FILE, decode=True, shuffle=False)
  88. expect_label = [[[0, 1, 0], [1, 0, 0]], [[1, 0, 0], [1, 0, 1]]]
  89. one_hot_encode = data_trans.OneHot(3)
  90. data = data.map(input_columns=["label"], operations=one_hot_encode)
  91. data = data.map(input_columns=["label"], operations=multi_label_hot)
  92. data = data.batch(2)
  93. count = 0
  94. for item in data.create_dict_iterator():
  95. assert (item["label"].tolist() == expect_label[count])
  96. logger.info("item[image] is {}".format(item["image"]))
  97. count = count + 1
  98. if __name__ == '__main__':
  99. test_manifest_dataset_train()
  100. test_manifest_dataset_eval()
  101. test_manifest_dataset_class_index()
  102. test_manifest_dataset_get_class_index()
  103. test_manifest_dataset_multi_label()
  104. test_manifest_dataset_multi_label_onehot()