You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_compose.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.common.dtype as mstype
  18. import mindspore.dataset as ds
  19. import mindspore.dataset.transforms.c_transforms as c_transforms
  20. import mindspore.dataset.transforms.py_transforms as py_transforms
  21. import mindspore.dataset.vision.c_transforms as c_vision
  22. import mindspore.dataset.vision.py_transforms as py_vision
  23. from util import visualize_list, save_and_check_md5, config_get_set_seed, config_get_set_num_parallel_workers
  24. GENERATE_GOLDEN = False
  25. def test_compose():
  26. """
  27. Test C++ and Python Compose Op
  28. """
  29. ds.config.set_seed(0)
  30. def test_config(arr, op_list):
  31. try:
  32. data = ds.NumpySlicesDataset(arr, column_names="col", shuffle=False)
  33. data = data.map(input_columns=["col"], operations=op_list)
  34. res = []
  35. for i in data.create_dict_iterator(output_numpy=True):
  36. res.append(i["col"].tolist())
  37. return res
  38. except (TypeError, ValueError) as e:
  39. return str(e)
  40. # Test simple compose with only 1 op, this would generate a warning
  41. assert test_config([[1, 0], [3, 4]], c_transforms.Compose([c_transforms.Fill(2)])) == [[2, 2], [2, 2]]
  42. # Test 1 column -> 2 columns -> 1 -> 2 -> 1
  43. assert test_config([[1, 0]],
  44. c_transforms.Compose(
  45. [c_transforms.Duplicate(), c_transforms.Concatenate(), c_transforms.Duplicate(),
  46. c_transforms.Concatenate()])) \
  47. == [[1, 0] * 4]
  48. # Test one Python transform followed by a C++ transform. Type after OneHot is a float (mixed use-case)
  49. assert test_config([1, 0],
  50. c_transforms.Compose([py_transforms.OneHotOp(2), c_transforms.TypeCast(mstype.int32)])) \
  51. == [[[0, 1]], [[1, 0]]]
  52. # Test exceptions.
  53. with pytest.raises(TypeError) as error_info:
  54. c_transforms.Compose([1, c_transforms.TypeCast(mstype.int32)])
  55. assert "op_list[0] is neither a c_transform op (TensorOperation) nor a callable pyfunc." in str(error_info.value)
  56. # Test empty op list
  57. with pytest.raises(ValueError) as error_info:
  58. test_config([1, 0], c_transforms.Compose([]))
  59. assert "op_list can not be empty." in str(error_info.value)
  60. # Test Python compose op
  61. assert test_config([1, 0], py_transforms.Compose([py_transforms.OneHotOp(2)])) == [[[0, 1]], [[1, 0]]]
  62. assert test_config([1, 0], py_transforms.Compose([py_transforms.OneHotOp(2), (lambda x: x + x)])) == [[[0, 2]],
  63. [[2, 0]]]
  64. # Test nested Python compose op
  65. assert test_config([1, 0],
  66. py_transforms.Compose([py_transforms.Compose([py_transforms.OneHotOp(2)]), (lambda x: x + x)])) \
  67. == [[[0, 2]], [[2, 0]]]
  68. # Test passing a list of Python ops without Compose wrapper
  69. assert test_config([1, 0],
  70. [py_transforms.Compose([py_transforms.OneHotOp(2)]), (lambda x: x + x)]) \
  71. == [[[0, 2]], [[2, 0]]]
  72. assert test_config([1, 0], [py_transforms.OneHotOp(2), (lambda x: x + x)]) == [[[0, 2]], [[2, 0]]]
  73. # Test a non callable function
  74. with pytest.raises(ValueError) as error_info:
  75. py_transforms.Compose([1])
  76. assert "transforms[0] is not callable." in str(error_info.value)
  77. # Test empty Python op list
  78. with pytest.raises(ValueError) as error_info:
  79. test_config([1, 0], py_transforms.Compose([]))
  80. assert "transforms list is empty." in str(error_info.value)
  81. # Pass in extra brackets
  82. with pytest.raises(TypeError) as error_info:
  83. py_transforms.Compose([(lambda x: x + x)])()
  84. assert "Compose was called without an image. Fix invocation (avoid it being invoked as Compose([...])())." in str(
  85. error_info.value)
  86. def test_lambdas():
  87. """
  88. Test Multi Column Python Compose Op
  89. """
  90. ds.config.set_seed(0)
  91. def test_config(arr, input_columns, output_cols, op_list):
  92. data = ds.NumpySlicesDataset(arr, column_names=input_columns, shuffle=False)
  93. data = data.map(operations=op_list, input_columns=input_columns, output_columns=output_cols,
  94. column_order=output_cols)
  95. res = []
  96. for i in data.create_dict_iterator(output_numpy=True):
  97. for col_name in output_cols:
  98. res.append(i[col_name].tolist())
  99. return res
  100. arr = ([[1]], [[3]])
  101. assert test_config(arr, ["col0", "col1"], ["a"], py_transforms.Compose([(lambda x, y: x)])) == [[1]]
  102. assert test_config(arr, ["col0", "col1"], ["a"], py_transforms.Compose([lambda x, y: x, lambda x: x])) == [[1]]
  103. assert test_config(arr, ["col0", "col1"], ["a", "b"],
  104. py_transforms.Compose([lambda x, y: x, lambda x: (x, x * 2)])) == \
  105. [[1], [2]]
  106. assert test_config(arr, ["col0", "col1"], ["a", "b"],
  107. [lambda x, y: (x, x + y), lambda x, y: (x, y * 2)]) == [[1], [8]]
  108. def test_c_py_compose_transforms_module():
  109. """
  110. Test combining Python and C++ transforms
  111. """
  112. ds.config.set_seed(0)
  113. def test_config(arr, input_columns, output_cols, op_list):
  114. data = ds.NumpySlicesDataset(arr, column_names=input_columns, shuffle=False)
  115. data = data.map(operations=op_list, input_columns=input_columns, output_columns=output_cols,
  116. column_order=output_cols)
  117. res = []
  118. for i in data.create_dict_iterator(output_numpy=True):
  119. for col_name in output_cols:
  120. res.append(i[col_name].tolist())
  121. return res
  122. arr = [1, 0]
  123. assert test_config(arr, ["cols"], ["cols"],
  124. [py_transforms.OneHotOp(2), c_transforms.Mask(c_transforms.Relational.EQ, 1)]) == \
  125. [[[False, True]],
  126. [[True, False]]]
  127. assert test_config(arr, ["cols"], ["cols"],
  128. [py_transforms.OneHotOp(2), (lambda x: x + x), c_transforms.Fill(1)]) \
  129. == [[[1, 1]], [[1, 1]]]
  130. assert test_config(arr, ["cols"], ["cols"],
  131. [py_transforms.OneHotOp(2), (lambda x: x + x), c_transforms.Fill(1), (lambda x: x + x)]) \
  132. == [[[2, 2]], [[2, 2]]]
  133. assert test_config([[1, 3]], ["cols"], ["cols"],
  134. [c_transforms.PadEnd([3], -1), (lambda x: x + x)]) \
  135. == [[2, 6, -2]]
  136. arr = ([[1]], [[3]])
  137. assert test_config(arr, ["col0", "col1"], ["a"], [(lambda x, y: x + y), c_transforms.PadEnd([2], -1)]) == [[4, -1]]
  138. def test_c_py_compose_vision_module(plot=False, run_golden=True):
  139. """
  140. Test combining Python and C++ vision transforms
  141. """
  142. original_seed = config_get_set_seed(10)
  143. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  144. def test_config(plot, file_name, op_list):
  145. data_dir = "../data/dataset/testImageNetData/train/"
  146. data1 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
  147. data1 = data1.map(operations=op_list, input_columns=["image"])
  148. data2 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
  149. data2 = data2.map(operations=c_vision.Decode(), input_columns=["image"])
  150. original_images = []
  151. transformed_images = []
  152. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  153. transformed_images.append(item["image"])
  154. for item in data2.create_dict_iterator(num_epochs=1, output_numpy=True):
  155. original_images.append(item["image"])
  156. if run_golden:
  157. # Compare with expected md5 from images
  158. save_and_check_md5(data1, file_name, generate_golden=GENERATE_GOLDEN)
  159. if plot:
  160. visualize_list(original_images, transformed_images)
  161. test_config(op_list=[c_vision.Decode(),
  162. py_vision.ToPIL(),
  163. py_vision.Resize((224, 224)),
  164. np.array],
  165. plot=plot, file_name="compose_c_py_1.npz")
  166. test_config(op_list=[c_vision.Decode(),
  167. c_vision.Resize((224, 244)),
  168. py_vision.ToPIL(),
  169. np.array,
  170. c_vision.Resize((24, 24))],
  171. plot=plot, file_name="compose_c_py_2.npz")
  172. test_config(op_list=[py_vision.Decode(),
  173. py_vision.Resize((224, 224)),
  174. np.array,
  175. c_vision.RandomColor()],
  176. plot=plot, file_name="compose_c_py_3.npz")
  177. # Restore configuration
  178. ds.config.set_seed(original_seed)
  179. ds.config.set_num_parallel_workers((original_num_parallel_workers))
  180. def test_py_transforms_with_c_vision():
  181. """
  182. These examples will fail, as c_transform should not be used in py_transforms.Random(Apply/Choice/Order)
  183. """
  184. ds.config.set_seed(0)
  185. def test_config(op_list):
  186. data_dir = "../data/dataset/testImageNetData/train/"
  187. data = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
  188. data = data.map(operations=op_list)
  189. res = []
  190. for i in data.create_dict_iterator(output_numpy=True):
  191. for col_name in output_cols:
  192. res.append(i[col_name].tolist())
  193. return res
  194. with pytest.raises(ValueError) as error_info:
  195. test_config(py_transforms.RandomApply([c_vision.RandomResizedCrop(200)]))
  196. assert "transforms[0] is not a py transforms." in str(error_info.value)
  197. with pytest.raises(ValueError) as error_info:
  198. test_config(py_transforms.RandomChoice([c_vision.RandomResizedCrop(200)]))
  199. assert "transforms[0] is not a py transforms." in str(error_info.value)
  200. with pytest.raises(ValueError) as error_info:
  201. test_config(py_transforms.RandomOrder([np.array, c_vision.RandomResizedCrop(200)]))
  202. assert "transforms[1] is not a py transforms." in str(error_info.value)
  203. with pytest.raises(RuntimeError) as error_info:
  204. test_config([py_transforms.OneHotOp(20, 0.1)])
  205. assert "The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()" in str(
  206. error_info.value)
  207. def test_py_vision_with_c_transforms():
  208. """
  209. Test combining Python vision operations with C++ transforms operations
  210. """
  211. ds.config.set_seed(0)
  212. def test_config(op_list):
  213. data_dir = "../data/dataset/testImageNetData/train/"
  214. data1 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
  215. data1 = data1.map(operations=op_list, input_columns=["image"])
  216. transformed_images = []
  217. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  218. transformed_images.append(item["image"])
  219. return transformed_images
  220. # Test with Mask Op
  221. output_arr = test_config([py_vision.Decode(),
  222. py_vision.CenterCrop((2)), np.array,
  223. c_transforms.Mask(c_transforms.Relational.GE, 100)])
  224. exp_arr = [np.array([[[True, False, False],
  225. [True, False, False]],
  226. [[True, False, False],
  227. [True, False, False]]]),
  228. np.array([[[True, False, False],
  229. [True, False, False]],
  230. [[True, False, False],
  231. [True, False, False]]])]
  232. for exp_a, output in zip(exp_arr, output_arr):
  233. np.testing.assert_array_equal(exp_a, output)
  234. # Test with Fill Op
  235. output_arr = test_config([py_vision.Decode(),
  236. py_vision.CenterCrop((4)), np.array,
  237. c_transforms.Fill(10)])
  238. exp_arr = [np.ones((4, 4, 3)) * 10] * 2
  239. for exp_a, output in zip(exp_arr, output_arr):
  240. np.testing.assert_array_equal(exp_a, output)
  241. # Test with Concatenate Op, which will raise an error since ConcatenateOp only supports rank 1 tensors.
  242. with pytest.raises(RuntimeError) as error_info:
  243. test_config([py_vision.Decode(),
  244. py_vision.CenterCrop((2)), np.array,
  245. c_transforms.Concatenate(0)])
  246. assert "only 1D input supported" in str(error_info.value)
  247. def test_compose_with_custom_function():
  248. """
  249. Test Python Compose with custom function
  250. """
  251. def custom_function(x):
  252. return (x, x * x)
  253. # First dataset
  254. op_list = [
  255. lambda x: x * 3,
  256. custom_function,
  257. # convert two column output to one
  258. lambda *images: np.stack(images)
  259. ]
  260. data = ds.NumpySlicesDataset([[1, 2]], column_names=["col0"], shuffle=False)
  261. data = data.map(input_columns=["col0"], operations=op_list)
  262. #
  263. res = []
  264. for i in data.create_dict_iterator(output_numpy=True):
  265. res.append(i["col0"].tolist())
  266. assert res == [[[3, 6], [9, 36]]]
  267. if __name__ == "__main__":
  268. test_compose()
  269. test_lambdas()
  270. test_c_py_compose_transforms_module()
  271. test_c_py_compose_vision_module(plot=True)
  272. test_py_transforms_with_c_vision()
  273. test_py_vision_with_c_transforms()
  274. test_compose_with_custom_function()