Merge pull request !1061 from jiangzhiwen/dataset/flat_maptags/v0.3.0-alpha
| @@ -268,6 +268,50 @@ class Dataset: | |||||
| """ | """ | ||||
| return ShuffleDataset(self, buffer_size) | return ShuffleDataset(self, buffer_size) | ||||
| def flat_map(self, func): | |||||
| """ | |||||
| Maps `func` to each row in dataset and flatten the result. | |||||
| The specified `func` is a function that must take one 'Ndarray' as input | |||||
| and return a 'Dataset'. | |||||
| Args: | |||||
| func (function): A function that must take one 'Ndarray' as an argument and | |||||
| return a 'Dataset'. | |||||
| Returns: | |||||
| Dataset, applied by the function. | |||||
| Examples: | |||||
| >>> import mindspore.dataset as ds | |||||
| >>> import mindspore.dataset.transforms.nlp.utils as nlp | |||||
| >>> # declare a function which returns a Dataset object | |||||
| >>> def flat_map_func(x): | |||||
| >>> data_dir = nlp.as_text(x[0]) | |||||
| >>> d = ds.ImageFolderDatasetV2(data_dir) | |||||
| >>> return d | |||||
| >>> # data is a Dataset object | |||||
| >>> data = ds.TextFileDataset(DATA_FILE) | |||||
| >>> data = data.flat_map(flat_map_func) | |||||
| Raises: | |||||
| TypeError: If `func` is not a function. | |||||
| TypeError: If `func` doesn't return a Dataset. | |||||
| """ | |||||
| dataset = None | |||||
| if not hasattr(func, '__call__'): | |||||
| raise TypeError("func must be a function.") | |||||
| for row_data in self: | |||||
| if dataset is None: | |||||
| dataset = func(row_data) | |||||
| else: | |||||
| dataset += func(row_data) | |||||
| if not isinstance(dataset, Dataset): | |||||
| raise TypeError("flat_map must return a Dataset object.") | |||||
| return dataset | |||||
| @check_map | @check_map | ||||
| def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None, | def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None, | ||||
| num_parallel_workers=None, python_multiprocessing=False): | num_parallel_workers=None, python_multiprocessing=False): | ||||
| @@ -0,0 +1,2 @@ | |||||
| ../data/dataset/test_flat_map/images1.txt | |||||
| ../data/dataset/test_flat_map/images2.txt | |||||
| @@ -0,0 +1,3 @@ | |||||
| ../data/dataset/testPK/data | |||||
| ../data/dataset/testImageNetData/train | |||||
| ../data/dataset/testImageNetData2/train | |||||
| @@ -0,0 +1,3 @@ | |||||
| ../data/dataset/testPK/data | |||||
| ../data/dataset/testImageNetData/train | |||||
| ../data/dataset/testImageNetData2/train | |||||
| @@ -0,0 +1,3 @@ | |||||
| ../data/dataset/testPK/data | |||||
| ../data/dataset/testImageNetData/train | |||||
| ../data/dataset/testImageNetData2/train | |||||
| @@ -0,0 +1,72 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| import numpy as np | |||||
| import mindspore.dataset as ds | |||||
| DATA_FILE = "../data/dataset/test_flat_map/images1.txt" | |||||
| INDEX_FILE = "../data/dataset/test_flat_map/image_index.txt" | |||||
| def test_flat_map_1(): | |||||
| ''' | |||||
| DATA_FILE records the path of image folders, load the images from them. | |||||
| ''' | |||||
| import mindspore.dataset.transforms.nlp.utils as nlp | |||||
| def flat_map_func(x): | |||||
| data_dir = nlp.as_text(x[0]) | |||||
| d = ds.ImageFolderDatasetV2(data_dir) | |||||
| return d | |||||
| data = ds.TextFileDataset(DATA_FILE) | |||||
| data = data.flat_map(flat_map_func) | |||||
| count = 0 | |||||
| for d in data: | |||||
| assert isinstance(d[0], np.ndarray) | |||||
| count += 1 | |||||
| assert count == 52 | |||||
| def test_flat_map_2(): | |||||
| ''' | |||||
| Flatten 3D structure data | |||||
| ''' | |||||
| import mindspore.dataset.transforms.nlp.utils as nlp | |||||
| def flat_map_func_1(x): | |||||
| data_dir = nlp.as_text(x[0]) | |||||
| d = ds.ImageFolderDatasetV2(data_dir) | |||||
| return d | |||||
| def flat_map_func_2(x): | |||||
| text_file = nlp.as_text(x[0]) | |||||
| d = ds.TextFileDataset(text_file) | |||||
| d = d.flat_map(flat_map_func_1) | |||||
| return d | |||||
| data = ds.TextFileDataset(INDEX_FILE) | |||||
| data = data.flat_map(flat_map_func_2) | |||||
| count = 0 | |||||
| for d in data: | |||||
| assert isinstance(d[0], np.ndarray) | |||||
| count += 1 | |||||
| assert count == 104 | |||||
| if __name__ == "__main__": | |||||
| test_flat_map_1() | |||||
| test_flat_map_2() | |||||