You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_random_crop.py 21 kB

5 years ago

  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing RandomCrop op in DE
  17. """
  18. import numpy as np
  19. import mindspore.dataset.transforms.vision.c_transforms as c_vision
  20. import mindspore.dataset.transforms.vision.py_transforms as py_vision
  21. import mindspore.dataset.transforms.vision.utils as mode
  22. import mindspore.dataset as ds
  23. from mindspore import log as logger
  24. from util import save_and_check_md5, visualize, config_get_set_seed, \
  25. config_get_set_num_parallel_workers
  26. GENERATE_GOLDEN = False
  27. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  28. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  29. def test_random_crop_op_c(plot=False):
  30. """
  31. Test RandomCrop Op in c transforms
  32. """
  33. logger.info("test_random_crop_op_c")
  34. # First dataset
  35. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  36. random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
  37. decode_op = c_vision.Decode()
  38. data1 = data1.map(input_columns=["image"], operations=decode_op)
  39. data1 = data1.map(input_columns=["image"], operations=random_crop_op)
  40. # Second dataset
  41. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  42. data2 = data2.map(input_columns=["image"], operations=decode_op)
  43. image_cropped = []
  44. image = []
  45. for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
  46. image1 = item1["image"]
  47. image2 = item2["image"]
  48. image_cropped.append(image1)
  49. image.append(image2)
  50. if plot:
  51. visualize(image, image_cropped)
  52. def test_random_crop_op_py(plot=False):
  53. """
  54. Test RandomCrop op in py transforms
  55. """
  56. logger.info("test_random_crop_op_py")
  57. # First dataset
  58. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  59. transforms1 = [
  60. py_vision.Decode(),
  61. py_vision.RandomCrop([512, 512], [200, 200, 200, 200]),
  62. py_vision.ToTensor()
  63. ]
  64. transform1 = py_vision.ComposeOp(transforms1)
  65. data1 = data1.map(input_columns=["image"], operations=transform1())
  66. # Second dataset
  67. # Second dataset for comparison
  68. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  69. transforms2 = [
  70. py_vision.Decode(),
  71. py_vision.ToTensor()
  72. ]
  73. transform2 = py_vision.ComposeOp(transforms2)
  74. data2 = data2.map(input_columns=["image"], operations=transform2())
  75. crop_images = []
  76. original_images = []
  77. for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
  78. crop = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
  79. original = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
  80. crop_images.append(crop)
  81. original_images.append(original)
  82. if plot:
  83. visualize(original_images, crop_images)
  84. def test_random_crop_01_c():
  85. """
  86. Test RandomCrop op with c_transforms: size is a single integer, expected to pass
  87. """
  88. logger.info("test_random_crop_01_c")
  89. original_seed = config_get_set_seed(0)
  90. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  91. # Generate dataset
  92. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  93. # Note: If size is an int, a square crop of size (size, size) is returned.
  94. random_crop_op = c_vision.RandomCrop(512)
  95. decode_op = c_vision.Decode()
  96. data = data.map(input_columns=["image"], operations=decode_op)
  97. data = data.map(input_columns=["image"], operations=random_crop_op)
  98. filename = "random_crop_01_c_result.npz"
  99. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  100. # Restore config setting
  101. ds.config.set_seed(original_seed)
  102. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  103. def test_random_crop_01_py():
  104. """
  105. Test RandomCrop op with py_transforms: size is a single integer, expected to pass
  106. """
  107. logger.info("test_random_crop_01_py")
  108. original_seed = config_get_set_seed(0)
  109. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  110. # Generate dataset
  111. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  112. # Note: If size is an int, a square crop of size (size, size) is returned.
  113. transforms = [
  114. py_vision.Decode(),
  115. py_vision.RandomCrop(512),
  116. py_vision.ToTensor()
  117. ]
  118. transform = py_vision.ComposeOp(transforms)
  119. data = data.map(input_columns=["image"], operations=transform())
  120. filename = "random_crop_01_py_result.npz"
  121. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  122. # Restore config setting
  123. ds.config.set_seed(original_seed)
  124. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  125. def test_random_crop_02_c():
  126. """
  127. Test RandomCrop op with c_transforms: size is a list/tuple with length 2, expected to pass
  128. """
  129. logger.info("test_random_crop_02_c")
  130. original_seed = config_get_set_seed(0)
  131. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  132. # Generate dataset
  133. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  134. # Note: If size is a sequence of length 2, it should be (height, width).
  135. random_crop_op = c_vision.RandomCrop([512, 375])
  136. decode_op = c_vision.Decode()
  137. data = data.map(input_columns=["image"], operations=decode_op)
  138. data = data.map(input_columns=["image"], operations=random_crop_op)
  139. filename = "random_crop_02_c_result.npz"
  140. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  141. # Restore config setting
  142. ds.config.set_seed(original_seed)
  143. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  144. def test_random_crop_02_py():
  145. """
  146. Test RandomCrop op with py_transforms: size is a list/tuple with length 2, expected to pass
  147. """
  148. logger.info("test_random_crop_02_py")
  149. original_seed = config_get_set_seed(0)
  150. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  151. # Generate dataset
  152. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  153. # Note: If size is a sequence of length 2, it should be (height, width).
  154. transforms = [
  155. py_vision.Decode(),
  156. py_vision.RandomCrop([512, 375]),
  157. py_vision.ToTensor()
  158. ]
  159. transform = py_vision.ComposeOp(transforms)
  160. data = data.map(input_columns=["image"], operations=transform())
  161. filename = "random_crop_02_py_result.npz"
  162. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  163. # Restore config setting
  164. ds.config.set_seed(original_seed)
  165. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  166. def test_random_crop_03_c():
  167. """
  168. Test RandomCrop op with c_transforms: input image size == crop size, expected to pass
  169. """
  170. logger.info("test_random_crop_03_c")
  171. original_seed = config_get_set_seed(0)
  172. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  173. # Generate dataset
  174. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  175. # Note: The size of the image is 4032*2268
  176. random_crop_op = c_vision.RandomCrop([2268, 4032])
  177. decode_op = c_vision.Decode()
  178. data = data.map(input_columns=["image"], operations=decode_op)
  179. data = data.map(input_columns=["image"], operations=random_crop_op)
  180. filename = "random_crop_03_c_result.npz"
  181. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  182. # Restore config setting
  183. ds.config.set_seed(original_seed)
  184. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  185. def test_random_crop_03_py():
  186. """
  187. Test RandomCrop op with py_transforms: input image size == crop size, expected to pass
  188. """
  189. logger.info("test_random_crop_03_py")
  190. original_seed = config_get_set_seed(0)
  191. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  192. # Generate dataset
  193. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  194. # Note: The size of the image is 4032*2268
  195. transforms = [
  196. py_vision.Decode(),
  197. py_vision.RandomCrop([2268, 4032]),
  198. py_vision.ToTensor()
  199. ]
  200. transform = py_vision.ComposeOp(transforms)
  201. data = data.map(input_columns=["image"], operations=transform())
  202. filename = "random_crop_03_py_result.npz"
  203. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  204. # Restore config setting
  205. ds.config.set_seed(original_seed)
  206. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  207. def test_random_crop_04_c():
  208. """
  209. Test RandomCrop op with c_transforms: input image size < crop size, expected to fail
  210. """
  211. logger.info("test_random_crop_04_c")
  212. # Generate dataset
  213. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  214. # Note: The size of the image is 4032*2268
  215. random_crop_op = c_vision.RandomCrop([2268, 4033])
  216. decode_op = c_vision.Decode()
  217. data = data.map(input_columns=["image"], operations=decode_op)
  218. data = data.map(input_columns=["image"], operations=random_crop_op)
  219. try:
  220. data.create_dict_iterator().get_next()
  221. except RuntimeError as e:
  222. logger.info("Got an exception in DE: {}".format(str(e)))
  223. assert "Crop size is greater than the image dim" in str(e)
  224. def test_random_crop_04_py():
  225. """
  226. Test RandomCrop op with py_transforms:
  227. input image size < crop size, expected to fail
  228. """
  229. logger.info("test_random_crop_04_py")
  230. # Generate dataset
  231. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  232. # Note: The size of the image is 4032*2268
  233. transforms = [
  234. py_vision.Decode(),
  235. py_vision.RandomCrop([2268, 4033]),
  236. py_vision.ToTensor()
  237. ]
  238. transform = py_vision.ComposeOp(transforms)
  239. data = data.map(input_columns=["image"], operations=transform())
  240. try:
  241. data.create_dict_iterator().get_next()
  242. except RuntimeError as e:
  243. logger.info("Got an exception in DE: {}".format(str(e)))
  244. def test_random_crop_05_c():
  245. """
  246. Test RandomCrop op with c_transforms:
  247. input image size < crop size but pad_if_needed is enabled,
  248. expected to pass
  249. """
  250. logger.info("test_random_crop_05_c")
  251. original_seed = config_get_set_seed(0)
  252. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  253. # Generate dataset
  254. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  255. # Note: The size of the image is 4032*2268
  256. random_crop_op = c_vision.RandomCrop([2268, 4033], [200, 200, 200, 200], pad_if_needed=True)
  257. decode_op = c_vision.Decode()
  258. data = data.map(input_columns=["image"], operations=decode_op)
  259. data = data.map(input_columns=["image"], operations=random_crop_op)
  260. filename = "random_crop_05_c_result.npz"
  261. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  262. # Restore config setting
  263. ds.config.set_seed(original_seed)
  264. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  265. def test_random_crop_05_py():
  266. """
  267. Test RandomCrop op with py_transforms:
  268. input image size < crop size but pad_if_needed is enabled,
  269. expected to pass
  270. """
  271. logger.info("test_random_crop_05_py")
  272. original_seed = config_get_set_seed(0)
  273. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  274. # Generate dataset
  275. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  276. # Note: The size of the image is 4032*2268
  277. transforms = [
  278. py_vision.Decode(),
  279. py_vision.RandomCrop([2268, 4033], [200, 200, 200, 200], pad_if_needed=True),
  280. py_vision.ToTensor()
  281. ]
  282. transform = py_vision.ComposeOp(transforms)
  283. data = data.map(input_columns=["image"], operations=transform())
  284. filename = "random_crop_05_py_result.npz"
  285. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  286. # Restore config setting
  287. ds.config.set_seed(original_seed)
  288. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  289. def test_random_crop_06_c():
  290. """
  291. Test RandomCrop op with c_transforms:
  292. invalid size, expected to raise TypeError
  293. """
  294. logger.info("test_random_crop_06_c")
  295. # Generate dataset
  296. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  297. try:
  298. # Note: if size is neither an int nor a list of length 2, an exception will raise
  299. random_crop_op = c_vision.RandomCrop([512, 512, 375])
  300. decode_op = c_vision.Decode()
  301. data = data.map(input_columns=["image"], operations=decode_op)
  302. data = data.map(input_columns=["image"], operations=random_crop_op)
  303. except TypeError as e:
  304. logger.info("Got an exception in DE: {}".format(str(e)))
  305. assert "Size should be a single integer" in str(e)
  306. def test_random_crop_06_py():
  307. """
  308. Test RandomCrop op with py_transforms:
  309. invalid size, expected to raise TypeError
  310. """
  311. logger.info("test_random_crop_06_py")
  312. # Generate dataset
  313. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  314. try:
  315. # Note: if size is neither an int nor a list of length 2, an exception will raise
  316. transforms = [
  317. py_vision.Decode(),
  318. py_vision.RandomCrop([512, 512, 375]),
  319. py_vision.ToTensor()
  320. ]
  321. transform = py_vision.ComposeOp(transforms)
  322. data = data.map(input_columns=["image"], operations=transform())
  323. except TypeError as e:
  324. logger.info("Got an exception in DE: {}".format(str(e)))
  325. assert "Size should be a single integer" in str(e)
  326. def test_random_crop_07_c():
  327. """
  328. Test RandomCrop op with c_transforms:
  329. padding_mode is Border.CONSTANT and fill_value is 255 (White),
  330. expected to pass
  331. """
  332. logger.info("test_random_crop_07_c")
  333. original_seed = config_get_set_seed(0)
  334. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  335. # Generate dataset
  336. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  337. # Note: The padding_mode is default as Border.CONSTANT and set filling color to be white.
  338. random_crop_op = c_vision.RandomCrop(512, [200, 200, 200, 200], fill_value=(255, 255, 255))
  339. decode_op = c_vision.Decode()
  340. data = data.map(input_columns=["image"], operations=decode_op)
  341. data = data.map(input_columns=["image"], operations=random_crop_op)
  342. filename = "random_crop_07_c_result.npz"
  343. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  344. # Restore config setting
  345. ds.config.set_seed(original_seed)
  346. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  347. def test_random_crop_07_py():
  348. """
  349. Test RandomCrop op with py_transforms:
  350. padding_mode is Border.CONSTANT and fill_value is 255 (White),
  351. expected to pass
  352. """
  353. logger.info("test_random_crop_07_py")
  354. original_seed = config_get_set_seed(0)
  355. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  356. # Generate dataset
  357. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  358. # Note: The padding_mode is default as Border.CONSTANT and set filling color to be white.
  359. transforms = [
  360. py_vision.Decode(),
  361. py_vision.RandomCrop(512, [200, 200, 200, 200], fill_value=(255, 255, 255)),
  362. py_vision.ToTensor()
  363. ]
  364. transform = py_vision.ComposeOp(transforms)
  365. data = data.map(input_columns=["image"], operations=transform())
  366. filename = "random_crop_07_py_result.npz"
  367. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  368. # Restore config setting
  369. ds.config.set_seed(original_seed)
  370. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  371. def test_random_crop_08_c():
  372. """
  373. Test RandomCrop op with c_transforms: padding_mode is Border.EDGE,
  374. expected to pass
  375. """
  376. logger.info("test_random_crop_08_c")
  377. original_seed = config_get_set_seed(0)
  378. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  379. # Generate dataset
  380. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  381. # Note: The padding_mode is Border.EDGE.
  382. random_crop_op = c_vision.RandomCrop(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE)
  383. decode_op = c_vision.Decode()
  384. data = data.map(input_columns=["image"], operations=decode_op)
  385. data = data.map(input_columns=["image"], operations=random_crop_op)
  386. filename = "random_crop_08_c_result.npz"
  387. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  388. # Restore config setting
  389. ds.config.set_seed(original_seed)
  390. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  391. def test_random_crop_08_py():
  392. """
  393. Test RandomCrop op with py_transforms: padding_mode is Border.EDGE,
  394. expected to pass
  395. """
  396. logger.info("test_random_crop_08_py")
  397. original_seed = config_get_set_seed(0)
  398. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  399. # Generate dataset
  400. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  401. # Note: The padding_mode is Border.EDGE.
  402. transforms = [
  403. py_vision.Decode(),
  404. py_vision.RandomCrop(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE),
  405. py_vision.ToTensor()
  406. ]
  407. transform = py_vision.ComposeOp(transforms)
  408. data = data.map(input_columns=["image"], operations=transform())
  409. filename = "random_crop_08_py_result.npz"
  410. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  411. # Restore config setting
  412. ds.config.set_seed(original_seed)
  413. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  414. def test_random_crop_09():
  415. """
  416. Test RandomCrop op: invalid type of input image (not PIL), expected to raise TypeError
  417. """
  418. logger.info("test_random_crop_09")
  419. # Generate dataset
  420. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  421. transforms = [
  422. py_vision.Decode(),
  423. py_vision.ToTensor(),
  424. # Note: if input is not PIL image, TypeError will raise
  425. py_vision.RandomCrop(512)
  426. ]
  427. transform = py_vision.ComposeOp(transforms)
  428. data = data.map(input_columns=["image"], operations=transform())
  429. try:
  430. data.create_dict_iterator().get_next()
  431. except RuntimeError as e:
  432. logger.info("Got an exception in DE: {}".format(str(e)))
  433. assert "should be PIL Image" in str(e)
  434. def test_random_crop_comp(plot=False):
  435. """
  436. Test RandomCrop and compare between python and c image augmentation
  437. """
  438. logger.info("Test RandomCrop with c_transform and py_transform comparison")
  439. cropped_size = 512
  440. # First dataset
  441. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  442. random_crop_op = c_vision.RandomCrop(cropped_size)
  443. decode_op = c_vision.Decode()
  444. data1 = data1.map(input_columns=["image"], operations=decode_op)
  445. data1 = data1.map(input_columns=["image"], operations=random_crop_op)
  446. # Second dataset
  447. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  448. transforms = [
  449. py_vision.Decode(),
  450. py_vision.RandomCrop(cropped_size),
  451. py_vision.ToTensor()
  452. ]
  453. transform = py_vision.ComposeOp(transforms)
  454. data2 = data2.map(input_columns=["image"], operations=transform())
  455. image_c_cropped = []
  456. image_py_cropped = []
  457. for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
  458. c_image = item1["image"]
  459. py_image = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
  460. image_c_cropped.append(c_image)
  461. image_py_cropped.append(py_image)
  462. if plot:
  463. visualize(image_c_cropped, image_py_cropped)
  464. if __name__ == "__main__":
  465. test_random_crop_01_c()
  466. test_random_crop_02_c()
  467. test_random_crop_03_c()
  468. test_random_crop_04_c()
  469. test_random_crop_05_c()
  470. test_random_crop_06_c()
  471. test_random_crop_07_c()
  472. test_random_crop_08_c()
  473. test_random_crop_01_py()
  474. test_random_crop_02_py()
  475. test_random_crop_03_py()
  476. test_random_crop_04_py()
  477. test_random_crop_05_py()
  478. test_random_crop_06_py()
  479. test_random_crop_07_py()
  480. test_random_crop_08_py()
  481. test_random_crop_09()
  482. test_random_crop_op_c(True)
  483. test_random_crop_op_py(True)
  484. test_random_crop_comp(True)