Merge pull request !1457 from Peilin/splitOp-after-testingtags/v0.5.0-beta
| @@ -71,7 +71,7 @@ if __name__ == '__main__': | |||||
| model = Model(network, loss, opt, {'acc': Accuracy()}) | model = Model(network, loss, opt, {'acc': Accuracy()}) | ||||
| print("============== Starting Training ==============") | print("============== Starting Training ==============") | ||||
| ds_train = create_dataset(args.preprocess_path, cfg.batch_size, repeat_num=cfg.num_epochs) | |||||
| ds_train = create_dataset(args.preprocess_path, cfg.batch_size, cfg.num_epochs) | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, | config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, | ||||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | keep_checkpoint_max=cfg.keep_checkpoint_max) | ||||
| ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck) | ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck) | ||||
| @@ -70,21 +70,26 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| } | } | ||||
| Status RandomSampler::InitSampler() { | Status RandomSampler::InitSampler() { | ||||
| num_samples_ = (user_num_samples_ < num_samples_) ? user_num_samples_ : num_samples_; | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive"); | |||||
| samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows needs to be positive."); | |||||
| rnd_.seed(seed_); | rnd_.seed(seed_); | ||||
| if (replacement_ == false) { | if (replacement_ == false) { | ||||
| num_samples_ = std::min(num_samples_, num_rows_); | |||||
| shuffled_ids_.reserve(num_rows_); | shuffled_ids_.reserve(num_rows_); | ||||
| for (int64_t i = 0; i < num_rows_; i++) { | for (int64_t i = 0; i < num_rows_; i++) { | ||||
| shuffled_ids_.push_back(i); | shuffled_ids_.push_back(i); | ||||
| } | } | ||||
| std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_); | std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_); | ||||
| } else { | } else { | ||||
| num_samples_ = std::min(num_samples_, user_num_samples_); | |||||
| dist = std::make_unique<std::uniform_int_distribution<int64_t>>(0, num_rows_ - 1); | dist = std::make_unique<std::uniform_int_distribution<int64_t>>(0, num_rows_ - 1); | ||||
| } | } | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples needs to be positive."); | |||||
| samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -32,9 +32,7 @@ Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { | |||||
| } | } | ||||
| // Handshake and init child first. | // Handshake and init child first. | ||||
| if (HasChildSampler()) { | |||||
| RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op)); | |||||
| } | |||||
| RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op)); | |||||
| } | } | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n"); | CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n"); | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -28,9 +28,9 @@ SubsetSampler::SubsetSampler(int64_t start_index, int64_t subset_size) | |||||
| : Sampler(subset_size), start_index_(start_index), subset_size_(subset_size), current_id_(0) {} | : Sampler(subset_size), start_index_(start_index), subset_size_(subset_size), current_id_(0) {} | ||||
| Status SubsetSampler::InitSampler() { | Status SubsetSampler::InitSampler() { | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size_ <= 0\n"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size <= 0\n"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n"); | CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n"); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows_\n"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows\n"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(start_index_ + subset_size_ - 1 < num_rows_, "Final index out of bounds.\n"); | CHECK_FAIL_RETURN_UNEXPECTED(start_index_ + subset_size_ - 1 < num_rows_, "Final index out of bounds.\n"); | ||||
| num_samples_ = subset_size_; | num_samples_ = subset_size_; | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -23,7 +23,7 @@ from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset | |||||
| GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, TextFileDataset, \ | GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, TextFileDataset, \ | ||||
| Schema, Shuffle, zip, RandomDataset | Schema, Shuffle, zip, RandomDataset | ||||
| from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ | from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ | ||||
| WeightedRandomSampler, Sampler | |||||
| WeightedRandomSampler, SubsetSampler, Sampler | |||||
| from .engine.serializer_deserializer import serialize, deserialize, show | from .engine.serializer_deserializer import serialize, deserialize, show | ||||
| from .engine.graphdata import GraphData | from .engine.graphdata import GraphData | ||||
| @@ -633,9 +633,9 @@ class Dataset: | |||||
| Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size | Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size | ||||
| of the original dataset. If after rounding, any size equals 0, an error will occur. | of the original dataset. If after rounding, any size equals 0, an error will occur. | ||||
| All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur. | All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur. | ||||
| randomize (bool): determines whether or not to split the data randomly. If true, the data | |||||
| will be randomly split. Otherwise, each split will be created with consecutive rows | |||||
| from the dataset. | |||||
| randomize (bool, optional): determines whether or not to split the data randomly (default=True). | |||||
| If true, the data will be randomly split. Otherwise, each split will be created with | |||||
| consecutive rows from the dataset. | |||||
| Note: | Note: | ||||
| 1. Dataset cannot be sharded if split is going to be called. | 1. Dataset cannot be sharded if split is going to be called. | ||||
| @@ -678,7 +678,8 @@ class Dataset: | |||||
| ds = copy.deepcopy(self) | ds = copy.deepcopy(self) | ||||
| if randomize: | if randomize: | ||||
| # want to shuffle the same way every epoch before split | # want to shuffle the same way every epoch before split | ||||
| ds = ds.shuffle() | |||||
| # in alter_tree, shuffle buffer is minimum 10000, so use 10000 here | |||||
| ds = ds.shuffle(10000) | |||||
| ds.reshuffle_each_epoch = False | ds.reshuffle_each_epoch = False | ||||
| if rows_to_skip > 0: | if rows_to_skip > 0: | ||||
| @@ -1209,6 +1210,9 @@ class MappableDataset(SourceDataset): | |||||
| >>> new_sampler = ds.DistributedSampler(10, 2) | >>> new_sampler = ds.DistributedSampler(10, 2) | ||||
| >>> data.use_sampler(new_sampler) | >>> data.use_sampler(new_sampler) | ||||
| """ | """ | ||||
| if new_sampler is not None and not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)): | |||||
| raise TypeError("new_sampler is not an instance of a sampler.") | |||||
| self.sampler = self.sampler.child_sampler | self.sampler = self.sampler.child_sampler | ||||
| self.add_sampler(new_sampler) | self.add_sampler(new_sampler) | ||||
| @@ -1218,6 +1222,11 @@ class MappableDataset(SourceDataset): | |||||
| def is_sharded(self): | def is_sharded(self): | ||||
| raise NotImplementedError("MappableDataset must implement is_sharded.") | raise NotImplementedError("MappableDataset must implement is_sharded.") | ||||
| def _get_sampler_dataset_size(self): | |||||
| if self.sampler is not None: | |||||
| return self.sampler.get_dataset_size() | |||||
| return None | |||||
| @check_split | @check_split | ||||
| def split(self, sizes, randomize=True): | def split(self, sizes, randomize=True): | ||||
| @@ -1236,9 +1245,9 @@ class MappableDataset(SourceDataset): | |||||
| Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size | Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size | ||||
| of the original dataset. If after rounding, any size equals 0, an error will occur. | of the original dataset. If after rounding, any size equals 0, an error will occur. | ||||
| All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur. | All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur. | ||||
| randomize (bool): determines whether or not to split the data randomly. If true, the data | |||||
| will be randomly split. Otherwise, each split will be created with consecutive rows | |||||
| from the dataset. | |||||
| randomize (bool, optional): determines whether or not to split the data randomly (default=True). | |||||
| If true, the data will be randomly split. Otherwise, each split will be created with | |||||
| consecutive rows from the dataset. | |||||
| Note: | Note: | ||||
| 1. Dataset should not be sharded if split is going to be called. Instead, create a | 1. Dataset should not be sharded if split is going to be called. Instead, create a | ||||
| @@ -2105,7 +2114,6 @@ class TransferDataset(DatasetOp): | |||||
| self.iterator = TupleIterator(self) | self.iterator = TupleIterator(self) | ||||
| class RangeDataset(MappableDataset): | class RangeDataset(MappableDataset): | ||||
| """ | """ | ||||
| A source dataset that reads and parses datasets stored on disk in a range. | A source dataset that reads and parses datasets stored on disk in a range. | ||||
| @@ -2296,8 +2304,13 @@ class ImageFolderDatasetV2(MappableDataset): | |||||
| else: | else: | ||||
| num_samples = self.num_samples | num_samples = self.num_samples | ||||
| num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir, num_samples)[0] | num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir, num_samples)[0] | ||||
| rows_per_shard = get_num_rows(num_rows, self.num_shards) | |||||
| rows_from_sampler = self._get_sampler_dataset_size() | |||||
| return get_num_rows(num_rows, self.num_shards) | |||||
| if rows_from_sampler is None: | |||||
| return rows_per_shard | |||||
| return min(rows_from_sampler, rows_per_shard) | |||||
| def num_classes(self): | def num_classes(self): | ||||
| """ | """ | ||||
| @@ -2425,8 +2438,13 @@ class MnistDataset(MappableDataset): | |||||
| num_samples = self.num_samples | num_samples = self.num_samples | ||||
| num_rows = MnistOp.get_num_rows(self.dataset_dir, num_samples) | num_rows = MnistOp.get_num_rows(self.dataset_dir, num_samples) | ||||
| rows_per_shard = get_num_rows(num_rows, self.num_shards) | |||||
| rows_from_sampler = self._get_sampler_dataset_size() | |||||
| if rows_from_sampler is None: | |||||
| return rows_per_shard | |||||
| return get_num_rows(num_rows, self.num_shards) | |||||
| return min(rows_from_sampler, rows_per_shard) | |||||
| def is_shuffled(self): | def is_shuffled(self): | ||||
| if self.shuffle_level is None: | if self.shuffle_level is None: | ||||
| @@ -2926,7 +2944,12 @@ class GeneratorDataset(MappableDataset): | |||||
| Return: | Return: | ||||
| Number, number of batches. | Number, number of batches. | ||||
| """ | """ | ||||
| return self._dataset_size | |||||
| rows_from_sampler = self._get_sampler_dataset_size() | |||||
| if rows_from_sampler is None: | |||||
| return self._dataset_size | |||||
| return min(rows_from_sampler, self._dataset_size) | |||||
| # manually set dataset_size as a temporary solution. | # manually set dataset_size as a temporary solution. | ||||
| def set_dataset_size(self, value): | def set_dataset_size(self, value): | ||||
| @@ -3220,8 +3243,13 @@ class ManifestDataset(MappableDataset): | |||||
| class_indexing = self.class_indexing | class_indexing = self.class_indexing | ||||
| num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, num_samples, class_indexing, self.usage)[0] | num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, num_samples, class_indexing, self.usage)[0] | ||||
| rows_per_shard = get_num_rows(num_rows, self.num_shards) | |||||
| rows_from_sampler = self._get_sampler_dataset_size() | |||||
| if rows_from_sampler is None: | |||||
| return rows_per_shard | |||||
| return get_num_rows(num_rows, self.num_shards) | |||||
| return min(rows_from_sampler, rows_per_shard) | |||||
| def num_classes(self): | def num_classes(self): | ||||
| """ | """ | ||||
| @@ -3379,8 +3407,13 @@ class Cifar10Dataset(MappableDataset): | |||||
| num_samples = self.num_samples | num_samples = self.num_samples | ||||
| num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, True) | num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, True) | ||||
| rows_per_shard = get_num_rows(num_rows, self.num_shards) | |||||
| rows_from_sampler = self._get_sampler_dataset_size() | |||||
| return get_num_rows(num_rows, self.num_shards) | |||||
| if rows_from_sampler is None: | |||||
| return rows_per_shard | |||||
| return min(rows_from_sampler, rows_per_shard) | |||||
| def is_shuffled(self): | def is_shuffled(self): | ||||
| if self.shuffle_level is None: | if self.shuffle_level is None: | ||||
| @@ -3498,8 +3531,13 @@ class Cifar100Dataset(MappableDataset): | |||||
| num_samples = self.num_samples | num_samples = self.num_samples | ||||
| num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, False) | num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, False) | ||||
| rows_per_shard = get_num_rows(num_rows, self.num_shards) | |||||
| rows_from_sampler = self._get_sampler_dataset_size() | |||||
| if rows_from_sampler is None: | |||||
| return rows_per_shard | |||||
| return get_num_rows(num_rows, self.num_shards) | |||||
| return min(rows_from_sampler, rows_per_shard) | |||||
| def is_shuffled(self): | def is_shuffled(self): | ||||
| if self.shuffle_level is None: | if self.shuffle_level is None: | ||||
| @@ -3562,7 +3600,12 @@ class RandomDataset(SourceDataset): | |||||
| Return: | Return: | ||||
| Number, number of batches. | Number, number of batches. | ||||
| """ | """ | ||||
| return num_samples | |||||
| rows_from_sampler = self._get_sampler_dataset_size() | |||||
| if rows_from_sampler is None: | |||||
| return self.num_samples | |||||
| return min(rows_from_sampler, self.num_samples) | |||||
| def is_shuffled(self): | def is_shuffled(self): | ||||
| return True | return True | ||||
| @@ -3871,7 +3914,12 @@ class VOCDataset(MappableDataset): | |||||
| Return: | Return: | ||||
| Number, number of batches. | Number, number of batches. | ||||
| """ | """ | ||||
| return self.num_samples | |||||
| rows_from_sampler = self._get_sampler_dataset_size() | |||||
| if rows_from_sampler is None: | |||||
| return self.num_samples | |||||
| return min(rows_from_sampler, self.num_samples) | |||||
| def get_class_indexing(self): | def get_class_indexing(self): | ||||
| """ | """ | ||||
| @@ -114,6 +114,9 @@ class Sampler: | |||||
| return self.child_sampler.is_sharded() | return self.child_sampler.is_sharded() | ||||
| def get_dataset_size(self): | |||||
| return self._get_indices().size | |||||
| class BuiltinSampler: | class BuiltinSampler: | ||||
| """ | """ | ||||
| @@ -146,6 +149,12 @@ class BuiltinSampler: | |||||
| def is_sharded(self): | def is_sharded(self): | ||||
| raise NotImplementedError("Sampler must implement is_sharded.") | raise NotImplementedError("Sampler must implement is_sharded.") | ||||
| def get_dataset_size(self): | |||||
| if self.child_sampler is not None: | |||||
| return self.child_sampler.get_dataset_size() | |||||
| return None | |||||
| class DistributedSampler(BuiltinSampler): | class DistributedSampler(BuiltinSampler): | ||||
| """ | """ | ||||
| @@ -330,6 +339,9 @@ class RandomSampler(BuiltinSampler): | |||||
| return self.child_sampler.is_sharded() | return self.child_sampler.is_sharded() | ||||
| def get_dataset_size(self): | |||||
| return self.num_samples | |||||
| class SequentialSampler(BuiltinSampler): | class SequentialSampler(BuiltinSampler): | ||||
| """ | """ | ||||
| @@ -421,6 +433,9 @@ class SubsetSampler(BuiltinSampler): | |||||
| return self.child_sampler.is_sharded() | return self.child_sampler.is_sharded() | ||||
| def get_dataset_size(self): | |||||
| return self.subset_size | |||||
| class SubsetRandomSampler(BuiltinSampler): | class SubsetRandomSampler(BuiltinSampler): | ||||
| """ | """ | ||||
| @@ -467,6 +482,10 @@ class SubsetRandomSampler(BuiltinSampler): | |||||
| return cde.MindrecordSubsetRandomSampler(self.indices) | return cde.MindrecordSubsetRandomSampler(self.indices) | ||||
| def get_dataset_size(self): | |||||
| return len(indices) | |||||
| class WeightedRandomSampler(BuiltinSampler): | class WeightedRandomSampler(BuiltinSampler): | ||||
| """ | """ | ||||
| Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities). | Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities). | ||||
| @@ -522,3 +541,6 @@ class WeightedRandomSampler(BuiltinSampler): | |||||
| return False | return False | ||||
| return self.child_sampler.is_sharded() | return self.child_sampler.is_sharded() | ||||
| def get_dataset_size(self): | |||||
| return self.num_samples | |||||
| @@ -13,6 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| import numpy as np | import numpy as np | ||||
| import pytest | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| @@ -164,6 +165,35 @@ def test_python_sampler(): | |||||
| assert list(sp1.get_indices()) == [0, 1, 2, 3, 4] | assert list(sp1.get_indices()) == [0, 1, 2, 3, 4] | ||||
| def test_subset_sampler(): | |||||
| manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" | |||||
| map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | |||||
| def test_config(num_samples, start_index, subset_size): | |||||
| sampler = ds.SubsetSampler(start_index, subset_size) | |||||
| d = ds.ManifestDataset(manifest_file, sampler=sampler) | |||||
| res = [] | |||||
| for item in d.create_dict_iterator(): | |||||
| res.append(map[(item["image"].shape[0], item["label"].item())]) | |||||
| return res | |||||
| with pytest.raises(RuntimeError) as info: | |||||
| test_config(5, 0, 0) | |||||
| assert "subset_size <= 0" in str(info.value) | |||||
| assert test_config(5, 0, 1) == [0] | |||||
| assert test_config(5, 0, 2) == [0, 1] | |||||
| assert test_config(5, 0, 3) == [0, 1, 2] | |||||
| assert test_config(5, 0, 4) == [0, 1, 2, 3] | |||||
| assert test_config(5, 0, 5) == [0, 1, 2, 3, 4] | |||||
| assert test_config(5, 1, 1) == [1] | |||||
| assert test_config(5, 2, 3) == [2, 3, 4] | |||||
| assert test_config(5, 3, 2) == [3, 4] | |||||
| assert test_config(5, 4, 1) == [4] | |||||
| def test_sampler_chain(): | def test_sampler_chain(): | ||||
| manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" | manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" | ||||
| map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | ||||
| @@ -190,10 +220,26 @@ def test_sampler_chain(): | |||||
| assert test_config(5, 3) == [3] | assert test_config(5, 3) == [3] | ||||
| assert test_config(5, 4) == [4] | assert test_config(5, 4) == [4] | ||||
| def test_add_sampler_invalid_input(): | |||||
| manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" | |||||
| map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | |||||
| data1 = ds.ManifestDataset(manifest_file) | |||||
| with pytest.raises(TypeError) as info: | |||||
| data1.use_sampler(1) | |||||
| assert "not an instance of a sampler" in str(info.value) | |||||
| with pytest.raises(TypeError) as info: | |||||
| data1.use_sampler("sampler") | |||||
| assert "not an instance of a sampler" in str(info.value) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_sequential_sampler(True) | test_sequential_sampler(True) | ||||
| test_random_sampler(True) | test_random_sampler(True) | ||||
| test_random_sampler_multi_iter(True) | test_random_sampler_multi_iter(True) | ||||
| test_sampler_py_api() | test_sampler_py_api() | ||||
| test_python_sampler() | test_python_sampler() | ||||
| test_subset_sampler() | |||||
| test_sampler_chain() | test_sampler_chain() | ||||
| test_add_sampler_invalid_input() | |||||
| @@ -23,7 +23,11 @@ from util import config_get_set_num_parallel_workers | |||||
| manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" | manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" | ||||
| manifest_map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | manifest_map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | ||||
| def split_with_invalid_inputs(d): | |||||
| text_file_dataset_path = "../data/dataset/testTextFileDataset/*" | |||||
| text_file_data = ["This is a text file.", "Another file.", "Be happy every day.", | |||||
| "End of file.", "Good luck to everyone."] | |||||
| def split_with_invalid_inputs(d): | |||||
| with pytest.raises(ValueError) as info: | with pytest.raises(ValueError) as info: | ||||
| s1, s2 = d.split([]) | s1, s2 = d.split([]) | ||||
| assert "sizes cannot be empty" in str(info.value) | assert "sizes cannot be empty" in str(info.value) | ||||
| @@ -68,8 +72,8 @@ def split_with_invalid_inputs(d): | |||||
| s1, s2 = d.split([0.05, 0.95]) | s1, s2 = d.split([0.05, 0.95]) | ||||
| assert "percentage 0.05 is too small" in str(info.value) | assert "percentage 0.05 is too small" in str(info.value) | ||||
| def test_unmappable_invalid_input(): | def test_unmappable_invalid_input(): | ||||
| text_file_dataset_path = "../data/dataset/testTextFileDataset/*" | |||||
| d = ds.TextFileDataset(text_file_dataset_path) | d = ds.TextFileDataset(text_file_dataset_path) | ||||
| split_with_invalid_inputs(d) | split_with_invalid_inputs(d) | ||||
| @@ -78,11 +82,10 @@ def test_unmappable_invalid_input(): | |||||
| s1, s2 = d.split([4, 1]) | s1, s2 = d.split([4, 1]) | ||||
| assert "dataset should not be sharded before split" in str(info.value) | assert "dataset should not be sharded before split" in str(info.value) | ||||
| def test_unmappable_split(): | def test_unmappable_split(): | ||||
| text_file_dataset_path = "../data/dataset/testTextFileDataset/*" | |||||
| text_file_data = ["This is a text file.", "Another file.", "Be happy every day.", | |||||
| "End of file.", "Good luck to everyone."] | |||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(4) | original_num_parallel_workers = config_get_set_num_parallel_workers(4) | ||||
| d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) | d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) | ||||
| s1, s2 = d.split([4, 1], randomize=False) | s1, s2 = d.split([4, 1], randomize=False) | ||||
| @@ -124,6 +127,142 @@ def test_unmappable_split(): | |||||
| assert s1_output == text_file_data[0:2] | assert s1_output == text_file_data[0:2] | ||||
| assert s2_output == text_file_data[2:] | assert s2_output == text_file_data[2:] | ||||
| # Restore configuration num_parallel_workers | |||||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||||
| def test_unmappable_randomize_deterministic(): | |||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(4) | |||||
| # the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3] | |||||
| ds.config.set_seed(53) | |||||
| d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) | |||||
| s1, s2 = d.split([0.8, 0.2]) | |||||
| for _ in range(10): | |||||
| s1_output = [] | |||||
| for item in s1.create_dict_iterator(): | |||||
| s1_output.append(item["text"].item().decode("utf8")) | |||||
| s2_output = [] | |||||
| for item in s2.create_dict_iterator(): | |||||
| s2_output.append(item["text"].item().decode("utf8")) | |||||
| # note no overlap | |||||
| assert s1_output == [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]] | |||||
| assert s2_output == [text_file_data[3]] | |||||
| # Restore configuration num_parallel_workers | |||||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||||
| def test_unmappable_randomize_repeatable(): | |||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(4) | |||||
| # the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3] | |||||
| ds.config.set_seed(53) | |||||
| d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) | |||||
| s1, s2 = d.split([0.8, 0.2]) | |||||
| num_epochs = 5 | |||||
| s1 = s1.repeat(num_epochs) | |||||
| s2 = s2.repeat(num_epochs) | |||||
| s1_output = [] | |||||
| for item in s1.create_dict_iterator(): | |||||
| s1_output.append(item["text"].item().decode("utf8")) | |||||
| s2_output = [] | |||||
| for item in s2.create_dict_iterator(): | |||||
| s2_output.append(item["text"].item().decode("utf8")) | |||||
| # note no overlap | |||||
| assert s1_output == [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]] * num_epochs | |||||
| assert s2_output == [text_file_data[3]] * num_epochs | |||||
| # Restore configuration num_parallel_workers | |||||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||||
| def test_unmappable_get_dataset_size(): | |||||
| d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) | |||||
| s1, s2 = d.split([0.8, 0.2]) | |||||
| assert d.get_dataset_size() == 5 | |||||
| assert s1.get_dataset_size() == 4 | |||||
| assert s2.get_dataset_size() == 1 | |||||
| def test_unmappable_multi_split(): | |||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(4) | |||||
| # the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3] | |||||
| ds.config.set_seed(53) | |||||
| d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) | |||||
| s1, s2 = d.split([4, 1]) | |||||
| s1_correct_output = [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]] | |||||
| s1_output = [] | |||||
| for item in s1.create_dict_iterator(): | |||||
| s1_output.append(item["text"].item().decode("utf8")) | |||||
| assert s1_output == s1_correct_output | |||||
| # no randomize in second split | |||||
| s1s1, s1s2, s1s3 = s1.split([1, 2, 1], randomize=False) | |||||
| s1s1_output = [] | |||||
| for item in s1s1.create_dict_iterator(): | |||||
| s1s1_output.append(item["text"].item().decode("utf8")) | |||||
| s1s2_output = [] | |||||
| for item in s1s2.create_dict_iterator(): | |||||
| s1s2_output.append(item["text"].item().decode("utf8")) | |||||
| s1s3_output = [] | |||||
| for item in s1s3.create_dict_iterator(): | |||||
| s1s3_output.append(item["text"].item().decode("utf8")) | |||||
| assert s1s1_output == [s1_correct_output[0]] | |||||
| assert s1s2_output == [s1_correct_output[1], s1_correct_output[2]] | |||||
| assert s1s3_output == [s1_correct_output[3]] | |||||
| s2_output = [] | |||||
| for item in s2.create_dict_iterator(): | |||||
| s2_output.append(item["text"].item().decode("utf8")) | |||||
| assert s2_output == [text_file_data[3]] | |||||
| # randomize in second split | |||||
| # the labels outputted by the ShuffleOp for seed 53 is [2, 3, 1, 0] | |||||
| shuffled_ids = [2, 3, 1, 0] | |||||
| s1s1, s1s2, s1s3 = s1.split([1, 2, 1]) | |||||
| s1s1_output = [] | |||||
| for item in s1s1.create_dict_iterator(): | |||||
| s1s1_output.append(item["text"].item().decode("utf8")) | |||||
| s1s2_output = [] | |||||
| for item in s1s2.create_dict_iterator(): | |||||
| s1s2_output.append(item["text"].item().decode("utf8")) | |||||
| s1s3_output = [] | |||||
| for item in s1s3.create_dict_iterator(): | |||||
| s1s3_output.append(item["text"].item().decode("utf8")) | |||||
| assert s1s1_output == [s1_correct_output[shuffled_ids[0]]] | |||||
| assert s1s2_output == [s1_correct_output[shuffled_ids[1]], s1_correct_output[shuffled_ids[2]]] | |||||
| assert s1s3_output == [s1_correct_output[shuffled_ids[3]]] | |||||
| s2_output = [] | |||||
| for item in s2.create_dict_iterator(): | |||||
| s2_output.append(item["text"].item().decode("utf8")) | |||||
| assert s2_output == [text_file_data[3]] | |||||
| # Restore configuration num_parallel_workers | # Restore configuration num_parallel_workers | ||||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | ds.config.set_num_parallel_workers(original_num_parallel_workers) | ||||
| @@ -137,6 +276,7 @@ def test_mappable_invalid_input(): | |||||
| s1, s2 = d.split([4, 1]) | s1, s2 = d.split([4, 1]) | ||||
| assert "dataset should not be sharded before split" in str(info.value) | assert "dataset should not be sharded before split" in str(info.value) | ||||
| def test_mappable_split_general(): | def test_mappable_split_general(): | ||||
| d = ds.ManifestDataset(manifest_file, shuffle=False) | d = ds.ManifestDataset(manifest_file, shuffle=False) | ||||
| d = d.take(5) | d = d.take(5) | ||||
| @@ -183,6 +323,7 @@ def test_mappable_split_general(): | |||||
| assert s1_output == [0, 1] | assert s1_output == [0, 1] | ||||
| assert s2_output == [2, 3, 4] | assert s2_output == [2, 3, 4] | ||||
| def test_mappable_split_optimized(): | def test_mappable_split_optimized(): | ||||
| d = ds.ManifestDataset(manifest_file, shuffle=False) | d = ds.ManifestDataset(manifest_file, shuffle=False) | ||||
| @@ -228,9 +369,9 @@ def test_mappable_split_optimized(): | |||||
| assert s1_output == [0, 1] | assert s1_output == [0, 1] | ||||
| assert s2_output == [2, 3, 4] | assert s2_output == [2, 3, 4] | ||||
| def test_mappable_randomize_deterministic(): | def test_mappable_randomize_deterministic(): | ||||
| # set arbitrary seed for shard after split | |||||
| # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4] | |||||
| # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2] | |||||
| ds.config.set_seed(53) | ds.config.set_seed(53) | ||||
| d = ds.ManifestDataset(manifest_file, shuffle=False) | d = ds.ManifestDataset(manifest_file, shuffle=False) | ||||
| @@ -249,9 +390,9 @@ def test_mappable_randomize_deterministic(): | |||||
| assert s1_output == [0, 1, 3, 4] | assert s1_output == [0, 1, 3, 4] | ||||
| assert s2_output == [2] | assert s2_output == [2] | ||||
| def test_mappable_randomize_repeatable(): | def test_mappable_randomize_repeatable(): | ||||
| # set arbitrary seed for shard after split | |||||
| # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4] | |||||
| # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2] | |||||
| ds.config.set_seed(53) | ds.config.set_seed(53) | ||||
| d = ds.ManifestDataset(manifest_file, shuffle=False) | d = ds.ManifestDataset(manifest_file, shuffle=False) | ||||
| @@ -273,9 +414,10 @@ def test_mappable_randomize_repeatable(): | |||||
| assert s1_output == [0, 1, 3, 4] * num_epochs | assert s1_output == [0, 1, 3, 4] * num_epochs | ||||
| assert s2_output == [2] * num_epochs | assert s2_output == [2] * num_epochs | ||||
| def test_mappable_sharding(): | def test_mappable_sharding(): | ||||
| # set arbitrary seed for repeatability for shard after split | # set arbitrary seed for repeatability for shard after split | ||||
| # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4] | |||||
| # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2] | |||||
| ds.config.set_seed(53) | ds.config.set_seed(53) | ||||
| num_epochs = 5 | num_epochs = 5 | ||||
| @@ -336,12 +478,94 @@ def test_mappable_sharding(): | |||||
| assert s2_output == [2] | assert s2_output == [2] | ||||
| assert d2s2_output == [2] | assert d2s2_output == [2] | ||||
| def test_mappable_get_dataset_size(): | |||||
| d = ds.ManifestDataset(manifest_file, shuffle=False) | |||||
| s1, s2 = d.split([4, 1]) | |||||
| assert d.get_dataset_size() == 5 | |||||
| assert s1.get_dataset_size() == 4 | |||||
| assert s2.get_dataset_size() == 1 | |||||
| def test_mappable_multi_split(): | |||||
| # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2] | |||||
| ds.config.set_seed(53) | |||||
| d = ds.ManifestDataset(manifest_file, shuffle=False) | |||||
| s1, s2 = d.split([4, 1]) | |||||
| s1_correct_output = [0, 1, 3, 4] | |||||
| s1_output = [] | |||||
| for item in s1.create_dict_iterator(): | |||||
| s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) | |||||
| assert s1_output == s1_correct_output | |||||
| # no randomize in second split | |||||
| s1s1, s1s2, s1s3 = s1.split([1, 2, 1], randomize=False) | |||||
| s1s1_output = [] | |||||
| for item in s1s1.create_dict_iterator(): | |||||
| s1s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) | |||||
| s1s2_output = [] | |||||
| for item in s1s2.create_dict_iterator(): | |||||
| s1s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) | |||||
| s1s3_output = [] | |||||
| for item in s1s3.create_dict_iterator(): | |||||
| s1s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) | |||||
| assert s1s1_output == [s1_correct_output[0]] | |||||
| assert s1s2_output == [s1_correct_output[1], s1_correct_output[2]] | |||||
| assert s1s3_output == [s1_correct_output[3]] | |||||
| s2_output = [] | |||||
| for item in s2.create_dict_iterator(): | |||||
| s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) | |||||
| assert s2_output == [2] | |||||
| # randomize in second split | |||||
| # the labels outputted by the RandomSampler for seed 53 is [3, 1, 2, 0] | |||||
| random_sampler_ids = [3, 1, 2, 0] | |||||
| s1s1, s1s2, s1s3 = s1.split([1, 2, 1]) | |||||
| s1s1_output = [] | |||||
| for item in s1s1.create_dict_iterator(): | |||||
| s1s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) | |||||
| s1s2_output = [] | |||||
| for item in s1s2.create_dict_iterator(): | |||||
| s1s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) | |||||
| s1s3_output = [] | |||||
| for item in s1s3.create_dict_iterator(): | |||||
| s1s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) | |||||
| assert s1s1_output == [s1_correct_output[random_sampler_ids[0]]] | |||||
| assert s1s2_output == [s1_correct_output[random_sampler_ids[1]], s1_correct_output[random_sampler_ids[2]]] | |||||
| assert s1s3_output == [s1_correct_output[random_sampler_ids[3]]] | |||||
| s2_output = [] | |||||
| for item in s2.create_dict_iterator(): | |||||
| s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) | |||||
| assert s2_output == [2] | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_unmappable_invalid_input() | test_unmappable_invalid_input() | ||||
| test_unmappable_split() | test_unmappable_split() | ||||
| test_unmappable_randomize_deterministic() | |||||
| test_unmappable_randomize_repeatable() | |||||
| test_unmappable_get_dataset_size() | |||||
| test_unmappable_multi_split() | |||||
| test_mappable_invalid_input() | test_mappable_invalid_input() | ||||
| test_mappable_split_general() | test_mappable_split_general() | ||||
| test_mappable_split_optimized() | test_mappable_split_optimized() | ||||
| test_mappable_randomize_deterministic() | test_mappable_randomize_deterministic() | ||||
| test_mappable_randomize_repeatable() | test_mappable_randomize_repeatable() | ||||
| test_mappable_sharding() | test_mappable_sharding() | ||||
| test_mappable_get_dataset_size() | |||||
| test_mappable_multi_split() | |||||