diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/bindings.cc index 54844b8a2a..2596c4b61e 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/bindings.cc @@ -24,7 +24,18 @@ PYBIND_REGISTER(CBatchInfo, 0, ([](const py::module *m) { (void)py::class_(*m, "CBatchInfo") .def(py::init()) .def("get_epoch_num", &BatchOp::CBatchInfo::get_epoch_num) - .def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num); + .def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num) + .def(py::pickle( + [](const BatchOp::CBatchInfo &p) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(p.epoch_num_, p.batch_num_, p.total_batch_num_); + }, + [](py::tuple t) { // __setstate__ + if (t.size() != 3) throw std::runtime_error("Invalid state!"); + /* Create a new C++ instance */ + BatchOp::CBatchInfo p(t[0].cast(), t[1].cast(), t[2].cast()); + return p; + })); })); PYBIND_REGISTER(DatasetOp, 0, ([](const py::module *m) { diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index b5725cf709..83c1749bf5 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -158,22 +158,14 @@ class Dataset: if len(self.parent) > 1: raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)") ir_children = [d.parse_tree() for d in self.children] + # Bootstrap can only be performed on a copy of the original dataset node. + # Bootstrap on original dataset node will make all iterators share the same process pool + self.iterator_bootstrap() ir_node = self.parse(ir_children) - return self._alter_node(ir_node) + return ir_node - @staticmethod - def _alter_node(node): - """ - Internal method to add process pool to copied map node. - Returns: - DatasetNode. The altered node. - """ - if isinstance(node, MapDataset): - if node.python_multiprocessing: - # Bootstrap can only be performed on a copy of the original dataset node. - # Bootstrap on original dataset node will make all iterators share the same process pool - node.iterator_bootstrap() - return node + def iterator_bootstrap(self): + pass @staticmethod def _noop_mode(): @@ -272,7 +264,7 @@ class Dataset: @check_batch def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None, - input_columns=None, output_columns=None, column_order=None, pad_info=None): + input_columns=None, output_columns=None, column_order=None, pad_info=None, python_multiprocessing=False): """ Combine batch_size number of consecutive rows into batches. @@ -312,6 +304,8 @@ class Dataset: same). pad_info (dict, optional): Whether to perform padding on selected columns. pad_info={"col1":([224,224],0)} would pad column with name "col1" to a tensor of size [224,224] and fill the missing with 0. + python_multiprocessing (bool, optional): Parallelize Python function per_batch_map with multiple worker + processes. This option could be beneficial if the function is computational heavy (default=False). Returns: BatchDataset, dataset batched. @@ -339,7 +333,7 @@ class Dataset: >>> data = data.batch(batch_size=8, input_columns=["image"], per_batch_map=np_resize) """ return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns, - output_columns, column_order, pad_info) + output_columns, column_order, pad_info, python_multiprocessing) @check_sync_wait def sync_wait(self, condition_name, num_batch=1, callback=None): @@ -1835,7 +1829,8 @@ class BatchDataset(Dataset): """ def __init__(self, input_dataset, batch_size, drop_remainder=False, num_parallel_workers=None, - per_batch_map=None, input_columns=None, output_columns=None, column_order=None, pad_info=None): + per_batch_map=None, input_columns=None, output_columns=None, column_order=None, pad_info=None, + python_multiprocessing=False): super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers) if BatchDataset._is_ancestor_of_repeat(input_dataset): @@ -1858,6 +1853,10 @@ class BatchDataset(Dataset): self.pad = bool(pad_info is not None) self.pad_info = replace_none(pad_info, dict()) + self.python_multiprocessing = python_multiprocessing + self.process_pool = None + self.hook = None + def parse(self, children=None): return cde.BatchNode(children[0], self.batch_size, self.drop_remainder, self.pad, self.input_columns, self.output_columns, @@ -1923,9 +1922,32 @@ class BatchDataset(Dataset): new_op.output_columns = copy.deepcopy(self.output_columns, memodict) new_op.column_order = copy.deepcopy(self.column_order, memodict) new_op.pad = self.pad + new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict) + new_op.hook = copy.deepcopy(self.hook, memodict) new_op.pad_info = copy.deepcopy(self.pad_info, memodict) return new_op + # Iterator bootstrap will be called on iterator construction. + # A deep copy of Dataset object is created prior of iterator_bootstrap. + # This method will create per iterator process pool and bind pyfunc execution to the pool. + def iterator_bootstrap(self): + """ + Per iterator bootstrap callback. + """ + if self.python_multiprocessing: + # Construct pool with the callable list + # The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses + self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers, + initializer=_pyfunc_worker_init, + initargs=([self.per_batch_map],)) + idx = 0 + # Wrap per_batch_map into _PythonCallable + self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool) + + def __del__(self): + if hasattr(self, 'process_pool') and self.process_pool is not None: + self.process_pool.close() + class BatchInfo(cde.CBatchInfo): """ @@ -2352,7 +2374,6 @@ class MapDataset(Dataset): # CPP ops remain the same iter_specific_operations.append(op) self.operations = iter_specific_operations - self.hook = _ExceptHookHandler(self.process_pool) def __del__(self): if hasattr(self, 'process_pool') and self.process_pool is not None: diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index cf1baccacb..9ca1dcb2c6 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -538,7 +538,7 @@ def check_batch(method): @wraps(method) def new_method(self, *args, **kwargs): [batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns, output_columns, - column_order, pad_info], param_dict = parse_user_args(method, *args, **kwargs) + column_order, pad_info, python_multiprocessing], param_dict = parse_user_args(method, *args, **kwargs) if not (isinstance(batch_size, int) or (callable(batch_size))): raise TypeError("batch_size should either be an int or a callable.") @@ -577,6 +577,9 @@ def check_batch(method): if column_order is not None: check_columns(column_order, "column_order") + if python_multiprocessing is not None: + type_check(python_multiprocessing, (bool,), "python_multiprocessing") + return method(self, *args, **kwargs) return new_method diff --git a/tests/ut/python/dataset/test_var_batch_map_multi.py b/tests/ut/python/dataset/test_var_batch_map_multi.py new file mode 100644 index 0000000000..0f5e44076b --- /dev/null +++ b/tests/ut/python/dataset/test_var_batch_map_multi.py @@ -0,0 +1,477 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import os +import time +import numpy as np + +import mindspore.dataset as ds +from mindspore import log as logger + +from mindspore.dataset.transforms.py_transforms import Compose +import mindspore.dataset.vision.py_transforms as py_vision + + +def test_batch_corner_cases(): + def gen(num): + for i in range(num): + yield (np.array([i]),) + + def test_repeat_batch(gen_num, repeats, batch_size, drop, res): + data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).repeat(repeats).batch(batch_size, drop) + for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): + res.append(item["num"]) + + def test_batch_repeat(gen_num, repeats, batch_size, drop, res): + data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).batch(batch_size, drop).repeat(repeats) + for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): + res.append(item["num"]) + + tst1, tst2, tst3, tst4 = [], [], [], [] + # case 1 & 2, where batch_size is greater than the entire epoch, with drop equals to both val + test_repeat_batch(gen_num=2, repeats=4, batch_size=7, drop=False, res=tst1) + np.testing.assert_array_equal(np.array([[0], [1], [0], [1], [0], [1], [0]]), tst1[0], "\nATTENTION BATCH FAILED\n") + np.testing.assert_array_equal(np.array([[1]]), tst1[1], "\nATTENTION TEST BATCH FAILED\n") + assert len(tst1) == 2, "\nATTENTION TEST BATCH FAILED\n" + test_repeat_batch(gen_num=2, repeats=4, batch_size=5, drop=True, res=tst2) + np.testing.assert_array_equal(np.array([[0], [1], [0], [1], [0]]), tst2[0], "\nATTENTION BATCH FAILED\n") + assert len(tst2) == 1, "\nATTENTION TEST BATCH FAILED\n" + # case 3 & 4, batch before repeat with different drop + test_batch_repeat(gen_num=5, repeats=2, batch_size=4, drop=True, res=tst3) + np.testing.assert_array_equal(np.array([[0], [1], [2], [3]]), tst3[0], "\nATTENTION BATCH FAILED\n") + np.testing.assert_array_equal(tst3[0], tst3[1], "\nATTENTION BATCH FAILED\n") + assert len(tst3) == 2, "\nATTENTION BATCH FAILED\n" + test_batch_repeat(gen_num=5, repeats=2, batch_size=4, drop=False, res=tst4) + np.testing.assert_array_equal(np.array([[0], [1], [2], [3]]), tst4[0], "\nATTENTION BATCH FAILED\n") + np.testing.assert_array_equal(tst4[0], tst4[2], "\nATTENTION BATCH FAILED\n") + np.testing.assert_array_equal(tst4[1], np.array([[4]]), "\nATTENTION BATCH FAILED\n") + np.testing.assert_array_equal(tst4[1], tst4[3], "\nATTENTION BATCH FAILED\n") + assert len(tst4) == 4, "\nATTENTION BATCH FAILED\n" + + +# each sub-test in this function is tested twice with exact parameter except that the second test passes each row +# to a pyfunc which makes a deep copy of the row +def test_variable_size_batch(): + def check_res(arr1, arr2): + for ind, _ in enumerate(arr1): + if not np.array_equal(arr1[ind], np.array(arr2[ind])): + return False + return len(arr1) == len(arr2) + + def gen(num): + for i in range(num): + yield (np.array([i]),) + + def add_one_by_batch_num(batchInfo): + return batchInfo.get_batch_num() + 1 + + def add_one_by_epoch(batchInfo): + return batchInfo.get_epoch_num() + 1 + + def simple_copy(colList, batchInfo): + _ = batchInfo + return ([np.copy(arr) for arr in colList],) + + def test_repeat_batch(gen_num, r, drop, func, res): + data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).repeat(r).batch(batch_size=func, + drop_remainder=drop) + for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): + res.append(item["num"]) + + # same as test_repeat_batch except each row is passed through via a map which makes a copy of each element + def test_repeat_batch_with_copy_map(gen_num, r, drop, func): + res = [] + data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).repeat(r) \ + .batch(batch_size=func, drop_remainder=drop, input_columns=["num"], per_batch_map=simple_copy, + python_multiprocessing=True) + for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): + res.append(item["num"]) + return res + + def test_batch_repeat(gen_num, r, drop, func, res): + data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).batch(batch_size=func, drop_remainder=drop).repeat( + r) + for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): + res.append(item["num"]) + + # same as test_batch_repeat except each row is passed through via a map which makes a copy of each element + def test_batch_repeat_with_copy_map(gen_num, r, drop, func): + res = [] + data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]) \ + .batch(batch_size=func, drop_remainder=drop, input_columns=["num"], per_batch_map=simple_copy, + python_multiprocessing=True).repeat(r) + for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): + res.append(item["num"]) + return res + + tst1, tst2, tst3, tst4, tst5, tst6, tst7 = [], [], [], [], [], [], [] + + # no repeat, simple var size, based on batch_num + test_repeat_batch(7, 1, True, add_one_by_batch_num, tst1) + assert check_res(tst1, [[[0]], [[1], [2]], [[3], [4], [5]]]), "\nATTENTION VAR BATCH FAILED\n" + assert check_res(tst1, test_repeat_batch_with_copy_map(7, 1, True, add_one_by_batch_num)), "\nMAP FAILED\n" + test_repeat_batch(9, 1, False, add_one_by_batch_num, tst2) + assert check_res(tst2, [[[0]], [[1], [2]], [[3], [4], [5]], [[6], [7], [8]]]), "\nATTENTION VAR BATCH FAILED\n" + assert check_res(tst2, test_repeat_batch_with_copy_map(9, 1, False, add_one_by_batch_num)), "\nMAP FAILED\n" + # batch after repeat, cross epoch batch + test_repeat_batch(7, 2, False, add_one_by_batch_num, tst3) + assert check_res(tst3, [[[0]], [[1], [2]], [[3], [4], [5]], [[6], [0], [1], [2]], + [[3], [4], [5], [6]]]), "\nATTENTION VAR BATCH FAILED\n" + assert check_res(tst3, test_repeat_batch_with_copy_map(7, 2, False, add_one_by_batch_num)), "\nMAP FAILED\n" + # repeat after batch, no cross epoch batch, remainder dropped + test_batch_repeat(9, 7, True, add_one_by_batch_num, tst4) + assert check_res(tst4, [[[0]], [[1], [2]], [[3], [4], [5]]] * 7), "\nATTENTION VAR BATCH FAILED\n" + assert check_res(tst4, test_batch_repeat_with_copy_map(9, 7, True, add_one_by_batch_num)), "\nAMAP FAILED\n" + # repeat after batch, no cross epoch batch, remainder kept + test_batch_repeat(9, 3, False, add_one_by_batch_num, tst5) + assert check_res(tst5, [[[0]], [[1], [2]], [[3], [4], [5]], [[6], [7], [8]]] * 3), "\nATTENTION VAR BATCH FAILED\n" + assert check_res(tst5, test_batch_repeat_with_copy_map(9, 3, False, add_one_by_batch_num)), "\nMAP FAILED\n" + # batch_size based on epoch number, drop + test_batch_repeat(4, 4, True, add_one_by_epoch, tst6) + assert check_res(tst6, [[[0]], [[1]], [[2]], [[3]], [[0], [1]], [[2], [3]], [[0], [1], [2]], + [[0], [1], [2], [3]]]), "\nATTENTION VAR BATCH FAILED\n" + assert check_res(tst6, test_batch_repeat_with_copy_map(4, 4, True, add_one_by_epoch)), "\nMAP FAILED\n" + # batch_size based on epoch number, no drop + test_batch_repeat(4, 4, False, add_one_by_epoch, tst7) + assert check_res(tst7, [[[0]], [[1]], [[2]], [[3]], [[0], [1]], [[2], [3]], [[0], [1], [2]], [[3]], + [[0], [1], [2], [3]]]), "\nATTENTION VAR BATCH FAILED\n" + str(tst7) + assert check_res(tst7, test_batch_repeat_with_copy_map(4, 4, False, add_one_by_epoch)), "\nMAP FAILED\n" + + +def test_basic_batch_map(): + def check_res(arr1, arr2): + for ind, _ in enumerate(arr1): + if not np.array_equal(arr1[ind], np.array(arr2[ind])): + return False + return len(arr1) == len(arr2) + + def gen(num): + for i in range(num): + yield (np.array([i]),) + + def invert_sign_per_epoch(colList, batchInfo): + return ([np.copy(((-1) ** batchInfo.get_epoch_num()) * arr) for arr in colList],) + + def invert_sign_per_batch(colList, batchInfo): + return ([np.copy(((-1) ** batchInfo.get_batch_num()) * arr) for arr in colList],) + + def batch_map_config(num, r, batch_size, func, res): + data1 = ds.GeneratorDataset((lambda: gen(num)), ["num"]) \ + .batch(batch_size=batch_size, input_columns=["num"], per_batch_map=func, + python_multiprocessing=True).repeat(r) + for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): + res.append(item["num"]) + + tst1, tst2, = [], [] + batch_map_config(4, 2, 2, invert_sign_per_epoch, tst1) + assert check_res(tst1, [[[0], [1]], [[2], [3]], [[0], [-1]], [[-2], [-3]]]), "\nATTENTION MAP BATCH FAILED\n" + str( + tst1) + # each batch, the sign of a row is changed, test map is corrected performed according to its batch_num + batch_map_config(4, 2, 2, invert_sign_per_batch, tst2) + assert check_res(tst2, + [[[0], [1]], [[-2], [-3]], [[0], [1]], [[-2], [-3]]]), "\nATTENTION MAP BATCH FAILED\n" + str(tst2) + + +def test_batch_multi_col_map(): + def check_res(arr1, arr2): + for ind, _ in enumerate(arr1): + if not np.array_equal(arr1[ind], np.array(arr2[ind])): + return False + return len(arr1) == len(arr2) + + def gen(num): + for i in range(num): + yield (np.array([i]), np.array([i ** 2])) + + def col1_col2_add_num(col1, col2, batchInfo): + _ = batchInfo + return ([[np.copy(arr + 100) for arr in col1], + [np.copy(arr + 300) for arr in col2]]) + + def invert_sign_per_batch(colList, batchInfo): + return ([np.copy(((-1) ** batchInfo.get_batch_num()) * arr) for arr in colList],) + + def invert_sign_per_batch_multi_col(col1, col2, batchInfo): + return ([np.copy(((-1) ** batchInfo.get_batch_num()) * arr) for arr in col1], + [np.copy(((-1) ** batchInfo.get_batch_num()) * arr) for arr in col2]) + + def batch_map_config(num, r, batch_size, func, col_names, res): + data1 = ds.GeneratorDataset((lambda: gen(num)), ["num", "num_square"]) \ + .batch(batch_size=batch_size, input_columns=col_names, per_batch_map=func, + python_multiprocessing=True).repeat(r) + for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): + res.append(np.array([item["num"], item["num_square"]])) + + tst1, tst2, tst3, tst4 = [], [], [], [] + batch_map_config(4, 2, 2, invert_sign_per_batch, ["num_square"], tst1) + assert check_res(tst1, [[[[0], [1]], [[0], [1]]], [[[2], [3]], [[-4], [-9]]], [[[0], [1]], [[0], [1]]], + [[[2], [3]], [[-4], [-9]]]]), "\nATTENTION MAP BATCH FAILED\n" + str(tst1) + + batch_map_config(4, 2, 2, invert_sign_per_batch_multi_col, ["num", "num_square"], tst2) + assert check_res(tst2, [[[[0], [1]], [[0], [1]]], [[[-2], [-3]], [[-4], [-9]]], [[[0], [1]], [[0], [1]]], + [[[-2], [-3]], [[-4], [-9]]]]), "\nATTENTION MAP BATCH FAILED\n" + str(tst2) + + # the two tests below verify the order of the map. + # num_square column adds 100, num column adds 300. + batch_map_config(4, 3, 2, col1_col2_add_num, ["num_square", "num"], tst3) + assert check_res(tst3, [[[[300], [301]], [[100], [101]]], + [[[302], [303]], [[104], [109]]]] * 3), "\nATTENTION MAP BATCH FAILED\n" + str(tst3) + # num column adds 100, num_square column adds 300. + batch_map_config(4, 3, 2, col1_col2_add_num, ["num", "num_square"], tst4) + assert check_res(tst4, [[[[100], [101]], [[300], [301]]], + [[[102], [103]], [[304], [309]]]] * 3), "\nATTENTION MAP BATCH FAILED\n" + str(tst4) + + +def test_var_batch_multi_col_map(): + def check_res(arr1, arr2): + for ind, _ in enumerate(arr1): + if not np.array_equal(arr1[ind], np.array(arr2[ind])): + return False + return len(arr1) == len(arr2) + + # gen 3 columns + # first column: 0, 3, 6, 9 ... ... + # second column:1, 4, 7, 10 ... ... + # third column: 2, 5, 8, 11 ... ... + def gen_3_cols(num): + for i in range(num): + yield (np.array([i * 3]), np.array([i * 3 + 1]), np.array([i * 3 + 2])) + + # first epoch batch_size per batch: 1, 2 ,3 ... ... + # second epoch batch_size per batch: 2, 4, 6 ... ... + # third epoch batch_size per batch: 3, 6 ,9 ... ... + def batch_func(batchInfo): + return (batchInfo.get_batch_num() + 1) * (batchInfo.get_epoch_num() + 1) + + # multiply first col by batch_num, multiply second col by -batch_num + def map_func(col1, col2, batchInfo): + return ([np.copy((1 + batchInfo.get_batch_num()) * arr) for arr in col1], + [np.copy(-(1 + batchInfo.get_batch_num()) * arr) for arr in col2]) + + def batch_map_config(num, r, fbatch, fmap, col_names, res): + data1 = ds.GeneratorDataset((lambda: gen_3_cols(num)), ["col1", "col2", "col3"]) \ + .batch(batch_size=fbatch, input_columns=col_names, per_batch_map=fmap, python_multiprocessing=True) \ + .repeat(r) + for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): + res.append(np.array([item["col1"], item["col2"], item["col3"]])) + + tst1 = [] + tst1_res = [[[[0]], [[-1]], [[2]]], [[[6], [12]], [[-8], [-14]], [[5], [8]]], + [[[27], [36], [45]], [[-30], [-39], [-48]], [[11], [14], [17]]], + [[[72], [84], [96], [108]], [[-76], [-88], [-100], [-112]], [[20], [23], [26], [29]]]] + batch_map_config(10, 1, batch_func, map_func, ["col1", "col2"], tst1) + assert check_res(tst1, tst1_res), "test_var_batch_multi_col_map FAILED" + + +def test_var_batch_var_resize(): + # fake resize image according to its batch number, if it's 5-th batch, resize to (5^2, 5^2) = (25, 25) + def np_psedo_resize(col, batchInfo): + s = (batchInfo.get_batch_num() + 1) ** 2 + return ([np.copy(c[0:s, 0:s, :]) for c in col],) + + def add_one(batchInfo): + return batchInfo.get_batch_num() + 1 + + data1 = ds.ImageFolderDataset("../data/dataset/testPK/data/", num_parallel_workers=4, decode=True) + data1 = data1.batch(batch_size=add_one, drop_remainder=True, input_columns=["image"], per_batch_map=np_psedo_resize, + python_multiprocessing=True) + # i-th batch has shape [i, i^2, i^2, 3] + i = 1 + for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): + assert item["image"].shape == (i, i ** 2, i ** 2, 3), "\ntest_var_batch_var_resize FAILED\n" + i += 1 + + +def test_exception(): + def gen(num): + for i in range(num): + yield (np.array([i]),) + + def bad_batch_size(batchInfo): + raise StopIteration + # return batchInfo.get_batch_num() + + def bad_map_func(col, batchInfo): + raise StopIteration + # return (col,) + + data1 = ds.GeneratorDataset((lambda: gen(100)), ["num"]).batch(bad_batch_size) + try: + for _ in data1.create_dict_iterator(num_epochs=1): + pass + assert False + except RuntimeError: + pass + + data2 = ds.GeneratorDataset((lambda: gen(100)), ["num"]).batch(4, input_columns=["num"], per_batch_map=bad_map_func, + python_multiprocessing=True) + try: + for _ in data2.create_dict_iterator(num_epochs=1): + pass + assert False + except RuntimeError: + pass + + +def test_multi_col_map(): + def gen_2_cols(num): + for i in range(1, 1 + num): + yield (np.array([i]), np.array([i ** 2])) + + def split_col(col, batchInfo): + return ([np.copy(arr) for arr in col], [np.copy(-arr) for arr in col]) + + def merge_col(col1, col2, batchInfo): + merged = [] + for k, v in enumerate(col1): + merged.append(np.array(v + col2[k])) + return (merged,) + + def swap_col(col1, col2, batchInfo): + return ([np.copy(a) for a in col2], [np.copy(b) for b in col1]) + + def batch_map_config(num, s, f, in_nms, out_nms, col_order=None): + try: + dst = ds.GeneratorDataset((lambda: gen_2_cols(num)), ["col1", "col2"]) + dst = dst.batch(batch_size=s, input_columns=in_nms, output_columns=out_nms, per_batch_map=f, + column_order=col_order, python_multiprocessing=True) + res = [] + for row in dst.create_dict_iterator(num_epochs=1, output_numpy=True): + res.append(row) + return res + except (ValueError, RuntimeError, TypeError) as e: + return str(e) + + # split 1 col into 2 cols + res = batch_map_config(2, 2, split_col, ["col2"], ["col_x", "col_y"])[0] + assert np.array_equal(res["col1"], [[1], [2]]) + assert np.array_equal(res["col_x"], [[1], [4]]) and np.array_equal(res["col_y"], [[-1], [-4]]) + + # merge 2 cols into 1 col + res = batch_map_config(4, 4, merge_col, ["col1", "col2"], ["merged"])[0] + assert np.array_equal(res["merged"], [[2], [6], [12], [20]]) + + # swap once + res = batch_map_config(3, 3, swap_col, ["col1", "col2"], ["col1", "col2"])[0] + assert np.array_equal(res["col1"], [[1], [4], [9]]) and np.array_equal(res["col2"], [[1], [2], [3]]) + + # swap twice + res = batch_map_config(3, 3, swap_col, ["col1", "col2"], ["col2", "col1"])[0] + assert np.array_equal(res["col2"], [[1], [4], [9]]) and np.array_equal(res["col1"], [[1], [2], [3]]) + + # test project after map + res = batch_map_config(2, 2, split_col, ["col2"], ["col_x", "col_y"], ["col_x", "col_y", "col1"])[0] + assert list(res.keys()) == ["col_x", "col_y", "col1"] + + # test the insertion order is maintained + res = batch_map_config(2, 2, split_col, ["col2"], ["col_x", "col_y"], ["col1", "col_x", "col_y"])[0] + assert list(res.keys()) == ["col1", "col_x", "col_y"] + + # test exceptions + assert "output_columns with value 233 is not of type" in batch_map_config(2, 2, split_col, ["col2"], 233) + assert "column_order with value 233 is not of type" in batch_map_config(2, 2, split_col, ["col2"], ["col1"], 233) + assert "output_columns is NOT set correctly" in batch_map_config(2, 2, split_col, ["col2"], ["col1"]) + assert "Incorrect number of columns" in batch_map_config(2, 2, split_col, ["col2"], ["col3", "col4", "col5"]) + assert "col-1 doesn't exist" in batch_map_config(2, 2, split_col, ["col-1"], ["col_x", "col_y"]) + + +def test_exceptions_2(): + def gen(num): + for i in range(num): + yield (np.array([i]),) + + def simple_copy(colList, batchInfo): + return ([np.copy(arr) for arr in colList],) + + def test_wrong_col_name(gen_num, batch_size): + data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).batch(batch_size, input_columns=["num1"], + per_batch_map=simple_copy, + python_multiprocessing=True) + try: + for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True): + pass + return "success" + except RuntimeError as e: + return str(e) + + # test exception where column name is incorrect + assert "error. col:num1 doesn't exist" in test_wrong_col_name(4, 2) + + +IMAGENET_RAWDATA_DIR = "../data/dataset/testImageNetData2/train" + + +def skip_test_performance(): + def trans(images, batchInfo): + start_time = time.time() + print(os.getppid(), batchInfo.get_batch_num(), time.strftime("%H:%M:%S", time.localtime())) + for _ in range(50): + op = Compose([py_vision.Decode(), py_vision.Resize(20), py_vision.ToTensor()]) + images2 = [op(img) for img in images] + end_time = time.time() + print(os.getppid(), time.strftime("%H:%M:%S", time.localtime()), end_time - start_time) + return (images2,) + + def trans2(img): + start_time = time.time() + img2 = None + print(os.getppid(), time.strftime("%H:%M:%S", time.localtime())) + for _ in range(50): + op = Compose([py_vision.Decode(), py_vision.Resize(20), py_vision.ToTensor()]) + img2 = op(img) + end_time = time.time() + print(os.getppid(), time.strftime("%H:%M:%S", time.localtime()), end_time - start_time) + return img2 + + print(os.getppid()) + data = ds.ImageFolderDataset(IMAGENET_RAWDATA_DIR, shuffle=False).repeat(10) + print(data.get_dataset_size()) + data = data.batch(1, per_batch_map=trans, input_columns=["image"], num_parallel_workers=12, + python_multiprocessing=True) + data = data.map(operations=trans2, num_parallel_workers=8, python_multiprocessing=False) + start = time.time() + for _ in data: + pass + end = time.time() + + print("Taken= ", end - start) + + +if __name__ == '__main__': + logger.info("Running test_var_batch_map.py test_batch_corner_cases() function") + test_batch_corner_cases() + + logger.info("Running test_var_batch_map.py test_variable_size_batch() function") + test_variable_size_batch() + + logger.info("Running test_var_batch_map.py test_basic_batch_map() function") + test_basic_batch_map() + + logger.info("Running test_var_batch_map.py test_batch_multi_col_map() function") + test_batch_multi_col_map() + + logger.info("Running test_var_batch_map.py tesgit t_var_batch_multi_col_map() function") + test_var_batch_multi_col_map() + + logger.info("Running test_var_batch_map.py test_var_batch_var_resize() function") + test_var_batch_var_resize() + + logger.info("Running test_var_batch_map.py test_exception() function") + test_exception() + + logger.info("Running test_var_batch_map.py test_multi_col_map() function") + test_multi_col_map() + + logger.info("Running test_var_batch_map.py test_exceptions_2() function") + test_exceptions_2()