|
|
|
@@ -0,0 +1,198 @@ |
|
|
|
# Copyright 2020 Huawei Technologies Co., Ltd |
|
|
|
# |
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
# you may not use this file except in compliance with the License. |
|
|
|
# You may obtain a copy of the License at |
|
|
|
# |
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
# |
|
|
|
# Unless required by applicable law or agreed to in writing, software |
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
# limitations under the License. |
|
|
|
# ============================================================================== |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
import mindspore.dataset as ds |
|
|
|
import mindspore.dataset.transforms.vision.c_transforms as vision |
|
|
|
|
|
|
|
CELEBA_DIR = "../data/dataset/testCelebAData" |
|
|
|
CIFAR10_DIR = "../data/dataset/testCifar10Data" |
|
|
|
CIFAR100_DIR = "../data/dataset/testCifar100Data" |
|
|
|
CLUE_DIR = "../data/dataset/testCLUE/afqmc/train.json" |
|
|
|
COCO_DIR = "../data/dataset/testCOCO/train" |
|
|
|
COCO_ANNOTATION = "../data/dataset/testCOCO/annotations/train.json" |
|
|
|
CSV_DIR = "../data/dataset/testCSV/1.csv" |
|
|
|
IMAGE_FOLDER_DIR = "../data/dataset/testPK/data/" |
|
|
|
MANIFEST_DIR = "../data/dataset/testManifestData/test.manifest" |
|
|
|
MNIST_DIR = "../data/dataset/testMnistData" |
|
|
|
TFRECORD_DIR = ["../data/dataset/testTFTestAllTypes/test.data"] |
|
|
|
TFRECORD_SCHEMA = "../data/dataset/testTFTestAllTypes/datasetSchema.json" |
|
|
|
VOC_DIR = "../data/dataset/testVOC2012" |
|
|
|
|
|
|
|
|
|
|
|
def test_get_column_name_celeba(): |
|
|
|
data = ds.CelebADataset(CELEBA_DIR) |
|
|
|
assert data.get_col_names() == ["image", "attr"] |
|
|
|
|
|
|
|
|
|
|
|
def test_get_column_name_cifar10(): |
|
|
|
data = ds.Cifar10Dataset(CIFAR10_DIR) |
|
|
|
assert data.get_col_names() == ["image", "label"] |
|
|
|
|
|
|
|
|
|
|
|
def test_get_column_name_cifar100(): |
|
|
|
data = ds.Cifar100Dataset(CIFAR100_DIR) |
|
|
|
assert data.get_col_names() == ["image", "coarse_label", "fine_label"] |
|
|
|
|
|
|
|
|
|
|
|
def test_get_column_name_clue(): |
|
|
|
data = ds.CLUEDataset(CLUE_DIR, task="AFQMC", usage="train") |
|
|
|
assert data.get_col_names() == ["label", "sentence1", "sentence2"] |
|
|
|
|
|
|
|
|
|
|
|
def test_get_column_name_coco(): |
|
|
|
data = ds.CocoDataset(COCO_DIR, annotation_file=COCO_ANNOTATION, task="Detection", |
|
|
|
decode=True, shuffle=False) |
|
|
|
assert data.get_col_names() == ["image", "bbox", "category_id", "iscrowd"] |
|
|
|
|
|
|
|
|
|
|
|
def test_get_column_name_csv(): |
|
|
|
data = ds.CSVDataset(CSV_DIR) |
|
|
|
assert data.get_col_names() == ["1", "2", "3", "4"] |
|
|
|
data = ds.CSVDataset(CSV_DIR, column_names=["col1", "col2", "col3", "col4"]) |
|
|
|
assert data.get_col_names() == ["col1", "col2", "col3", "col4"] |
|
|
|
|
|
|
|
|
|
|
|
def test_get_column_name_generator(): |
|
|
|
def generator(): |
|
|
|
for i in range(64): |
|
|
|
yield (np.array([i]),) |
|
|
|
|
|
|
|
data = ds.GeneratorDataset(generator, ["data"]) |
|
|
|
assert data.get_col_names() == ["data"] |
|
|
|
|
|
|
|
|
|
|
|
def test_get_column_name_imagefolder(): |
|
|
|
data = ds.ImageFolderDatasetV2(IMAGE_FOLDER_DIR) |
|
|
|
assert data.get_col_names() == ["image", "label"] |
|
|
|
|
|
|
|
|
|
|
|
def test_get_column_name_iterator(): |
|
|
|
data = ds.Cifar10Dataset(CIFAR10_DIR) |
|
|
|
itr = data.create_tuple_iterator(num_epochs=1) |
|
|
|
assert itr.get_col_names() == ["image", "label"] |
|
|
|
itr = data.create_dict_iterator(num_epochs=1) |
|
|
|
assert itr.get_col_names() == ["image", "label"] |
|
|
|
|
|
|
|
|
|
|
|
def test_get_column_name_manifest(): |
|
|
|
data = ds.ManifestDataset(MANIFEST_DIR) |
|
|
|
assert data.get_col_names() == ["image", "label"] |
|
|
|
|
|
|
|
|
|
|
|
def test_get_column_name_map(): |
|
|
|
data = ds.Cifar10Dataset(CIFAR10_DIR) |
|
|
|
center_crop_op = vision.CenterCrop(10) |
|
|
|
data = data.map(input_columns=["image"], operations=center_crop_op) |
|
|
|
assert data.get_col_names() == ["image", "label"] |
|
|
|
data = ds.Cifar10Dataset(CIFAR10_DIR) |
|
|
|
data = data.map(input_columns=["image"], operations=center_crop_op, output_columns=["image"]) |
|
|
|
assert data.get_col_names() == ["image", "label"] |
|
|
|
data = ds.Cifar10Dataset(CIFAR10_DIR) |
|
|
|
data = data.map(input_columns=["image"], operations=center_crop_op, output_columns=["col1"]) |
|
|
|
assert data.get_col_names() == ["col1", "label"] |
|
|
|
data = ds.Cifar10Dataset(CIFAR10_DIR) |
|
|
|
data = data.map(input_columns=["image"], operations=center_crop_op, output_columns=["col1", "col2"], |
|
|
|
columns_order=["col2", "col1"]) |
|
|
|
assert data.get_col_names() == ["col2", "col1"] |
|
|
|
|
|
|
|
|
|
|
|
def test_get_column_name_mnist(): |
|
|
|
data = ds.MnistDataset(MNIST_DIR) |
|
|
|
assert data.get_col_names() == ["image", "label"] |
|
|
|
|
|
|
|
|
|
|
|
def test_get_column_name_numpy_slices(): |
|
|
|
np_data = {"a": [1, 2], "b": [3, 4]} |
|
|
|
data = ds.NumpySlicesDataset(np_data, shuffle=False) |
|
|
|
assert data.get_col_names() == ["a", "b"] |
|
|
|
data = ds.NumpySlicesDataset([1, 2, 3], shuffle=False) |
|
|
|
assert data.get_col_names() == ["column_0"] |
|
|
|
|
|
|
|
|
|
|
|
def test_get_column_name_tfrecord(): |
|
|
|
data = ds.TFRecordDataset(TFRECORD_DIR, TFRECORD_SCHEMA) |
|
|
|
assert data.get_col_names() == ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", "col_sint16", "col_sint32", |
|
|
|
"col_sint64"] |
|
|
|
data = ds.TFRecordDataset(TFRECORD_DIR, TFRECORD_SCHEMA, |
|
|
|
columns_list=["col_sint16", "col_sint64", "col_2d", "col_binary"]) |
|
|
|
assert data.get_col_names() == ["col_sint16", "col_sint64", "col_2d", "col_binary"] |
|
|
|
|
|
|
|
data = ds.TFRecordDataset(TFRECORD_DIR) |
|
|
|
assert data.get_col_names() == ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", "col_sint16", "col_sint32", |
|
|
|
"col_sint64", "col_sint8"] |
|
|
|
s = ds.Schema() |
|
|
|
s.add_column("line", "string", []) |
|
|
|
s.add_column("words", "string", [-1]) |
|
|
|
s.add_column("chinese", "string", []) |
|
|
|
|
|
|
|
data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False, schema=s) |
|
|
|
assert data.get_col_names() == ["line", "words", "chinese"] |
|
|
|
|
|
|
|
|
|
|
|
def test_get_column_name_to_device(): |
|
|
|
data = ds.Cifar10Dataset(CIFAR10_DIR) |
|
|
|
data = data.to_device() |
|
|
|
assert data.get_col_names() == ["image", "label"] |
|
|
|
|
|
|
|
|
|
|
|
def test_get_column_name_voc(): |
|
|
|
data = ds.VOCDataset(VOC_DIR, task="Segmentation", mode="train", decode=True, shuffle=False) |
|
|
|
assert data.get_col_names() == ["image", "target"] |
|
|
|
|
|
|
|
|
|
|
|
def test_get_column_name_project(): |
|
|
|
data = ds.Cifar10Dataset(CIFAR10_DIR) |
|
|
|
assert data.get_col_names() == ["image", "label"] |
|
|
|
data = data.project(columns=["image"]) |
|
|
|
assert data.get_col_names() == ["image"] |
|
|
|
|
|
|
|
|
|
|
|
def test_get_column_name_rename(): |
|
|
|
data = ds.Cifar10Dataset(CIFAR10_DIR) |
|
|
|
assert data.get_col_names() == ["image", "label"] |
|
|
|
data = data.rename(["image", "label"], ["test1", "test2"]) |
|
|
|
assert data.get_col_names() == ["test1", "test2"] |
|
|
|
|
|
|
|
|
|
|
|
def test_get_column_name_zip(): |
|
|
|
data1 = ds.Cifar10Dataset(CIFAR10_DIR) |
|
|
|
assert data1.get_col_names() == ["image", "label"] |
|
|
|
data2 = ds.CSVDataset(CSV_DIR) |
|
|
|
assert data2.get_col_names() == ["1", "2", "3", "4"] |
|
|
|
data = ds.zip((data1, data2)) |
|
|
|
assert data.get_col_names() == ["image", "label", "1", "2", "3", "4"] |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
test_get_column_name_celeba() |
|
|
|
test_get_column_name_cifar10() |
|
|
|
test_get_column_name_cifar100() |
|
|
|
test_get_column_name_clue() |
|
|
|
test_get_column_name_coco() |
|
|
|
test_get_column_name_csv() |
|
|
|
test_get_column_name_generator() |
|
|
|
test_get_column_name_imagefolder() |
|
|
|
test_get_column_name_iterator() |
|
|
|
test_get_column_name_manifest() |
|
|
|
test_get_column_name_map() |
|
|
|
test_get_column_name_mnist() |
|
|
|
test_get_column_name_numpy_slices() |
|
|
|
test_get_column_name_tfrecord() |
|
|
|
test_get_column_name_to_device() |
|
|
|
test_get_column_name_voc() |
|
|
|
test_get_column_name_project() |
|
|
|
test_get_column_name_rename() |
|
|
|
test_get_column_name_zip() |