|
|
|
@@ -25,6 +25,7 @@ import mindspore.dataset.transforms.vision.c_transforms as vision |
|
|
|
import numpy as np |
|
|
|
import pytest |
|
|
|
from mindspore._c_dataengine import InterpolationMode |
|
|
|
from mindspore.dataset.transforms.vision import Inter |
|
|
|
from mindspore import log as logger |
|
|
|
|
|
|
|
import mindspore.dataset as ds |
|
|
|
@@ -151,6 +152,51 @@ def test_cv_minddataset_dataset_size(add_and_remove_cv_file): |
|
|
|
assert data_set.get_dataset_size() == 3 |
|
|
|
|
|
|
|
|
|
|
|
def test_cv_minddataset_repeat_reshuffle(add_and_remove_cv_file): |
|
|
|
"""tutorial for cv minddataset.""" |
|
|
|
columns_list = ["data", "label"] |
|
|
|
num_readers = 4 |
|
|
|
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers) |
|
|
|
decode_op = vision.Decode() |
|
|
|
data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=2) |
|
|
|
resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR) |
|
|
|
data_set = data_set.map(input_columns="data", operations=resize_op, num_parallel_workers=2) |
|
|
|
data_set = data_set.batch(2) |
|
|
|
data_set = data_set.repeat(2) |
|
|
|
num_iter = 0 |
|
|
|
labels = [] |
|
|
|
for item in data_set.create_dict_iterator(): |
|
|
|
logger.info("-------------- get dataset size {} -----------------".format(num_iter)) |
|
|
|
logger.info("-------------- item[label]: {} ---------------------".format(item["label"])) |
|
|
|
logger.info("-------------- item[data]: {} ----------------------".format(item["data"])) |
|
|
|
num_iter += 1 |
|
|
|
labels.append(item["label"]) |
|
|
|
assert num_iter == 10 |
|
|
|
logger.info("repeat shuffle: {}".format(labels)) |
|
|
|
assert len(labels) == 10 |
|
|
|
assert labels[0:5] == labels[0:5] |
|
|
|
assert labels[0:5] != labels[5:5] |
|
|
|
|
|
|
|
|
|
|
|
def test_cv_minddataset_batch_size_larger_than_records(add_and_remove_cv_file): |
|
|
|
"""tutorial for cv minddataset.""" |
|
|
|
columns_list = ["data", "label"] |
|
|
|
num_readers = 4 |
|
|
|
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers) |
|
|
|
decode_op = vision.Decode() |
|
|
|
data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=2) |
|
|
|
resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR) |
|
|
|
data_set = data_set.map(input_columns="data", operations=resize_op, num_parallel_workers=2) |
|
|
|
data_set = data_set.batch(32, drop_remainder=True) |
|
|
|
num_iter = 0 |
|
|
|
for item in data_set.create_dict_iterator(): |
|
|
|
logger.info("-------------- get dataset size {} -----------------".format(num_iter)) |
|
|
|
logger.info("-------------- item[label]: {} ---------------------".format(item["label"])) |
|
|
|
logger.info("-------------- item[data]: {} ----------------------".format(item["data"])) |
|
|
|
num_iter += 1 |
|
|
|
assert num_iter == 0 |
|
|
|
|
|
|
|
|
|
|
|
def test_cv_minddataset_issue_888(add_and_remove_cv_file): |
|
|
|
"""issue 888 test.""" |
|
|
|
columns_list = ["data", "label"] |
|
|
|
|