|
|
|
@@ -302,6 +302,19 @@ class Dataset: |
|
|
|
>>> # Create a dataset where every 100 rows is combined into a batch |
|
|
|
>>> # and drops the last incomplete batch if there is one. |
|
|
|
>>> data = data.batch(100, True) |
|
|
|
>>> |
|
|
|
>>> # resize image according to its batch number, if it's 5-th batch, resize to (5^2, 5^2) = (25, 25) |
|
|
|
>>> def np_resize(col, batchInfo): |
|
|
|
>>> output = col.copy() |
|
|
|
>>> s = (batchInfo.get_batch_num() + 1) ** 2 |
|
|
|
>>> index = 0 |
|
|
|
>>> for c in col: |
|
|
|
>>> img = Image.fromarray(c.astype('uint8')).convert('RGB') |
|
|
|
>>> img = img.resize((s, s), Image.ANTIALIAS) |
|
|
|
>>> output[index] = np.array(img) |
|
|
|
>>> index += 1 |
|
|
|
>>> return (output,) |
|
|
|
>>> data = data.batch(batch_size=8, input_columns=["image"], per_batch_map=np_resize) |
|
|
|
""" |
|
|
|
return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns, |
|
|
|
output_columns, column_order, pad_info) |
|
|
|
|