| @@ -66,11 +66,10 @@ def _alter_node(node): | |||||
| class Iterator: | class Iterator: | ||||
| """ | """ | ||||
| General Iterator over a dataset. | |||||
| Attributes: | |||||
| dataset: Dataset to be iterated over | |||||
| General Iterator over a dataset. | |||||
| Attributes: | |||||
| dataset: Dataset to be iterated over | |||||
| """ | """ | ||||
| def __init__(self, dataset): | def __init__(self, dataset): | ||||
| @@ -86,6 +85,7 @@ class Iterator: | |||||
| root = self.__convert_node_postorder(self.dataset) | root = self.__convert_node_postorder(self.dataset) | ||||
| self.depipeline.AssignRootNode(root) | self.depipeline.AssignRootNode(root) | ||||
| self.depipeline.LaunchTreeExec() | self.depipeline.LaunchTreeExec() | ||||
| self._index = 0 | |||||
| def __is_tree_node(self, node): | def __is_tree_node(self, node): | ||||
| """Check if a node is tree node.""" | """Check if a node is tree node.""" | ||||
| @@ -185,10 +185,7 @@ class Iterator: | |||||
| Iterator.__print_local(input_op, level + 1) | Iterator.__print_local(input_op, level + 1) | ||||
| def print(self): | def print(self): | ||||
| """ | |||||
| Print the dataset tree | |||||
| """ | |||||
| """Print the dataset tree""" | |||||
| self.__print_local(self.dataset, 0) | self.__print_local(self.dataset, 0) | ||||
| def release(self): | def release(self): | ||||
| @@ -202,7 +199,10 @@ class Iterator: | |||||
| def __next__(self): | def __next__(self): | ||||
| data = self.get_next() | data = self.get_next() | ||||
| if not data: | if not data: | ||||
| if self._index == 0: | |||||
| logger.warning("No records available.") | |||||
| raise StopIteration | raise StopIteration | ||||
| self._index += 1 | |||||
| return data | return data | ||||
| def get_output_shapes(self): | def get_output_shapes(self): | ||||
| @@ -234,7 +234,7 @@ class DictIterator(Iterator): | |||||
| def get_next(self): | def get_next(self): | ||||
| """ | """ | ||||
| Returns the next record in the dataset as dictionary | |||||
| Returns the next record in the dataset as dictionary | |||||
| Returns: | Returns: | ||||
| Dict, the next record in the dataset. | Dict, the next record in the dataset. | ||||
| @@ -260,7 +260,7 @@ class TupleIterator(Iterator): | |||||
| def get_next(self): | def get_next(self): | ||||
| """ | """ | ||||
| Returns the next record in the dataset as a list | |||||
| Returns the next record in the dataset as a list | |||||
| Returns: | Returns: | ||||
| List, the next record in the dataset. | List, the next record in the dataset. | ||||
| @@ -328,13 +328,20 @@ class FileWriter: | |||||
| self._generator.build() | self._generator.build() | ||||
| self._generator.write_to_db() | self._generator.write_to_db() | ||||
| mindrecord_files = [] | |||||
| index_files = [] | |||||
| # change the file mode to 600 | # change the file mode to 600 | ||||
| for item in self._paths: | for item in self._paths: | ||||
| if os.path.exists(item): | if os.path.exists(item): | ||||
| os.chmod(item, stat.S_IRUSR | stat.S_IWUSR) | os.chmod(item, stat.S_IRUSR | stat.S_IWUSR) | ||||
| mindrecord_files.append(item) | |||||
| index_file = item + ".db" | index_file = item + ".db" | ||||
| if os.path.exists(index_file): | if os.path.exists(index_file): | ||||
| os.chmod(index_file, stat.S_IRUSR | stat.S_IWUSR) | os.chmod(index_file, stat.S_IRUSR | stat.S_IWUSR) | ||||
| index_files.append(index_file) | |||||
| logger.info("The list of mindrecord files created are: {}, and the list of index files are: {}".format( | |||||
| mindrecord_files, index_files)) | |||||
| return ret | return ret | ||||
| @@ -25,6 +25,7 @@ import mindspore.dataset.transforms.vision.c_transforms as vision | |||||
| import numpy as np | import numpy as np | ||||
| import pytest | import pytest | ||||
| from mindspore._c_dataengine import InterpolationMode | from mindspore._c_dataengine import InterpolationMode | ||||
| from mindspore.dataset.transforms.vision import Inter | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| @@ -151,6 +152,51 @@ def test_cv_minddataset_dataset_size(add_and_remove_cv_file): | |||||
| assert data_set.get_dataset_size() == 3 | assert data_set.get_dataset_size() == 3 | ||||
| def test_cv_minddataset_repeat_reshuffle(add_and_remove_cv_file): | |||||
| """tutorial for cv minddataset.""" | |||||
| columns_list = ["data", "label"] | |||||
| num_readers = 4 | |||||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers) | |||||
| decode_op = vision.Decode() | |||||
| data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=2) | |||||
| resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR) | |||||
| data_set = data_set.map(input_columns="data", operations=resize_op, num_parallel_workers=2) | |||||
| data_set = data_set.batch(2) | |||||
| data_set = data_set.repeat(2) | |||||
| num_iter = 0 | |||||
| labels = [] | |||||
| for item in data_set.create_dict_iterator(): | |||||
| logger.info("-------------- get dataset size {} -----------------".format(num_iter)) | |||||
| logger.info("-------------- item[label]: {} ---------------------".format(item["label"])) | |||||
| logger.info("-------------- item[data]: {} ----------------------".format(item["data"])) | |||||
| num_iter += 1 | |||||
| labels.append(item["label"]) | |||||
| assert num_iter == 10 | |||||
| logger.info("repeat shuffle: {}".format(labels)) | |||||
| assert len(labels) == 10 | |||||
| assert labels[0:5] == labels[0:5] | |||||
| assert labels[0:5] != labels[5:5] | |||||
| def test_cv_minddataset_batch_size_larger_than_records(add_and_remove_cv_file): | |||||
| """tutorial for cv minddataset.""" | |||||
| columns_list = ["data", "label"] | |||||
| num_readers = 4 | |||||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers) | |||||
| decode_op = vision.Decode() | |||||
| data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=2) | |||||
| resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR) | |||||
| data_set = data_set.map(input_columns="data", operations=resize_op, num_parallel_workers=2) | |||||
| data_set = data_set.batch(32, drop_remainder=True) | |||||
| num_iter = 0 | |||||
| for item in data_set.create_dict_iterator(): | |||||
| logger.info("-------------- get dataset size {} -----------------".format(num_iter)) | |||||
| logger.info("-------------- item[label]: {} ---------------------".format(item["label"])) | |||||
| logger.info("-------------- item[data]: {} ----------------------".format(item["data"])) | |||||
| num_iter += 1 | |||||
| assert num_iter == 0 | |||||
| def test_cv_minddataset_issue_888(add_and_remove_cv_file): | def test_cv_minddataset_issue_888(add_and_remove_cv_file): | ||||
| """issue 888 test.""" | """issue 888 test.""" | ||||
| columns_list = ["data", "label"] | columns_list = ["data", "label"] | ||||