|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """
- Testing SlicePatches Python API
- """
- import functools
- import numpy as np
-
- import mindspore.dataset as ds
- import mindspore.dataset.vision.c_transforms as c_vision
- import mindspore.dataset.vision.utils as mode
-
- from mindspore import log as logger
- from util import diff_mse, visualize_list
-
- DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
- SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
-
-
- def test_slice_patches_01(plot=False):
- """
- slice rgb image(100, 200) to 4 patches
- """
- slice_to_patches([100, 200], 2, 2, True, plot=plot)
-
-
- def test_slice_patches_02(plot=False):
- """
- no op
- """
- slice_to_patches([100, 200], 1, 1, True, plot=plot)
-
-
- def test_slice_patches_03(plot=False):
- """
- slice rgb image(99, 199) to 4 patches in pad mode
- """
- slice_to_patches([99, 199], 2, 2, True, plot=plot)
-
-
- def test_slice_patches_04(plot=False):
- """
- slice rgb image(99, 199) to 4 patches in drop mode
- """
- slice_to_patches([99, 199], 2, 2, False, plot=plot)
-
-
- def test_slice_patches_05(plot=False):
- """
- slice rgb image(99, 199) to 4 patches in pad mode
- """
- slice_to_patches([99, 199], 2, 2, True, 255, plot=plot)
-
-
- def slice_to_patches(ori_size, num_h, num_w, pad_or_drop, fill_value=0, plot=False):
- """
- Tool function for slice patches
- """
- logger.info("test_slice_patches_pipeline")
-
- cols = ['img' + str(x) for x in range(num_h*num_w)]
- # First dataset
- dataset1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
- decode_op = c_vision.Decode()
- resize_op = c_vision.Resize(ori_size) # H, W
- slice_patches_op = c_vision.SlicePatches(
- num_h, num_w, mode.SliceMode.PAD, fill_value)
- if not pad_or_drop:
- slice_patches_op = c_vision.SlicePatches(
- num_h, num_w, mode.SliceMode.DROP)
- dataset1 = dataset1.map(operations=decode_op, input_columns=["image"])
- dataset1 = dataset1.map(operations=resize_op, input_columns=["image"])
- dataset1 = dataset1.map(operations=slice_patches_op,
- input_columns=["image"], output_columns=cols, column_order=cols)
- # Second dataset
- dataset2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
- dataset2 = dataset2.map(operations=decode_op, input_columns=["image"])
- dataset2 = dataset2.map(operations=resize_op, input_columns=["image"])
- func_slice_patches = functools.partial(
- slice_patches, num_h=num_h, num_w=num_w, pad_or_drop=pad_or_drop, fill_value=fill_value)
- dataset2 = dataset2.map(operations=func_slice_patches,
- input_columns=["image"], output_columns=cols, column_order=cols)
-
- num_iter = 0
- patches_c = []
- patches_py = []
- for data1, data2 in zip(dataset1.create_dict_iterator(num_epochs=1, output_numpy=True),
- dataset2.create_dict_iterator(num_epochs=1, output_numpy=True)):
-
- for x in range(num_h*num_w):
- col = "img" + str(x)
- mse = diff_mse(data1[col], data2[col])
- logger.info("slice_patches_{}, mse: {}".format(num_iter + 1, mse))
- assert mse == 0
- patches_c.append(data1[col])
- patches_py.append(data2[col])
- num_iter += 1
- if plot:
- visualize_list(patches_py, patches_c)
-
-
- def test_slice_patches_exception_01():
- """
- Test SlicePatches with invalid parameters
- """
- logger.info("test_Slice_Patches_exception")
- try:
- _ = c_vision.SlicePatches(0, 2)
- except ValueError as e:
- logger.info("Got an exception in SlicePatches: {}".format(str(e)))
- assert "Input num_height is not within" in str(e)
-
- try:
- _ = c_vision.SlicePatches(2, 0)
- except ValueError as e:
- logger.info("Got an exception in SlicePatches: {}".format(str(e)))
- assert "Input num_width is not within" in str(e)
-
- try:
- _ = c_vision.SlicePatches(2, 2, 1)
- except TypeError as e:
- logger.info("Got an exception in SlicePatches: {}".format(str(e)))
- assert "Argument slice_mode with value" in str(e)
-
- try:
- _ = c_vision.SlicePatches(2, 2, mode.SliceMode.PAD, -1)
- except ValueError as e:
- logger.info("Got an exception in SlicePatches: {}".format(str(e)))
- assert "Input fill_value is not within" in str(e)
-
-
- def slice_patches(image, num_h, num_w, pad_or_drop, fill_value):
- """ help function which slice patches with numpy """
- if num_h == 1 and num_w == 1:
- return image
- # (H, W, C)
- H, W, C = image.shape
- patch_h = H // num_h
- patch_w = W // num_w
- if H % num_h != 0:
- if pad_or_drop:
- patch_h += 1
- if W % num_w != 0:
- if pad_or_drop:
- patch_w += 1
- img = image[:, :, :]
- if pad_or_drop:
- img = np.full([patch_h*num_h, patch_w*num_w, C], fill_value, dtype=np.uint8)
- img[:H, :W] = image[:, :, :]
- patches = []
- for top in range(num_h):
- for left in range(num_w):
- patches.append(img[top*patch_h:(top+1)*patch_h,
- left*patch_w:(left+1)*patch_w, :])
-
- return (*patches,)
-
-
- if __name__ == "__main__":
- test_slice_patches_01(plot=True)
- test_slice_patches_02(plot=True)
- test_slice_patches_03(plot=True)
- test_slice_patches_04(plot=True)
- test_slice_patches_05(plot=True)
- test_slice_patches_exception_01()
|