You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_get_col_names.py 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import mindspore.dataset as ds
  17. import mindspore.dataset.vision.c_transforms as vision
  18. CELEBA_DIR = "../data/dataset/testCelebAData"
  19. CIFAR10_DIR = "../data/dataset/testCifar10Data"
  20. CIFAR100_DIR = "../data/dataset/testCifar100Data"
  21. CLUE_DIR = "../data/dataset/testCLUE/afqmc/train.json"
  22. COCO_DIR = "../data/dataset/testCOCO/train"
  23. COCO_ANNOTATION = "../data/dataset/testCOCO/annotations/train.json"
  24. CSV_DIR = "../data/dataset/testCSV/1.csv"
  25. IMAGE_FOLDER_DIR = "../data/dataset/testPK/data/"
  26. MANIFEST_DIR = "../data/dataset/testManifestData/test.manifest"
  27. MNIST_DIR = "../data/dataset/testMnistData"
  28. TFRECORD_DIR = ["../data/dataset/testTFTestAllTypes/test.data"]
  29. TFRECORD_SCHEMA = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
  30. VOC_DIR = "../data/dataset/testVOC2012"
  31. def test_get_column_name_celeba():
  32. data = ds.CelebADataset(CELEBA_DIR)
  33. assert data.get_col_names() == ["image", "attr"]
  34. def test_get_column_name_cifar10():
  35. data = ds.Cifar10Dataset(CIFAR10_DIR)
  36. assert data.get_col_names() == ["image", "label"]
  37. def test_get_column_name_cifar100():
  38. data = ds.Cifar100Dataset(CIFAR100_DIR)
  39. assert data.get_col_names() == ["image", "coarse_label", "fine_label"]
  40. def test_get_column_name_clue():
  41. data = ds.CLUEDataset(CLUE_DIR, task="AFQMC", usage="train")
  42. assert data.get_col_names() == ["label", "sentence1", "sentence2"]
  43. def test_get_column_name_coco():
  44. data = ds.CocoDataset(COCO_DIR, annotation_file=COCO_ANNOTATION, task="Detection",
  45. decode=True, shuffle=False)
  46. assert data.get_col_names() == ["image", "bbox", "category_id", "iscrowd"]
  47. def test_get_column_name_csv():
  48. data = ds.CSVDataset(CSV_DIR)
  49. assert data.get_col_names() == ["1", "2", "3", "4"]
  50. data = ds.CSVDataset(CSV_DIR, column_names=["col1", "col2", "col3", "col4"])
  51. assert data.get_col_names() == ["col1", "col2", "col3", "col4"]
  52. def test_get_column_name_generator():
  53. def generator():
  54. for i in range(64):
  55. yield (np.array([i]),)
  56. data = ds.GeneratorDataset(generator, ["data"])
  57. assert data.get_col_names() == ["data"]
  58. def test_get_column_name_imagefolder():
  59. data = ds.ImageFolderDataset(IMAGE_FOLDER_DIR)
  60. assert data.get_col_names() == ["image", "label"]
  61. def test_get_column_name_iterator():
  62. data = ds.Cifar10Dataset(CIFAR10_DIR)
  63. itr = data.create_tuple_iterator(num_epochs=1)
  64. assert itr.get_col_names() == ["image", "label"]
  65. itr = data.create_dict_iterator(num_epochs=1)
  66. assert itr.get_col_names() == ["image", "label"]
  67. def test_get_column_name_manifest():
  68. data = ds.ManifestDataset(MANIFEST_DIR)
  69. assert data.get_col_names() == ["image", "label"]
  70. def test_get_column_name_map():
  71. data = ds.Cifar10Dataset(CIFAR10_DIR)
  72. center_crop_op = vision.CenterCrop(10)
  73. data = data.map(operations=center_crop_op, input_columns=["image"])
  74. assert data.get_col_names() == ["image", "label"]
  75. data = ds.Cifar10Dataset(CIFAR10_DIR)
  76. data = data.map(operations=center_crop_op, input_columns=["image"], output_columns=["image"])
  77. assert data.get_col_names() == ["image", "label"]
  78. data = ds.Cifar10Dataset(CIFAR10_DIR)
  79. data = data.map(operations=center_crop_op, input_columns=["image"], output_columns=["col1"])
  80. assert data.get_col_names() == ["col1", "label"]
  81. data = ds.Cifar10Dataset(CIFAR10_DIR)
  82. data = data.map(operations=center_crop_op, input_columns=["image"], output_columns=["col1", "col2"],
  83. column_order=["col2", "col1"])
  84. assert data.get_col_names() == ["col2", "col1"]
  85. def test_get_column_name_mnist():
  86. data = ds.MnistDataset(MNIST_DIR)
  87. assert data.get_col_names() == ["image", "label"]
  88. def test_get_column_name_numpy_slices():
  89. np_data = {"a": [1, 2], "b": [3, 4]}
  90. data = ds.NumpySlicesDataset(np_data, shuffle=False)
  91. assert data.get_col_names() == ["a", "b"]
  92. data = ds.NumpySlicesDataset([1, 2, 3], shuffle=False)
  93. assert data.get_col_names() == ["column_0"]
  94. def test_get_column_name_tfrecord():
  95. data = ds.TFRecordDataset(TFRECORD_DIR, TFRECORD_SCHEMA)
  96. assert data.get_col_names() == ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", "col_sint16", "col_sint32",
  97. "col_sint64"]
  98. data = ds.TFRecordDataset(TFRECORD_DIR, TFRECORD_SCHEMA,
  99. columns_list=["col_sint16", "col_sint64", "col_2d", "col_binary"])
  100. assert data.get_col_names() == ["col_sint16", "col_sint64", "col_2d", "col_binary"]
  101. data = ds.TFRecordDataset(TFRECORD_DIR)
  102. assert data.get_col_names() == ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", "col_sint16", "col_sint32",
  103. "col_sint64", "col_sint8"]
  104. s = ds.Schema()
  105. s.add_column("line", "string", [])
  106. s.add_column("words", "string", [-1])
  107. s.add_column("chinese", "string", [])
  108. data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False, schema=s)
  109. assert data.get_col_names() == ["line", "words", "chinese"]
  110. def test_get_column_name_to_device():
  111. data = ds.Cifar10Dataset(CIFAR10_DIR)
  112. data = data.to_device()
  113. assert data.get_col_names() == ["image", "label"]
  114. def test_get_column_name_voc():
  115. data = ds.VOCDataset(VOC_DIR, task="Segmentation", usage="train", decode=True, shuffle=False)
  116. assert data.get_col_names() == ["image", "target"]
  117. def test_get_column_name_project():
  118. data = ds.Cifar10Dataset(CIFAR10_DIR)
  119. assert data.get_col_names() == ["image", "label"]
  120. data = data.project(columns=["image"])
  121. assert data.get_col_names() == ["image"]
  122. def test_get_column_name_rename():
  123. data = ds.Cifar10Dataset(CIFAR10_DIR)
  124. assert data.get_col_names() == ["image", "label"]
  125. data = data.rename(["image", "label"], ["test1", "test2"])
  126. assert data.get_col_names() == ["test1", "test2"]
  127. def test_get_column_name_zip():
  128. data1 = ds.Cifar10Dataset(CIFAR10_DIR)
  129. assert data1.get_col_names() == ["image", "label"]
  130. data2 = ds.CSVDataset(CSV_DIR)
  131. assert data2.get_col_names() == ["1", "2", "3", "4"]
  132. data = ds.zip((data1, data2))
  133. assert data.get_col_names() == ["image", "label", "1", "2", "3", "4"]
  134. if __name__ == "__main__":
  135. test_get_column_name_celeba()
  136. test_get_column_name_cifar10()
  137. test_get_column_name_cifar100()
  138. test_get_column_name_clue()
  139. test_get_column_name_coco()
  140. test_get_column_name_csv()
  141. test_get_column_name_generator()
  142. test_get_column_name_imagefolder()
  143. test_get_column_name_iterator()
  144. test_get_column_name_manifest()
  145. test_get_column_name_map()
  146. test_get_column_name_mnist()
  147. test_get_column_name_numpy_slices()
  148. test_get_column_name_tfrecord()
  149. test_get_column_name_to_device()
  150. test_get_column_name_voc()
  151. test_get_column_name_project()
  152. test_get_column_name_rename()
  153. test_get_column_name_zip()