You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_random_crop.py 19 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing RandomCrop op in DE
  17. """
  18. import numpy as np
  19. import mindspore.dataset.transforms.vision.c_transforms as c_vision
  20. import mindspore.dataset.transforms.vision.py_transforms as py_vision
  21. import mindspore.dataset.transforms.vision.utils as mode
  22. import mindspore.dataset as ds
  23. from mindspore import log as logger
  24. from util import save_and_check_md5, visualize
  25. GENERATE_GOLDEN = False
  26. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  27. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  28. def test_random_crop_op(plot=False):
  29. """
  30. Test RandomCrop Op
  31. """
  32. logger.info("test_random_crop_op")
  33. # First dataset
  34. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  35. random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
  36. decode_op = c_vision.Decode()
  37. data1 = data1.map(input_columns=["image"], operations=decode_op)
  38. data1 = data1.map(input_columns=["image"], operations=random_crop_op)
  39. # Second dataset
  40. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  41. data2 = data2.map(input_columns=["image"], operations=decode_op)
  42. image_cropped = []
  43. image = []
  44. for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
  45. image1 = item1["image"]
  46. image2 = item2["image"]
  47. image_cropped.append(image1)
  48. image.append(image2)
  49. if plot:
  50. visualize(image, image_cropped)
  51. def test_random_crop_01_c():
  52. """
  53. Test RandomCrop op with c_transforms: size is a single integer, expected to pass
  54. """
  55. logger.info("test_random_crop_01_c")
  56. ds.config.set_seed(0)
  57. ds.config.set_num_parallel_workers(1)
  58. # Generate dataset
  59. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  60. # Note: If size is an int, a square crop of size (size, size) is returned.
  61. random_crop_op = c_vision.RandomCrop(512)
  62. decode_op = c_vision.Decode()
  63. data = data.map(input_columns=["image"], operations=decode_op)
  64. data = data.map(input_columns=["image"], operations=random_crop_op)
  65. filename = "random_crop_01_c_result.npz"
  66. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  67. def test_random_crop_01_py():
  68. """
  69. Test RandomCrop op with py_transforms: size is a single integer, expected to pass
  70. """
  71. logger.info("test_random_crop_01_py")
  72. ds.config.set_seed(0)
  73. ds.config.set_num_parallel_workers(1)
  74. # Generate dataset
  75. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  76. # Note: If size is an int, a square crop of size (size, size) is returned.
  77. transforms = [
  78. py_vision.Decode(),
  79. py_vision.RandomCrop(512),
  80. py_vision.ToTensor()
  81. ]
  82. transform = py_vision.ComposeOp(transforms)
  83. data = data.map(input_columns=["image"], operations=transform())
  84. filename = "random_crop_01_py_result.npz"
  85. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  86. def test_random_crop_02_c():
  87. """
  88. Test RandomCrop op with c_transforms: size is a list/tuple with length 2, expected to pass
  89. """
  90. logger.info("test_random_crop_02_c")
  91. ds.config.set_seed(0)
  92. ds.config.set_num_parallel_workers(1)
  93. # Generate dataset
  94. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  95. # Note: If size is a sequence of length 2, it should be (height, width).
  96. random_crop_op = c_vision.RandomCrop([512, 375])
  97. decode_op = c_vision.Decode()
  98. data = data.map(input_columns=["image"], operations=decode_op)
  99. data = data.map(input_columns=["image"], operations=random_crop_op)
  100. filename = "random_crop_02_c_result.npz"
  101. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  102. def test_random_crop_02_py():
  103. """
  104. Test RandomCrop op with py_transforms: size is a list/tuple with length 2, expected to pass
  105. """
  106. logger.info("test_random_crop_02_py")
  107. ds.config.set_seed(0)
  108. ds.config.set_num_parallel_workers(1)
  109. # Generate dataset
  110. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  111. # Note: If size is a sequence of length 2, it should be (height, width).
  112. transforms = [
  113. py_vision.Decode(),
  114. py_vision.RandomCrop([512, 375]),
  115. py_vision.ToTensor()
  116. ]
  117. transform = py_vision.ComposeOp(transforms)
  118. data = data.map(input_columns=["image"], operations=transform())
  119. filename = "random_crop_02_py_result.npz"
  120. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  121. def test_random_crop_03_c():
  122. """
  123. Test RandomCrop op with c_transforms: input image size == crop size, expected to pass
  124. """
  125. logger.info("test_random_crop_03_c")
  126. ds.config.set_seed(0)
  127. ds.config.set_num_parallel_workers(1)
  128. # Generate dataset
  129. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  130. # Note: The size of the image is 4032*2268
  131. random_crop_op = c_vision.RandomCrop([2268, 4032])
  132. decode_op = c_vision.Decode()
  133. data = data.map(input_columns=["image"], operations=decode_op)
  134. data = data.map(input_columns=["image"], operations=random_crop_op)
  135. filename = "random_crop_03_c_result.npz"
  136. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  137. def test_random_crop_03_py():
  138. """
  139. Test RandomCrop op with py_transforms: input image size == crop size, expected to pass
  140. """
  141. logger.info("test_random_crop_03_py")
  142. ds.config.set_seed(0)
  143. ds.config.set_num_parallel_workers(1)
  144. # Generate dataset
  145. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  146. # Note: The size of the image is 4032*2268
  147. transforms = [
  148. py_vision.Decode(),
  149. py_vision.RandomCrop([2268, 4032]),
  150. py_vision.ToTensor()
  151. ]
  152. transform = py_vision.ComposeOp(transforms)
  153. data = data.map(input_columns=["image"], operations=transform())
  154. filename = "random_crop_03_py_result.npz"
  155. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  156. def test_random_crop_04_c():
  157. """
  158. Test RandomCrop op with c_transforms: input image size < crop size, expected to fail
  159. """
  160. logger.info("test_random_crop_04_c")
  161. ds.config.set_seed(0)
  162. ds.config.set_num_parallel_workers(1)
  163. try:
  164. # Generate dataset
  165. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  166. # Note: The size of the image is 4032*2268
  167. random_crop_op = c_vision.RandomCrop([2268, 4033])
  168. decode_op = c_vision.Decode()
  169. data = data.map(input_columns=["image"], operations=decode_op)
  170. data = data.map(input_columns=["image"], operations=random_crop_op)
  171. image_list = []
  172. for item in data.create_dict_iterator():
  173. image = item["image"]
  174. image_list.append(image.shape)
  175. except Exception as e:
  176. logger.info("Got an exception in DE: {}".format(str(e)))
  177. def test_random_crop_04_py():
  178. """
  179. Test RandomCrop op with py_transforms:
  180. input image size < crop size, expected to fail
  181. """
  182. logger.info("test_random_crop_04_py")
  183. ds.config.set_seed(0)
  184. ds.config.set_num_parallel_workers(1)
  185. try:
  186. # Generate dataset
  187. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  188. # Note: The size of the image is 4032*2268
  189. transforms = [
  190. py_vision.Decode(),
  191. py_vision.RandomCrop([2268, 4033]),
  192. py_vision.ToTensor()
  193. ]
  194. transform = py_vision.ComposeOp(transforms)
  195. data = data.map(input_columns=["image"], operations=transform())
  196. image_list = []
  197. for item in data.create_dict_iterator():
  198. image = (item["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
  199. image_list.append(image.shape)
  200. except Exception as e:
  201. logger.info("Got an exception in DE: {}".format(str(e)))
  202. def test_random_crop_05_c():
  203. """
  204. Test RandomCrop op with c_transforms:
  205. input image size < crop size but pad_if_needed is enabled,
  206. expected to pass
  207. """
  208. logger.info("test_random_crop_05_c")
  209. ds.config.set_seed(0)
  210. ds.config.set_num_parallel_workers(1)
  211. # Generate dataset
  212. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  213. # Note: The size of the image is 4032*2268
  214. random_crop_op = c_vision.RandomCrop([2268, 4033], [200, 200, 200, 200], pad_if_needed=True)
  215. decode_op = c_vision.Decode()
  216. data = data.map(input_columns=["image"], operations=decode_op)
  217. data = data.map(input_columns=["image"], operations=random_crop_op)
  218. filename = "random_crop_05_c_result.npz"
  219. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  220. def test_random_crop_05_py():
  221. """
  222. Test RandomCrop op with py_transforms:
  223. input image size < crop size but pad_if_needed is enabled,
  224. expected to pass
  225. """
  226. logger.info("test_random_crop_05_py")
  227. ds.config.set_seed(0)
  228. ds.config.set_num_parallel_workers(1)
  229. # Generate dataset
  230. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  231. # Note: The size of the image is 4032*2268
  232. transforms = [
  233. py_vision.Decode(),
  234. py_vision.RandomCrop([2268, 4033], [200, 200, 200, 200], pad_if_needed=True),
  235. py_vision.ToTensor()
  236. ]
  237. transform = py_vision.ComposeOp(transforms)
  238. data = data.map(input_columns=["image"], operations=transform())
  239. filename = "random_crop_05_py_result.npz"
  240. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  241. def test_random_crop_06_c():
  242. """
  243. Test RandomCrop op with c_transforms:
  244. invalid size, expected to raise TypeError
  245. """
  246. logger.info("test_random_crop_06_c")
  247. ds.config.set_seed(0)
  248. ds.config.set_num_parallel_workers(1)
  249. try:
  250. # Generate dataset
  251. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  252. # Note: if size is neither an int nor a list of length 2, an exception will raise
  253. random_crop_op = c_vision.RandomCrop([512, 512, 375])
  254. decode_op = c_vision.Decode()
  255. data = data.map(input_columns=["image"], operations=decode_op)
  256. data = data.map(input_columns=["image"], operations=random_crop_op)
  257. image_list = []
  258. for item in data.create_dict_iterator():
  259. image = item["image"]
  260. image_list.append(image.shape)
  261. except TypeError as e:
  262. logger.info("Got an exception in DE: {}".format(str(e)))
  263. assert "Size" in str(e)
  264. def test_random_crop_06_py():
  265. """
  266. Test RandomCrop op with py_transforms:
  267. invalid size, expected to raise TypeError
  268. """
  269. logger.info("test_random_crop_06_py")
  270. ds.config.set_seed(0)
  271. ds.config.set_num_parallel_workers(1)
  272. try:
  273. # Generate dataset
  274. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  275. # Note: if size is neither an int nor a list of length 2, an exception will raise
  276. transforms = [
  277. py_vision.Decode(),
  278. py_vision.RandomCrop([512, 512, 375]),
  279. py_vision.ToTensor()
  280. ]
  281. transform = py_vision.ComposeOp(transforms)
  282. data = data.map(input_columns=["image"], operations=transform())
  283. image_list = []
  284. for item in data.create_dict_iterator():
  285. image = (item["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
  286. image_list.append(image.shape)
  287. except TypeError as e:
  288. logger.info("Got an exception in DE: {}".format(str(e)))
  289. assert "Size" in str(e)
  290. def test_random_crop_07_c():
  291. """
  292. Test RandomCrop op with c_transforms:
  293. padding_mode is Border.CONSTANT and fill_value is 255 (White),
  294. expected to pass
  295. """
  296. logger.info("test_random_crop_07_c")
  297. ds.config.set_seed(0)
  298. ds.config.set_num_parallel_workers(1)
  299. # Generate dataset
  300. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  301. # Note: The padding_mode is default as Border.CONSTANT and set filling color to be white.
  302. random_crop_op = c_vision.RandomCrop(512, [200, 200, 200, 200], fill_value=(255, 255, 255))
  303. decode_op = c_vision.Decode()
  304. data = data.map(input_columns=["image"], operations=decode_op)
  305. data = data.map(input_columns=["image"], operations=random_crop_op)
  306. filename = "random_crop_07_c_result.npz"
  307. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  308. def test_random_crop_07_py():
  309. """
  310. Test RandomCrop op with py_transforms:
  311. padding_mode is Border.CONSTANT and fill_value is 255 (White),
  312. expected to pass
  313. """
  314. logger.info("test_random_crop_07_py")
  315. ds.config.set_seed(0)
  316. ds.config.set_num_parallel_workers(1)
  317. # Generate dataset
  318. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  319. # Note: The padding_mode is default as Border.CONSTANT and set filling color to be white.
  320. transforms = [
  321. py_vision.Decode(),
  322. py_vision.RandomCrop(512, [200, 200, 200, 200], fill_value=(255, 255, 255)),
  323. py_vision.ToTensor()
  324. ]
  325. transform = py_vision.ComposeOp(transforms)
  326. data = data.map(input_columns=["image"], operations=transform())
  327. filename = "random_crop_07_py_result.npz"
  328. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  329. def test_random_crop_08_c():
  330. """
  331. Test RandomCrop op with c_transforms: padding_mode is Border.EDGE,
  332. expected to pass
  333. """
  334. logger.info("test_random_crop_08_c")
  335. ds.config.set_seed(0)
  336. ds.config.set_num_parallel_workers(1)
  337. # Generate dataset
  338. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  339. # Note: The padding_mode is Border.EDGE.
  340. random_crop_op = c_vision.RandomCrop(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE)
  341. decode_op = c_vision.Decode()
  342. data = data.map(input_columns=["image"], operations=decode_op)
  343. data = data.map(input_columns=["image"], operations=random_crop_op)
  344. filename = "random_crop_08_c_result.npz"
  345. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  346. def test_random_crop_08_py():
  347. """
  348. Test RandomCrop op with py_transforms: padding_mode is Border.EDGE,
  349. expected to pass
  350. """
  351. logger.info("test_random_crop_08_py")
  352. ds.config.set_seed(0)
  353. ds.config.set_num_parallel_workers(1)
  354. # Generate dataset
  355. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  356. # Note: The padding_mode is Border.EDGE.
  357. transforms = [
  358. py_vision.Decode(),
  359. py_vision.RandomCrop(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE),
  360. py_vision.ToTensor()
  361. ]
  362. transform = py_vision.ComposeOp(transforms)
  363. data = data.map(input_columns=["image"], operations=transform())
  364. filename = "random_crop_08_py_result.npz"
  365. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  366. def test_random_crop_09():
  367. """
  368. Test RandomCrop op: invalid type of input image (not PIL), expected to raise TypeError
  369. """
  370. logger.info("test_random_crop_09")
  371. ds.config.set_seed(0)
  372. ds.config.set_num_parallel_workers(1)
  373. # Generate dataset
  374. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  375. transforms = [
  376. py_vision.Decode(),
  377. py_vision.ToTensor(),
  378. # Note: if input is not PIL image, TypeError will raise
  379. py_vision.RandomCrop(512)
  380. ]
  381. transform = py_vision.ComposeOp(transforms)
  382. try:
  383. data = data.map(input_columns=["image"], operations=transform())
  384. image_list = []
  385. for item in data.create_dict_iterator():
  386. image = item["image"]
  387. image_list.append(image.shape)
  388. except Exception as e:
  389. logger.info("Got an exception in DE: {}".format(str(e)))
  390. assert "should be PIL Image" in str(e)
  391. def test_random_crop_comp(plot=False):
  392. """
  393. Test RandomCrop and compare between python and c image augmentation
  394. """
  395. logger.info("Test RandomCrop with c_transform and py_transform comparison")
  396. ds.config.set_seed(0)
  397. ds.config.set_num_parallel_workers(1)
  398. cropped_size = 512
  399. # First dataset
  400. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  401. random_crop_op = c_vision.RandomCrop(cropped_size)
  402. decode_op = c_vision.Decode()
  403. data1 = data1.map(input_columns=["image"], operations=decode_op)
  404. data1 = data1.map(input_columns=["image"], operations=random_crop_op)
  405. # Second dataset
  406. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  407. transforms = [
  408. py_vision.Decode(),
  409. py_vision.RandomCrop(cropped_size),
  410. py_vision.ToTensor()
  411. ]
  412. transform = py_vision.ComposeOp(transforms)
  413. data2 = data2.map(input_columns=["image"], operations=transform())
  414. image_c_cropped = []
  415. image_py_cropped = []
  416. for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
  417. c_image = item1["image"]
  418. py_image = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
  419. image_c_cropped.append(c_image)
  420. image_py_cropped.append(py_image)
  421. if plot:
  422. visualize(image_c_cropped, image_py_cropped)
  423. if __name__ == "__main__":
  424. test_random_crop_01_c()
  425. test_random_crop_02_c()
  426. test_random_crop_03_c()
  427. test_random_crop_04_c()
  428. test_random_crop_05_c()
  429. test_random_crop_06_c()
  430. test_random_crop_07_c()
  431. test_random_crop_08_c()
  432. test_random_crop_01_py()
  433. test_random_crop_02_py()
  434. test_random_crop_03_py()
  435. test_random_crop_04_py()
  436. test_random_crop_05_py()
  437. test_random_crop_06_py()
  438. test_random_crop_07_py()
  439. test_random_crop_08_py()
  440. test_random_crop_09()
  441. test_random_crop_op(True)
  442. test_random_crop_comp(True)