You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_slice_patches.py 8.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing SlicePatches Python API
  17. """
  18. import functools
  19. import numpy as np
  20. import mindspore.dataset as ds
  21. import mindspore.dataset.vision.c_transforms as c_vision
  22. import mindspore.dataset.vision.utils as mode
  23. from mindspore import log as logger
  24. from util import diff_mse, visualize_list
  25. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  26. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  27. def test_slice_patches_01(plot=False):
  28. """
  29. slice rgb image(100, 200) to 4 patches
  30. """
  31. slice_to_patches([100, 200], 2, 2, True, plot=plot)
  32. def test_slice_patches_02(plot=False):
  33. """
  34. no op
  35. """
  36. slice_to_patches([100, 200], 1, 1, True, plot=plot)
  37. def test_slice_patches_03(plot=False):
  38. """
  39. slice rgb image(99, 199) to 4 patches in pad mode
  40. """
  41. slice_to_patches([99, 199], 2, 2, True, plot=plot)
  42. def test_slice_patches_04(plot=False):
  43. """
  44. slice rgb image(99, 199) to 4 patches in drop mode
  45. """
  46. slice_to_patches([99, 199], 2, 2, False, plot=plot)
  47. def test_slice_patches_05(plot=False):
  48. """
  49. slice rgb image(99, 199) to 4 patches in pad mode
  50. """
  51. slice_to_patches([99, 199], 2, 2, True, 255, plot=plot)
  52. def slice_to_patches(ori_size, num_h, num_w, pad_or_drop, fill_value=0, plot=False):
  53. """
  54. Tool function for slice patches
  55. """
  56. logger.info("test_slice_patches_pipeline")
  57. cols = ['img' + str(x) for x in range(num_h*num_w)]
  58. # First dataset
  59. dataset1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
  60. decode_op = c_vision.Decode()
  61. resize_op = c_vision.Resize(ori_size) # H, W
  62. slice_patches_op = c_vision.SlicePatches(
  63. num_h, num_w, mode.SliceMode.PAD, fill_value)
  64. if not pad_or_drop:
  65. slice_patches_op = c_vision.SlicePatches(
  66. num_h, num_w, mode.SliceMode.DROP)
  67. dataset1 = dataset1.map(operations=decode_op, input_columns=["image"])
  68. dataset1 = dataset1.map(operations=resize_op, input_columns=["image"])
  69. dataset1 = dataset1.map(operations=slice_patches_op,
  70. input_columns=["image"], output_columns=cols, column_order=cols)
  71. # Second dataset
  72. dataset2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
  73. dataset2 = dataset2.map(operations=decode_op, input_columns=["image"])
  74. dataset2 = dataset2.map(operations=resize_op, input_columns=["image"])
  75. func_slice_patches = functools.partial(
  76. slice_patches, num_h=num_h, num_w=num_w, pad_or_drop=pad_or_drop, fill_value=fill_value)
  77. dataset2 = dataset2.map(operations=func_slice_patches,
  78. input_columns=["image"], output_columns=cols, column_order=cols)
  79. num_iter = 0
  80. patches_c = []
  81. patches_py = []
  82. for data1, data2 in zip(dataset1.create_dict_iterator(num_epochs=1, output_numpy=True),
  83. dataset2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  84. for x in range(num_h*num_w):
  85. col = "img" + str(x)
  86. mse = diff_mse(data1[col], data2[col])
  87. logger.info("slice_patches_{}, mse: {}".format(num_iter + 1, mse))
  88. assert mse == 0
  89. patches_c.append(data1[col])
  90. patches_py.append(data2[col])
  91. num_iter += 1
  92. if plot:
  93. visualize_list(patches_py, patches_c)
  94. def test_slice_patches_exception_01():
  95. """
  96. Test SlicePatches with invalid parameters
  97. """
  98. logger.info("test_Slice_Patches_exception")
  99. try:
  100. _ = c_vision.SlicePatches(0, 2)
  101. except ValueError as e:
  102. logger.info("Got an exception in SlicePatches: {}".format(str(e)))
  103. assert "Input num_height is not within" in str(e)
  104. try:
  105. _ = c_vision.SlicePatches(2, 0)
  106. except ValueError as e:
  107. logger.info("Got an exception in SlicePatches: {}".format(str(e)))
  108. assert "Input num_width is not within" in str(e)
  109. try:
  110. _ = c_vision.SlicePatches(2, 2, 1)
  111. except TypeError as e:
  112. logger.info("Got an exception in SlicePatches: {}".format(str(e)))
  113. assert "Argument slice_mode with value" in str(e)
  114. try:
  115. _ = c_vision.SlicePatches(2, 2, mode.SliceMode.PAD, -1)
  116. except ValueError as e:
  117. logger.info("Got an exception in SlicePatches: {}".format(str(e)))
  118. assert "Input fill_value is not within" in str(e)
  119. def test_slice_patches_06():
  120. image = np.random.randint(0, 255, (158, 126, 1)).astype(np.int32)
  121. slice_patches_op = c_vision.SlicePatches(2, 8)
  122. patches = slice_patches_op(image)
  123. assert len(patches) == 16
  124. assert patches[0].shape == (79, 16, 1)
  125. def test_slice_patches_07():
  126. image = np.random.randint(0, 255, (158, 126)).astype(np.int32)
  127. slice_patches_op = c_vision.SlicePatches(2, 8)
  128. patches = slice_patches_op(image)
  129. assert len(patches) == 16
  130. assert patches[0].shape == (79, 16)
  131. def test_slice_patches_08():
  132. np_data = np.random.randint(0, 255, (1, 56, 82, 256)).astype(np.uint8)
  133. dataset = ds.NumpySlicesDataset(np_data, column_names=["image"])
  134. slice_patches_op = c_vision.SlicePatches(2, 2)
  135. dataset = dataset.map(input_columns=["image"], output_columns=["img0", "img1", "img2", "img3"],
  136. column_order=["img0", "img1", "img2", "img3"],
  137. operations=slice_patches_op)
  138. for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  139. patch_shape = item['img0'].shape
  140. assert patch_shape == (28, 41, 256)
  141. def test_slice_patches_09():
  142. image = np.random.randint(0, 255, (56, 82, 256)).astype(np.uint8)
  143. slice_patches_op = c_vision.SlicePatches(4, 3, mode.SliceMode.PAD)
  144. patches = slice_patches_op(image)
  145. assert len(patches) == 12
  146. assert patches[0].shape == (14, 28, 256)
  147. def skip_test_slice_patches_10():
  148. image = np.random.randint(0, 255, (7000, 7000, 255)).astype(np.uint8)
  149. slice_patches_op = c_vision.SlicePatches(10, 13, mode.SliceMode.DROP)
  150. patches = slice_patches_op(image)
  151. assert patches[0].shape == (700, 538, 255)
  152. def skip_test_slice_patches_11():
  153. np_data = np.random.randint(0, 255, (1, 7000, 7000, 256)).astype(np.uint8)
  154. dataset = ds.NumpySlicesDataset(np_data, column_names=["image"])
  155. slice_patches_op = c_vision.SlicePatches(10, 13, mode.SliceMode.DROP)
  156. cols = ['img' + str(x) for x in range(10*13)]
  157. dataset = dataset.map(input_columns=["image"], output_columns=cols,
  158. column_order=cols, operations=slice_patches_op)
  159. for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  160. patch_shape = item['img0'].shape
  161. assert patch_shape == (700, 538, 256)
  162. def slice_patches(image, num_h, num_w, pad_or_drop, fill_value):
  163. """ help function which slice patches with numpy """
  164. if num_h == 1 and num_w == 1:
  165. return image
  166. # (H, W, C)
  167. H, W, C = image.shape
  168. patch_h = H // num_h
  169. patch_w = W // num_w
  170. if H % num_h != 0:
  171. if pad_or_drop:
  172. patch_h += 1
  173. if W % num_w != 0:
  174. if pad_or_drop:
  175. patch_w += 1
  176. img = image[:, :, :]
  177. if pad_or_drop:
  178. img = np.full([patch_h*num_h, patch_w*num_w, C], fill_value, dtype=np.uint8)
  179. img[:H, :W] = image[:, :, :]
  180. patches = []
  181. for top in range(num_h):
  182. for left in range(num_w):
  183. patches.append(img[top*patch_h:(top+1)*patch_h,
  184. left*patch_w:(left+1)*patch_w, :])
  185. return (*patches,)
  186. if __name__ == "__main__":
  187. test_slice_patches_01(plot=True)
  188. test_slice_patches_02(plot=True)
  189. test_slice_patches_03(plot=True)
  190. test_slice_patches_04(plot=True)
  191. test_slice_patches_05(plot=True)
  192. test_slice_patches_06()
  193. test_slice_patches_07()
  194. test_slice_patches_08()
  195. test_slice_patches_09()
  196. test_slice_patches_exception_01()