From f258697901a95b196091d2a96b3f9559fca250b5 Mon Sep 17 00:00:00 2001 From: jonyguo Date: Thu, 22 Oct 2020 14:38:20 +0800 Subject: [PATCH] add per_batch_map usage in api comment --- mindspore/dataset/engine/datasets.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 8bbe8a2daa..a7a8389c67 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -302,6 +302,19 @@ class Dataset: >>> # Create a dataset where every 100 rows is combined into a batch >>> # and drops the last incomplete batch if there is one. >>> data = data.batch(100, True) + >>> + >>> # resize image according to its batch number, if it's 5-th batch, resize to (5^2, 5^2) = (25, 25) + >>> def np_resize(col, batchInfo): + >>> output = col.copy() + >>> s = (batchInfo.get_batch_num() + 1) ** 2 + >>> index = 0 + >>> for c in col: + >>> img = Image.fromarray(c.astype('uint8')).convert('RGB') + >>> img = img.resize((s, s), Image.ANTIALIAS) + >>> output[index] = np.array(img) + >>> index += 1 + >>> return (output,) + >>> data = data.batch(batch_size=8, input_columns=["image"], per_batch_map=np_resize) """ return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns, output_columns, column_order, pad_info)