|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- import numpy as np
-
- import mindspore.dataset as ds
- import mindspore.dataset.vision.c_transforms as vision
-
- CELEBA_DIR = "../data/dataset/testCelebAData"
- CIFAR10_DIR = "../data/dataset/testCifar10Data"
- CIFAR100_DIR = "../data/dataset/testCifar100Data"
- CLUE_DIR = "../data/dataset/testCLUE/afqmc/train.json"
- COCO_DIR = "../data/dataset/testCOCO/train"
- COCO_ANNOTATION = "../data/dataset/testCOCO/annotations/train.json"
- CSV_DIR = "../data/dataset/testCSV/1.csv"
- IMAGE_FOLDER_DIR = "../data/dataset/testPK/data/"
- MANIFEST_DIR = "../data/dataset/testManifestData/test.manifest"
- MNIST_DIR = "../data/dataset/testMnistData"
- TFRECORD_DIR = ["../data/dataset/testTFTestAllTypes/test.data"]
- TFRECORD_SCHEMA = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
- VOC_DIR = "../data/dataset/testVOC2012"
-
-
- def test_get_column_name_celeba():
- data = ds.CelebADataset(CELEBA_DIR)
- assert data.get_col_names() == ["image", "attr"]
-
-
- def test_get_column_name_cifar10():
- data = ds.Cifar10Dataset(CIFAR10_DIR)
- assert data.get_col_names() == ["image", "label"]
-
-
- def test_get_column_name_cifar100():
- data = ds.Cifar100Dataset(CIFAR100_DIR)
- assert data.get_col_names() == ["image", "coarse_label", "fine_label"]
-
-
- def test_get_column_name_clue():
- data = ds.CLUEDataset(CLUE_DIR, task="AFQMC", usage="train")
- assert data.get_col_names() == ["label", "sentence1", "sentence2"]
-
-
- def test_get_column_name_coco():
- data = ds.CocoDataset(COCO_DIR, annotation_file=COCO_ANNOTATION, task="Detection",
- decode=True, shuffle=False)
- assert data.get_col_names() == ["image", "bbox", "category_id", "iscrowd"]
-
-
- def test_get_column_name_csv():
- data = ds.CSVDataset(CSV_DIR)
- assert data.get_col_names() == ["1", "2", "3", "4"]
- data = ds.CSVDataset(CSV_DIR, column_names=["col1", "col2", "col3", "col4"])
- assert data.get_col_names() == ["col1", "col2", "col3", "col4"]
-
-
- def test_get_column_name_generator():
- def generator():
- for i in range(64):
- yield (np.array([i]),)
-
- data = ds.GeneratorDataset(generator, ["data"])
- assert data.get_col_names() == ["data"]
-
-
- def test_get_column_name_imagefolder():
- data = ds.ImageFolderDataset(IMAGE_FOLDER_DIR)
- assert data.get_col_names() == ["image", "label"]
-
-
- def test_get_column_name_iterator():
- data = ds.Cifar10Dataset(CIFAR10_DIR)
- itr = data.create_tuple_iterator(num_epochs=1)
- assert itr.get_col_names() == ["image", "label"]
- itr = data.create_dict_iterator(num_epochs=1)
- assert itr.get_col_names() == ["image", "label"]
-
-
- def test_get_column_name_manifest():
- data = ds.ManifestDataset(MANIFEST_DIR)
- assert data.get_col_names() == ["image", "label"]
-
-
- def test_get_column_name_map():
- data = ds.Cifar10Dataset(CIFAR10_DIR)
- center_crop_op = vision.CenterCrop(10)
- data = data.map(operations=center_crop_op, input_columns=["image"])
- assert data.get_col_names() == ["image", "label"]
- data = ds.Cifar10Dataset(CIFAR10_DIR)
- data = data.map(operations=center_crop_op, input_columns=["image"], output_columns=["image"])
- assert data.get_col_names() == ["image", "label"]
- data = ds.Cifar10Dataset(CIFAR10_DIR)
- data = data.map(operations=center_crop_op, input_columns=["image"], output_columns=["col1"])
- assert data.get_col_names() == ["col1", "label"]
- data = ds.Cifar10Dataset(CIFAR10_DIR)
- data = data.map(operations=center_crop_op, input_columns=["image"], output_columns=["col1", "col2"],
- column_order=["col2", "col1"])
- assert data.get_col_names() == ["col2", "col1"]
-
-
- def test_get_column_name_mnist():
- data = ds.MnistDataset(MNIST_DIR)
- assert data.get_col_names() == ["image", "label"]
-
-
- def test_get_column_name_numpy_slices():
- np_data = {"a": [1, 2], "b": [3, 4]}
- data = ds.NumpySlicesDataset(np_data, shuffle=False)
- assert data.get_col_names() == ["a", "b"]
- data = ds.NumpySlicesDataset([1, 2, 3], shuffle=False)
- assert data.get_col_names() == ["column_0"]
-
-
- def test_get_column_name_tfrecord():
- data = ds.TFRecordDataset(TFRECORD_DIR, TFRECORD_SCHEMA)
- assert data.get_col_names() == ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", "col_sint16", "col_sint32",
- "col_sint64"]
- data = ds.TFRecordDataset(TFRECORD_DIR, TFRECORD_SCHEMA,
- columns_list=["col_sint16", "col_sint64", "col_2d", "col_binary"])
- assert data.get_col_names() == ["col_sint16", "col_sint64", "col_2d", "col_binary"]
-
- data = ds.TFRecordDataset(TFRECORD_DIR)
- assert data.get_col_names() == ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", "col_sint16", "col_sint32",
- "col_sint64", "col_sint8"]
- s = ds.Schema()
- s.add_column("line", "string", [])
- s.add_column("words", "string", [-1])
- s.add_column("chinese", "string", [])
-
- data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False, schema=s)
- assert data.get_col_names() == ["line", "words", "chinese"]
-
-
- def test_get_column_name_to_device():
- data = ds.Cifar10Dataset(CIFAR10_DIR)
- data = data.to_device()
- assert data.get_col_names() == ["image", "label"]
-
-
- def test_get_column_name_voc():
- data = ds.VOCDataset(VOC_DIR, task="Segmentation", usage="train", decode=True, shuffle=False)
- assert data.get_col_names() == ["image", "target"]
-
-
- def test_get_column_name_project():
- data = ds.Cifar10Dataset(CIFAR10_DIR)
- assert data.get_col_names() == ["image", "label"]
- data = data.project(columns=["image"])
- assert data.get_col_names() == ["image"]
-
-
- def test_get_column_name_rename():
- data = ds.Cifar10Dataset(CIFAR10_DIR)
- assert data.get_col_names() == ["image", "label"]
- data = data.rename(["image", "label"], ["test1", "test2"])
- assert data.get_col_names() == ["test1", "test2"]
-
-
- def test_get_column_name_zip():
- data1 = ds.Cifar10Dataset(CIFAR10_DIR)
- assert data1.get_col_names() == ["image", "label"]
- data2 = ds.CSVDataset(CSV_DIR)
- assert data2.get_col_names() == ["1", "2", "3", "4"]
- data = ds.zip((data1, data2))
- assert data.get_col_names() == ["image", "label", "1", "2", "3", "4"]
-
-
- if __name__ == "__main__":
- test_get_column_name_celeba()
- test_get_column_name_cifar10()
- test_get_column_name_cifar100()
- test_get_column_name_clue()
- test_get_column_name_coco()
- test_get_column_name_csv()
- test_get_column_name_generator()
- test_get_column_name_imagefolder()
- test_get_column_name_iterator()
- test_get_column_name_manifest()
- test_get_column_name_map()
- test_get_column_name_mnist()
- test_get_column_name_numpy_slices()
- test_get_column_name_tfrecord()
- test_get_column_name_to_device()
- test_get_column_name_voc()
- test_get_column_name_project()
- test_get_column_name_rename()
- test_get_column_name_zip()
|