# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import re import pytest import mindspore.dataset as ds from mindspore import log as logger DATA_DIR = "../data/dataset/testLFW" def test_lfw_basic(): """ Feature: LFW Description: test basic usage of LFW Expectation: the dataset is as expected """ logger.info("Test Case basic") # define parameters. repeat_count = 2 # apply dataset operations. data1 = ds.LFWDataset(DATA_DIR) data1 = data1.repeat(repeat_count) num_iter = 0 # each data is a dictionary. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # in this example, each dictionary has keys "image" and "label". logger.info("image is {}".format(item["image"])) logger.info("label is {}".format(item["label"])) num_iter += 1 logger.info("Number of data in data1: {}".format(num_iter)) assert num_iter == 8 def test_lfw_task(): """ Feature: LFW Description: test basic usage of LFW Expectation: the dataset is as expected """ logger.info("Test Case basic") # define parameters. repeat_count = 2 # apply dataset operations. data1 = ds.LFWDataset(DATA_DIR, task="pairs", usage="all") data1 = data1.repeat(repeat_count) num_iter = 0 # each data is a dictionary. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # in this example, each dictionary has keys "image" and "label". logger.info("image1 is {}".format(item["image1"])) logger.info("image2 is {}".format(item["image2"])) logger.info("label is {}".format(item["label"])) num_iter += 1 logger.info("Number of data in data1: {}".format(num_iter)) assert num_iter == 16 def test_lfw_usage(): """ Feature: LFW Description: test basic usage of LFW Expectation: the dataset is as expected """ logger.info("Test Case basic") # define parameters. repeat_count = 2 # apply dataset operations. data1 = ds.LFWDataset(DATA_DIR, usage="test") data1 = data1.repeat(repeat_count) num_iter = 0 # each data is a dictionary. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # in this example, each dictionary has keys "image" and "label". logger.info("image is {}".format(item["image"])) logger.info("label is {}".format(item["label"])) num_iter += 1 logger.info("Number of data in data1: {}".format(num_iter)) assert num_iter == 6 def test_lfw_image_set(): """ Feature: LFW Description: test basic usage of LFW Expectation: the dataset is as expected """ logger.info("Test Case basic") # define parameters. repeat_count = 2 # apply dataset operations. data1 = ds.LFWDataset(DATA_DIR, image_set="funneled") data1 = data1.repeat(repeat_count) num_iter = 0 # each data is a dictionary. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # in this example, each dictionary has keys "image" and "label". logger.info("image is {}".format(item["image"])) logger.info("label is {}".format(item["label"])) num_iter += 1 logger.info("Number of data in data1: {}".format(num_iter)) assert num_iter == 8 def test_lfw_num_samples(): """ Feature: LFW Description: test basic usage of LFW Expectation: the dataset is as expected """ logger.info("Test Case numSamples") # define parameters. repeat_count = 2 # apply dataset operations. data1 = ds.LFWDataset(DATA_DIR, num_samples=4, num_parallel_workers=2) data1 = data1.repeat(repeat_count) num_iter = 0 # each data is a dictionary. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # in this example, each dictionary has keys "image" and "label". logger.info("image is {}".format(item["image"])) logger.info("label is {}".format(item["label"])) num_iter += 1 logger.info("Number of data in data1: {}".format(num_iter)) assert num_iter == 8 random_sampler = ds.RandomSampler(num_samples=2, replacement=True) data1 = ds.LFWDataset(DATA_DIR, num_parallel_workers=2, sampler=random_sampler) num_iter = 0 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): num_iter += 1 assert num_iter == 2 random_sampler = ds.RandomSampler(num_samples=3, replacement=False) data1 = ds.LFWDataset(DATA_DIR, num_parallel_workers=2, sampler=random_sampler) num_iter = 0 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): num_iter += 1 assert num_iter == 3 def test_lfw_num_shards(): """ Feature: LFW Description: test basic usage of LFW Expectation: the dataset is as expected """ logger.info("Test Case numShards") # define parameters. repeat_count = 2 # apply dataset operations. data1 = ds.LFWDataset(DATA_DIR, num_shards=5, shard_id=1) data1 = data1.repeat(repeat_count) num_iter = 0 # each data is a dictionary. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # in this example, each dictionary has keys "image" and "label". logger.info("image is {}".format(item["image"])) logger.info("label is {}".format(item["label"])) num_iter += 1 logger.info("Number of data in data1: {}".format(num_iter)) assert num_iter == 2 def test_lfw_shard_id(): """ Feature: LFW Description: test basic usage of LFW Expectation: the dataset is as expected """ logger.info("Test Case withShardID") # define parameters. repeat_count = 2 # apply dataset operations. data1 = ds.LFWDataset(DATA_DIR, num_shards=4, shard_id=1) data1 = data1.repeat(repeat_count) num_iter = 0 # each data is a dictionary. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # in this example, each dictionary has keys "image" and "label". logger.info("image is {}".format(item["image"])) logger.info("label is {}".format(item["label"])) num_iter += 1 logger.info("Number of data in data1: {}".format(num_iter)) assert num_iter == 2 def test_lfw_no_shuffle(): """ Feature: LFW Description: test dataset of LFW Expectation: the dataset is as expected """ logger.info("Test Case noShuffle") # define parameters. repeat_count = 2 # apply dataset operations. data1 = ds.LFWDataset(DATA_DIR, shuffle=False) data1 = data1.repeat(repeat_count) num_iter = 0 # each data is a dictionary. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # in this example, each dictionary has keys "image" and "label". logger.info("image is {}".format(item["image"])) logger.info("label is {}".format(item["label"])) num_iter += 1 logger.info("Number of data in data1: {}".format(num_iter)) assert num_iter == 8 def test_lfw_decode(): """ Feature: LFW Description: test basic usage of LFW Expectation: the dataset is as expected """ logger.info("Test Case decode") # define parameters. repeat_count = 2 # apply dataset operations. data1 = ds.LFWDataset(DATA_DIR, decode=True) data1 = data1.repeat(repeat_count) num_iter = 0 # each data is a dictionary. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # in this example, each dictionary has keys "image" and "label". logger.info("image is {}".format(item["image"])) logger.info("label is {}".format(item["label"])) num_iter += 1 logger.info("Number of data in data1: {}".format(num_iter)) assert num_iter == 8 def test_sequential_sampler(): """ Feature: LFW Description: test basic usage of LFW Expectation: the dataset is as expected """ logger.info("Test Case SequentialSampler") # define parameters. repeat_count = 2 # apply dataset operations. sampler = ds.SequentialSampler(num_samples=3) data1 = ds.LFWDataset(DATA_DIR, sampler=sampler) data1 = data1.repeat(repeat_count) result = [] num_iter = 0 # each data is a dictionary. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # in this example, each dictionary has keys "image" and "label". result.append(item["label"]) num_iter += 1 assert num_iter == 6 logger.info("Result: {}".format(result)) def test_random_and_sequentialchained_sampler(): """ Feature: LFW Description: test basic usage of LFW Expectation: the dataset is as expected """ logger.info("Test Case Chained Sampler - Random and Sequential, with repeat") # Create chained sampler, random and sequential. sampler = ds.RandomSampler() child_sampler = ds.SequentialSampler() sampler.add_child(child_sampler) # Create LFWDataset with sampler. data1 = ds.LFWDataset(DATA_DIR, sampler=sampler) data1 = data1.repeat(count=3) # Verify dataset size. data1_size = data1.get_dataset_size() logger.info("dataset size is: {}".format(data1_size)) assert data1_size == 12 # Verify number of iterations. num_iter = 0 # each data is a dictionary. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # in this example, each dictionary has keys "image" and "label". logger.info("image is {}".format(item["image"])) logger.info("label is {}".format(item["label"])) num_iter += 1 logger.info("Number of data in data1: {}".format(num_iter)) assert num_iter == 12 def test_lfw_rename(): """ Feature: LFW Description: test basic usage of LFW Expectation: the dataset is as expected """ logger.info("Test Case rename") # define parameters. repeat_count = 2 # apply dataset operations. data1 = ds.LFWDataset(DATA_DIR, num_samples=4) data1 = data1.repeat(repeat_count) num_iter = 0 # each data is a dictionary. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # in this example, each dictionary has keys "image" and "label". logger.info("image is {}".format(item["image"])) logger.info("label is {}".format(item["label"])) num_iter += 1 logger.info("Number of data in data1: {}".format(num_iter)) assert num_iter == 8 data1 = data1.rename(input_columns=["image"], output_columns="image2") num_iter = 0 # each data is a dictionary. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # in this example, each dictionary has keys "image" and "label". logger.info("image is {}".format(item["image2"])) logger.info("label is {}".format(item["label"])) num_iter += 1 logger.info("Number of data in data1: {}".format(num_iter)) assert num_iter == 8 def test_lfw_zip(): """ Feature: LFW Description: test basic usage of LFW Expectation: the dataset is as expected """ logger.info("Test Case zip") # define parameters. repeat_count = 2 # apply dataset operations. data1 = ds.LFWDataset(DATA_DIR, num_samples=3) data2 = ds.LFWDataset(DATA_DIR, num_samples=3) data1 = data1.repeat(repeat_count) # rename dataset2 for no conflict. data2 = data2.rename(input_columns=["image", "label"], output_columns=["image1", "label1"]) data3 = ds.zip((data1, data2)) num_iter = 0 # each data is a dictionary. for item in data3.create_dict_iterator(num_epochs=1, output_numpy=True): # in this example, each dictionary has keys "image" and "label". logger.info("image is {}".format(item["image"])) logger.info("label is {}".format(item["label"])) num_iter += 1 logger.info("Number of data in data1: {}".format(num_iter)) assert num_iter == 3 def test_lfw_exception(): """ Feature: LFW Description: test error cases of LFW Expectation: throw exception correctly """ error_msg_1 = "sampler and shuffle cannot be specified at the same time" with pytest.raises(RuntimeError, match=error_msg_1): ds.LFWDataset(DATA_DIR, shuffle=False, decode=True, sampler=ds.SequentialSampler(1)) error_msg_2 = "sampler and sharding cannot be specified at the same time" with pytest.raises(RuntimeError, match=error_msg_2): ds.LFWDataset(DATA_DIR, sampler=ds.SequentialSampler(1), decode=True, num_shards=2, shard_id=0) error_msg_3 = "num_shards is specified and currently requires shard_id as well" with pytest.raises(RuntimeError, match=error_msg_3): ds.LFWDataset(DATA_DIR, decode=True, num_shards=10) error_msg_4 = "shard_id is specified but num_shards is not" with pytest.raises(RuntimeError, match=error_msg_4): ds.LFWDataset(DATA_DIR, decode=True, shard_id=0) error_msg_5 = "Input shard_id is not within the required interval" with pytest.raises(ValueError, match=error_msg_5): ds.LFWDataset(DATA_DIR, decode=True, num_shards=5, shard_id=-1) with pytest.raises(ValueError, match=error_msg_5): ds.LFWDataset(DATA_DIR, decode=True, num_shards=5, shard_id=5) error_msg_6 = "num_parallel_workers exceeds" with pytest.raises(ValueError, match=error_msg_6): ds.LFWDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=0) with pytest.raises(ValueError, match=error_msg_6): ds.LFWDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=256) error_msg_7 = "Argument shard_id" with pytest.raises(TypeError, match=error_msg_7): ds.LFWDataset(DATA_DIR, decode=True, num_shards=2, shard_id="0") error_msg_8 = "does not exist or is not a directory or permission denied!" with pytest.raises(ValueError, match=error_msg_8): all_data = ds.LFWDataset("../data/dataset/testLFW2", decode=True) for _ in all_data.create_dict_iterator(num_epochs=1): pass error_msg_9 = "Input task is not within the valid set of ['people', 'pairs']." with pytest.raises(ValueError, match=re.escape(error_msg_9)): all_data = ds.LFWDataset(DATA_DIR, task="all") for _ in all_data.create_dict_iterator(num_epochs=1): pass error_msg_10 = "Input usage is not within the valid set of ['10fold', 'train', 'test', 'all']." with pytest.raises(ValueError, match=re.escape(error_msg_10)): all_data = ds.LFWDataset(DATA_DIR, usage="many") for _ in all_data.create_dict_iterator(num_epochs=1): pass error_msg_11 = "Input image_set is not within the valid set of ['original', 'funneled', 'deepfunneled']." with pytest.raises(ValueError, match=re.escape(error_msg_11)): all_data = ds.LFWDataset(DATA_DIR, image_set="all") for _ in all_data.create_dict_iterator(num_epochs=1): pass error_msg_12 = "Argument decode with value 123 is not of type [], but got ." with pytest.raises(TypeError, match=re.escape(error_msg_12)): all_data = ds.LFWDataset(DATA_DIR, decode=123) for _ in all_data.create_dict_iterator(num_epochs=1): pass if __name__ == '__main__': test_lfw_basic() test_lfw_task() test_lfw_usage() test_lfw_image_set() test_lfw_num_samples() test_sequential_sampler() test_random_and_sequentialchained_sampler() test_lfw_num_shards() test_lfw_shard_id() test_lfw_no_shuffle() test_lfw_decode() test_lfw_rename() test_lfw_zip() test_lfw_exception()