Merge pull request !1457 from Peilin/splitOp-after-testingtags/v0.5.0-beta
| @@ -71,7 +71,7 @@ if __name__ == '__main__': | |||
| model = Model(network, loss, opt, {'acc': Accuracy()}) | |||
| print("============== Starting Training ==============") | |||
| ds_train = create_dataset(args.preprocess_path, cfg.batch_size, repeat_num=cfg.num_epochs) | |||
| ds_train = create_dataset(args.preprocess_path, cfg.batch_size, cfg.num_epochs) | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, | |||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | |||
| ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck) | |||
| @@ -70,21 +70,26 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| } | |||
| Status RandomSampler::InitSampler() { | |||
| num_samples_ = (user_num_samples_ < num_samples_) ? user_num_samples_ : num_samples_; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive"); | |||
| samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows needs to be positive."); | |||
| rnd_.seed(seed_); | |||
| if (replacement_ == false) { | |||
| num_samples_ = std::min(num_samples_, num_rows_); | |||
| shuffled_ids_.reserve(num_rows_); | |||
| for (int64_t i = 0; i < num_rows_; i++) { | |||
| shuffled_ids_.push_back(i); | |||
| } | |||
| std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_); | |||
| } else { | |||
| num_samples_ = std::min(num_samples_, user_num_samples_); | |||
| dist = std::make_unique<std::uniform_int_distribution<int64_t>>(0, num_rows_ - 1); | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples needs to be positive."); | |||
| samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; | |||
| return Status::OK(); | |||
| } | |||
| @@ -32,9 +32,7 @@ Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { | |||
| } | |||
| // Handshake and init child first. | |||
| if (HasChildSampler()) { | |||
| RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op)); | |||
| } | |||
| RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op)); | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n"); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -28,9 +28,9 @@ SubsetSampler::SubsetSampler(int64_t start_index, int64_t subset_size) | |||
| : Sampler(subset_size), start_index_(start_index), subset_size_(subset_size), current_id_(0) {} | |||
| Status SubsetSampler::InitSampler() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size_ <= 0\n"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size <= 0\n"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows_\n"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows\n"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(start_index_ + subset_size_ - 1 < num_rows_, "Final index out of bounds.\n"); | |||
| num_samples_ = subset_size_; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -23,7 +23,7 @@ from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset | |||
| GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, TextFileDataset, \ | |||
| Schema, Shuffle, zip, RandomDataset | |||
| from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ | |||
| WeightedRandomSampler, Sampler | |||
| WeightedRandomSampler, SubsetSampler, Sampler | |||
| from .engine.serializer_deserializer import serialize, deserialize, show | |||
| from .engine.graphdata import GraphData | |||
| @@ -633,9 +633,9 @@ class Dataset: | |||
| Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size | |||
| of the original dataset. If after rounding, any size equals 0, an error will occur. | |||
| All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur. | |||
| randomize (bool): determines whether or not to split the data randomly. If true, the data | |||
| will be randomly split. Otherwise, each split will be created with consecutive rows | |||
| from the dataset. | |||
| randomize (bool, optional): determines whether or not to split the data randomly (default=True). | |||
| If true, the data will be randomly split. Otherwise, each split will be created with | |||
| consecutive rows from the dataset. | |||
| Note: | |||
| 1. Dataset cannot be sharded if split is going to be called. | |||
| @@ -678,7 +678,8 @@ class Dataset: | |||
| ds = copy.deepcopy(self) | |||
| if randomize: | |||
| # want to shuffle the same way every epoch before split | |||
| ds = ds.shuffle() | |||
| # in alter_tree, shuffle buffer is minimum 10000, so use 10000 here | |||
| ds = ds.shuffle(10000) | |||
| ds.reshuffle_each_epoch = False | |||
| if rows_to_skip > 0: | |||
| @@ -1209,6 +1210,9 @@ class MappableDataset(SourceDataset): | |||
| >>> new_sampler = ds.DistributedSampler(10, 2) | |||
| >>> data.use_sampler(new_sampler) | |||
| """ | |||
| if new_sampler is not None and not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)): | |||
| raise TypeError("new_sampler is not an instance of a sampler.") | |||
| self.sampler = self.sampler.child_sampler | |||
| self.add_sampler(new_sampler) | |||
| @@ -1218,6 +1222,11 @@ class MappableDataset(SourceDataset): | |||
| def is_sharded(self): | |||
| raise NotImplementedError("MappableDataset must implement is_sharded.") | |||
| def _get_sampler_dataset_size(self): | |||
| if self.sampler is not None: | |||
| return self.sampler.get_dataset_size() | |||
| return None | |||
| @check_split | |||
| def split(self, sizes, randomize=True): | |||
| @@ -1236,9 +1245,9 @@ class MappableDataset(SourceDataset): | |||
| Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size | |||
| of the original dataset. If after rounding, any size equals 0, an error will occur. | |||
| All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur. | |||
| randomize (bool): determines whether or not to split the data randomly. If true, the data | |||
| will be randomly split. Otherwise, each split will be created with consecutive rows | |||
| from the dataset. | |||
| randomize (bool, optional): determines whether or not to split the data randomly (default=True). | |||
| If true, the data will be randomly split. Otherwise, each split will be created with | |||
| consecutive rows from the dataset. | |||
| Note: | |||
| 1. Dataset should not be sharded if split is going to be called. Instead, create a | |||
| @@ -2105,7 +2114,6 @@ class TransferDataset(DatasetOp): | |||
| self.iterator = TupleIterator(self) | |||
| class RangeDataset(MappableDataset): | |||
| """ | |||
| A source dataset that reads and parses datasets stored on disk in a range. | |||
| @@ -2296,8 +2304,13 @@ class ImageFolderDatasetV2(MappableDataset): | |||
| else: | |||
| num_samples = self.num_samples | |||
| num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir, num_samples)[0] | |||
| rows_per_shard = get_num_rows(num_rows, self.num_shards) | |||
| rows_from_sampler = self._get_sampler_dataset_size() | |||
| return get_num_rows(num_rows, self.num_shards) | |||
| if rows_from_sampler is None: | |||
| return rows_per_shard | |||
| return min(rows_from_sampler, rows_per_shard) | |||
| def num_classes(self): | |||
| """ | |||
| @@ -2425,8 +2438,13 @@ class MnistDataset(MappableDataset): | |||
| num_samples = self.num_samples | |||
| num_rows = MnistOp.get_num_rows(self.dataset_dir, num_samples) | |||
| rows_per_shard = get_num_rows(num_rows, self.num_shards) | |||
| rows_from_sampler = self._get_sampler_dataset_size() | |||
| if rows_from_sampler is None: | |||
| return rows_per_shard | |||
| return get_num_rows(num_rows, self.num_shards) | |||
| return min(rows_from_sampler, rows_per_shard) | |||
| def is_shuffled(self): | |||
| if self.shuffle_level is None: | |||
| @@ -2926,7 +2944,12 @@ class GeneratorDataset(MappableDataset): | |||
| Return: | |||
| Number, number of batches. | |||
| """ | |||
| return self._dataset_size | |||
| rows_from_sampler = self._get_sampler_dataset_size() | |||
| if rows_from_sampler is None: | |||
| return self._dataset_size | |||
| return min(rows_from_sampler, self._dataset_size) | |||
| # manually set dataset_size as a temporary solution. | |||
| def set_dataset_size(self, value): | |||
| @@ -3220,8 +3243,13 @@ class ManifestDataset(MappableDataset): | |||
| class_indexing = self.class_indexing | |||
| num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, num_samples, class_indexing, self.usage)[0] | |||
| rows_per_shard = get_num_rows(num_rows, self.num_shards) | |||
| rows_from_sampler = self._get_sampler_dataset_size() | |||
| if rows_from_sampler is None: | |||
| return rows_per_shard | |||
| return get_num_rows(num_rows, self.num_shards) | |||
| return min(rows_from_sampler, rows_per_shard) | |||
| def num_classes(self): | |||
| """ | |||
| @@ -3379,8 +3407,13 @@ class Cifar10Dataset(MappableDataset): | |||
| num_samples = self.num_samples | |||
| num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, True) | |||
| rows_per_shard = get_num_rows(num_rows, self.num_shards) | |||
| rows_from_sampler = self._get_sampler_dataset_size() | |||
| return get_num_rows(num_rows, self.num_shards) | |||
| if rows_from_sampler is None: | |||
| return rows_per_shard | |||
| return min(rows_from_sampler, rows_per_shard) | |||
| def is_shuffled(self): | |||
| if self.shuffle_level is None: | |||
| @@ -3498,8 +3531,13 @@ class Cifar100Dataset(MappableDataset): | |||
| num_samples = self.num_samples | |||
| num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, False) | |||
| rows_per_shard = get_num_rows(num_rows, self.num_shards) | |||
| rows_from_sampler = self._get_sampler_dataset_size() | |||
| if rows_from_sampler is None: | |||
| return rows_per_shard | |||
| return get_num_rows(num_rows, self.num_shards) | |||
| return min(rows_from_sampler, rows_per_shard) | |||
| def is_shuffled(self): | |||
| if self.shuffle_level is None: | |||
| @@ -3562,7 +3600,12 @@ class RandomDataset(SourceDataset): | |||
| Return: | |||
| Number, number of batches. | |||
| """ | |||
| return num_samples | |||
| rows_from_sampler = self._get_sampler_dataset_size() | |||
| if rows_from_sampler is None: | |||
| return self.num_samples | |||
| return min(rows_from_sampler, self.num_samples) | |||
| def is_shuffled(self): | |||
| return True | |||
| @@ -3871,7 +3914,12 @@ class VOCDataset(MappableDataset): | |||
| Return: | |||
| Number, number of batches. | |||
| """ | |||
| return self.num_samples | |||
| rows_from_sampler = self._get_sampler_dataset_size() | |||
| if rows_from_sampler is None: | |||
| return self.num_samples | |||
| return min(rows_from_sampler, self.num_samples) | |||
| def get_class_indexing(self): | |||
| """ | |||
| @@ -114,6 +114,9 @@ class Sampler: | |||
| return self.child_sampler.is_sharded() | |||
| def get_dataset_size(self): | |||
| return self._get_indices().size | |||
| class BuiltinSampler: | |||
| """ | |||
| @@ -146,6 +149,12 @@ class BuiltinSampler: | |||
| def is_sharded(self): | |||
| raise NotImplementedError("Sampler must implement is_sharded.") | |||
| def get_dataset_size(self): | |||
| if self.child_sampler is not None: | |||
| return self.child_sampler.get_dataset_size() | |||
| return None | |||
| class DistributedSampler(BuiltinSampler): | |||
| """ | |||
| @@ -330,6 +339,9 @@ class RandomSampler(BuiltinSampler): | |||
| return self.child_sampler.is_sharded() | |||
| def get_dataset_size(self): | |||
| return self.num_samples | |||
| class SequentialSampler(BuiltinSampler): | |||
| """ | |||
| @@ -421,6 +433,9 @@ class SubsetSampler(BuiltinSampler): | |||
| return self.child_sampler.is_sharded() | |||
| def get_dataset_size(self): | |||
| return self.subset_size | |||
| class SubsetRandomSampler(BuiltinSampler): | |||
| """ | |||
| @@ -467,6 +482,10 @@ class SubsetRandomSampler(BuiltinSampler): | |||
| return cde.MindrecordSubsetRandomSampler(self.indices) | |||
| def get_dataset_size(self): | |||
| return len(indices) | |||
| class WeightedRandomSampler(BuiltinSampler): | |||
| """ | |||
| Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities). | |||
| @@ -522,3 +541,6 @@ class WeightedRandomSampler(BuiltinSampler): | |||
| return False | |||
| return self.child_sampler.is_sharded() | |||
| def get_dataset_size(self): | |||
| return self.num_samples | |||
| @@ -13,6 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.dataset as ds | |||
| from mindspore import log as logger | |||
| @@ -164,6 +165,35 @@ def test_python_sampler(): | |||
| assert list(sp1.get_indices()) == [0, 1, 2, 3, 4] | |||
| def test_subset_sampler(): | |||
| manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" | |||
| map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | |||
| def test_config(num_samples, start_index, subset_size): | |||
| sampler = ds.SubsetSampler(start_index, subset_size) | |||
| d = ds.ManifestDataset(manifest_file, sampler=sampler) | |||
| res = [] | |||
| for item in d.create_dict_iterator(): | |||
| res.append(map[(item["image"].shape[0], item["label"].item())]) | |||
| return res | |||
| with pytest.raises(RuntimeError) as info: | |||
| test_config(5, 0, 0) | |||
| assert "subset_size <= 0" in str(info.value) | |||
| assert test_config(5, 0, 1) == [0] | |||
| assert test_config(5, 0, 2) == [0, 1] | |||
| assert test_config(5, 0, 3) == [0, 1, 2] | |||
| assert test_config(5, 0, 4) == [0, 1, 2, 3] | |||
| assert test_config(5, 0, 5) == [0, 1, 2, 3, 4] | |||
| assert test_config(5, 1, 1) == [1] | |||
| assert test_config(5, 2, 3) == [2, 3, 4] | |||
| assert test_config(5, 3, 2) == [3, 4] | |||
| assert test_config(5, 4, 1) == [4] | |||
| def test_sampler_chain(): | |||
| manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" | |||
| map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | |||
| @@ -190,10 +220,26 @@ def test_sampler_chain(): | |||
| assert test_config(5, 3) == [3] | |||
| assert test_config(5, 4) == [4] | |||
| def test_add_sampler_invalid_input(): | |||
| manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" | |||
| map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | |||
| data1 = ds.ManifestDataset(manifest_file) | |||
| with pytest.raises(TypeError) as info: | |||
| data1.use_sampler(1) | |||
| assert "not an instance of a sampler" in str(info.value) | |||
| with pytest.raises(TypeError) as info: | |||
| data1.use_sampler("sampler") | |||
| assert "not an instance of a sampler" in str(info.value) | |||
| if __name__ == '__main__': | |||
| test_sequential_sampler(True) | |||
| test_random_sampler(True) | |||
| test_random_sampler_multi_iter(True) | |||
| test_sampler_py_api() | |||
| test_python_sampler() | |||
| test_subset_sampler() | |||
| test_sampler_chain() | |||
| test_add_sampler_invalid_input() | |||
| @@ -23,7 +23,11 @@ from util import config_get_set_num_parallel_workers | |||
| manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" | |||
| manifest_map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | |||
| def split_with_invalid_inputs(d): | |||
| text_file_dataset_path = "../data/dataset/testTextFileDataset/*" | |||
| text_file_data = ["This is a text file.", "Another file.", "Be happy every day.", | |||
| "End of file.", "Good luck to everyone."] | |||
| def split_with_invalid_inputs(d): | |||
| with pytest.raises(ValueError) as info: | |||
| s1, s2 = d.split([]) | |||
| assert "sizes cannot be empty" in str(info.value) | |||
| @@ -68,8 +72,8 @@ def split_with_invalid_inputs(d): | |||
| s1, s2 = d.split([0.05, 0.95]) | |||
| assert "percentage 0.05 is too small" in str(info.value) | |||
| def test_unmappable_invalid_input(): | |||
| text_file_dataset_path = "../data/dataset/testTextFileDataset/*" | |||
| d = ds.TextFileDataset(text_file_dataset_path) | |||
| split_with_invalid_inputs(d) | |||
| @@ -78,11 +82,10 @@ def test_unmappable_invalid_input(): | |||
| s1, s2 = d.split([4, 1]) | |||
| assert "dataset should not be sharded before split" in str(info.value) | |||
| def test_unmappable_split(): | |||
| text_file_dataset_path = "../data/dataset/testTextFileDataset/*" | |||
| text_file_data = ["This is a text file.", "Another file.", "Be happy every day.", | |||
| "End of file.", "Good luck to everyone."] | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(4) | |||
| d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) | |||
| s1, s2 = d.split([4, 1], randomize=False) | |||
| @@ -124,6 +127,142 @@ def test_unmappable_split(): | |||
| assert s1_output == text_file_data[0:2] | |||
| assert s2_output == text_file_data[2:] | |||
| # Restore configuration num_parallel_workers | |||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||
| def test_unmappable_randomize_deterministic(): | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(4) | |||
| # the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3] | |||
| ds.config.set_seed(53) | |||
| d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) | |||
| s1, s2 = d.split([0.8, 0.2]) | |||
| for _ in range(10): | |||
| s1_output = [] | |||
| for item in s1.create_dict_iterator(): | |||
| s1_output.append(item["text"].item().decode("utf8")) | |||
| s2_output = [] | |||
| for item in s2.create_dict_iterator(): | |||
| s2_output.append(item["text"].item().decode("utf8")) | |||
| # note no overlap | |||
| assert s1_output == [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]] | |||
| assert s2_output == [text_file_data[3]] | |||
| # Restore configuration num_parallel_workers | |||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||
| def test_unmappable_randomize_repeatable(): | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(4) | |||
| # the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3] | |||
| ds.config.set_seed(53) | |||
| d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) | |||
| s1, s2 = d.split([0.8, 0.2]) | |||
| num_epochs = 5 | |||
| s1 = s1.repeat(num_epochs) | |||
| s2 = s2.repeat(num_epochs) | |||
| s1_output = [] | |||
| for item in s1.create_dict_iterator(): | |||
| s1_output.append(item["text"].item().decode("utf8")) | |||
| s2_output = [] | |||
| for item in s2.create_dict_iterator(): | |||
| s2_output.append(item["text"].item().decode("utf8")) | |||
| # note no overlap | |||
| assert s1_output == [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]] * num_epochs | |||
| assert s2_output == [text_file_data[3]] * num_epochs | |||
| # Restore configuration num_parallel_workers | |||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||
| def test_unmappable_get_dataset_size(): | |||
| d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) | |||
| s1, s2 = d.split([0.8, 0.2]) | |||
| assert d.get_dataset_size() == 5 | |||
| assert s1.get_dataset_size() == 4 | |||
| assert s2.get_dataset_size() == 1 | |||
| def test_unmappable_multi_split(): | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(4) | |||
| # the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3] | |||
| ds.config.set_seed(53) | |||
| d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) | |||
| s1, s2 = d.split([4, 1]) | |||
| s1_correct_output = [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]] | |||
| s1_output = [] | |||
| for item in s1.create_dict_iterator(): | |||
| s1_output.append(item["text"].item().decode("utf8")) | |||
| assert s1_output == s1_correct_output | |||
| # no randomize in second split | |||
| s1s1, s1s2, s1s3 = s1.split([1, 2, 1], randomize=False) | |||
| s1s1_output = [] | |||
| for item in s1s1.create_dict_iterator(): | |||
| s1s1_output.append(item["text"].item().decode("utf8")) | |||
| s1s2_output = [] | |||
| for item in s1s2.create_dict_iterator(): | |||
| s1s2_output.append(item["text"].item().decode("utf8")) | |||
| s1s3_output = [] | |||
| for item in s1s3.create_dict_iterator(): | |||
| s1s3_output.append(item["text"].item().decode("utf8")) | |||
| assert s1s1_output == [s1_correct_output[0]] | |||
| assert s1s2_output == [s1_correct_output[1], s1_correct_output[2]] | |||
| assert s1s3_output == [s1_correct_output[3]] | |||
| s2_output = [] | |||
| for item in s2.create_dict_iterator(): | |||
| s2_output.append(item["text"].item().decode("utf8")) | |||
| assert s2_output == [text_file_data[3]] | |||
| # randomize in second split | |||
| # the labels outputted by the ShuffleOp for seed 53 is [2, 3, 1, 0] | |||
| shuffled_ids = [2, 3, 1, 0] | |||
| s1s1, s1s2, s1s3 = s1.split([1, 2, 1]) | |||
| s1s1_output = [] | |||
| for item in s1s1.create_dict_iterator(): | |||
| s1s1_output.append(item["text"].item().decode("utf8")) | |||
| s1s2_output = [] | |||
| for item in s1s2.create_dict_iterator(): | |||
| s1s2_output.append(item["text"].item().decode("utf8")) | |||
| s1s3_output = [] | |||
| for item in s1s3.create_dict_iterator(): | |||
| s1s3_output.append(item["text"].item().decode("utf8")) | |||
| assert s1s1_output == [s1_correct_output[shuffled_ids[0]]] | |||
| assert s1s2_output == [s1_correct_output[shuffled_ids[1]], s1_correct_output[shuffled_ids[2]]] | |||
| assert s1s3_output == [s1_correct_output[shuffled_ids[3]]] | |||
| s2_output = [] | |||
| for item in s2.create_dict_iterator(): | |||
| s2_output.append(item["text"].item().decode("utf8")) | |||
| assert s2_output == [text_file_data[3]] | |||
| # Restore configuration num_parallel_workers | |||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||
| @@ -137,6 +276,7 @@ def test_mappable_invalid_input(): | |||
| s1, s2 = d.split([4, 1]) | |||
| assert "dataset should not be sharded before split" in str(info.value) | |||
| def test_mappable_split_general(): | |||
| d = ds.ManifestDataset(manifest_file, shuffle=False) | |||
| d = d.take(5) | |||
| @@ -183,6 +323,7 @@ def test_mappable_split_general(): | |||
| assert s1_output == [0, 1] | |||
| assert s2_output == [2, 3, 4] | |||
| def test_mappable_split_optimized(): | |||
| d = ds.ManifestDataset(manifest_file, shuffle=False) | |||
| @@ -228,9 +369,9 @@ def test_mappable_split_optimized(): | |||
| assert s1_output == [0, 1] | |||
| assert s2_output == [2, 3, 4] | |||
| def test_mappable_randomize_deterministic(): | |||
| # set arbitrary seed for shard after split | |||
| # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4] | |||
| # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2] | |||
| ds.config.set_seed(53) | |||
| d = ds.ManifestDataset(manifest_file, shuffle=False) | |||
| @@ -249,9 +390,9 @@ def test_mappable_randomize_deterministic(): | |||
| assert s1_output == [0, 1, 3, 4] | |||
| assert s2_output == [2] | |||
| def test_mappable_randomize_repeatable(): | |||
| # set arbitrary seed for shard after split | |||
| # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4] | |||
| # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2] | |||
| ds.config.set_seed(53) | |||
| d = ds.ManifestDataset(manifest_file, shuffle=False) | |||
| @@ -273,9 +414,10 @@ def test_mappable_randomize_repeatable(): | |||
| assert s1_output == [0, 1, 3, 4] * num_epochs | |||
| assert s2_output == [2] * num_epochs | |||
| def test_mappable_sharding(): | |||
| # set arbitrary seed for repeatability for shard after split | |||
| # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4] | |||
| # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2] | |||
| ds.config.set_seed(53) | |||
| num_epochs = 5 | |||
| @@ -336,12 +478,94 @@ def test_mappable_sharding(): | |||
| assert s2_output == [2] | |||
| assert d2s2_output == [2] | |||
| def test_mappable_get_dataset_size(): | |||
| d = ds.ManifestDataset(manifest_file, shuffle=False) | |||
| s1, s2 = d.split([4, 1]) | |||
| assert d.get_dataset_size() == 5 | |||
| assert s1.get_dataset_size() == 4 | |||
| assert s2.get_dataset_size() == 1 | |||
| def test_mappable_multi_split(): | |||
| # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2] | |||
| ds.config.set_seed(53) | |||
| d = ds.ManifestDataset(manifest_file, shuffle=False) | |||
| s1, s2 = d.split([4, 1]) | |||
| s1_correct_output = [0, 1, 3, 4] | |||
| s1_output = [] | |||
| for item in s1.create_dict_iterator(): | |||
| s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) | |||
| assert s1_output == s1_correct_output | |||
| # no randomize in second split | |||
| s1s1, s1s2, s1s3 = s1.split([1, 2, 1], randomize=False) | |||
| s1s1_output = [] | |||
| for item in s1s1.create_dict_iterator(): | |||
| s1s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) | |||
| s1s2_output = [] | |||
| for item in s1s2.create_dict_iterator(): | |||
| s1s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) | |||
| s1s3_output = [] | |||
| for item in s1s3.create_dict_iterator(): | |||
| s1s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) | |||
| assert s1s1_output == [s1_correct_output[0]] | |||
| assert s1s2_output == [s1_correct_output[1], s1_correct_output[2]] | |||
| assert s1s3_output == [s1_correct_output[3]] | |||
| s2_output = [] | |||
| for item in s2.create_dict_iterator(): | |||
| s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) | |||
| assert s2_output == [2] | |||
| # randomize in second split | |||
| # the labels outputted by the RandomSampler for seed 53 is [3, 1, 2, 0] | |||
| random_sampler_ids = [3, 1, 2, 0] | |||
| s1s1, s1s2, s1s3 = s1.split([1, 2, 1]) | |||
| s1s1_output = [] | |||
| for item in s1s1.create_dict_iterator(): | |||
| s1s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) | |||
| s1s2_output = [] | |||
| for item in s1s2.create_dict_iterator(): | |||
| s1s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) | |||
| s1s3_output = [] | |||
| for item in s1s3.create_dict_iterator(): | |||
| s1s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) | |||
| assert s1s1_output == [s1_correct_output[random_sampler_ids[0]]] | |||
| assert s1s2_output == [s1_correct_output[random_sampler_ids[1]], s1_correct_output[random_sampler_ids[2]]] | |||
| assert s1s3_output == [s1_correct_output[random_sampler_ids[3]]] | |||
| s2_output = [] | |||
| for item in s2.create_dict_iterator(): | |||
| s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) | |||
| assert s2_output == [2] | |||
| if __name__ == '__main__': | |||
| test_unmappable_invalid_input() | |||
| test_unmappable_split() | |||
| test_unmappable_randomize_deterministic() | |||
| test_unmappable_randomize_repeatable() | |||
| test_unmappable_get_dataset_size() | |||
| test_unmappable_multi_split() | |||
| test_mappable_invalid_input() | |||
| test_mappable_split_general() | |||
| test_mappable_split_optimized() | |||
| test_mappable_randomize_deterministic() | |||
| test_mappable_randomize_repeatable() | |||
| test_mappable_sharding() | |||
| test_mappable_get_dataset_size() | |||
| test_mappable_multi_split() | |||