You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.py 18 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import hashlib
  16. import json
  17. import os
  18. import itertools
  19. from enum import Enum
  20. import numpy as np
  21. import matplotlib.pyplot as plt
  22. import matplotlib.patches as patches
  23. import PIL
  24. # import jsbeautifier
  25. import mindspore.dataset as ds
  26. from mindspore import log as logger
  27. # These are list of plot title in different visualize modes
  28. PLOT_TITLE_DICT = {
  29. 1: ["Original image", "Transformed image"],
  30. 2: ["c_transform image", "py_transform image"]
  31. }
  32. SAVE_JSON = False
  33. def _save_golden(cur_dir, golden_ref_dir, result_dict):
  34. """
  35. Save the dictionary values as the golden result in .npz file
  36. """
  37. logger.info("cur_dir is {}".format(cur_dir))
  38. logger.info("golden_ref_dir is {}".format(golden_ref_dir))
  39. np.savez(golden_ref_dir, np.array(list(result_dict.values())))
  40. def _save_golden_dict(cur_dir, golden_ref_dir, result_dict):
  41. """
  42. Save the dictionary (both keys and values) as the golden result in .npz file
  43. """
  44. logger.info("cur_dir is {}".format(cur_dir))
  45. logger.info("golden_ref_dir is {}".format(golden_ref_dir))
  46. np.savez(golden_ref_dir, np.array(list(result_dict.items())))
  47. def _compare_to_golden(golden_ref_dir, result_dict):
  48. """
  49. Compare as numpy arrays the test result to the golden result
  50. """
  51. test_array = np.array(list(result_dict.values()))
  52. golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0']
  53. np.testing.assert_array_equal(test_array, golden_array)
  54. def _compare_to_golden_dict(golden_ref_dir, result_dict, check_pillow_version=False):
  55. """
  56. Compare as dictionaries the test result to the golden result
  57. """
  58. golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0']
  59. # Note: The version of PILLOW that is used in Jenkins CI is compared with below
  60. if (not check_pillow_version or PIL.__version__ == '7.1.2'):
  61. np.testing.assert_equal(result_dict, dict(golden_array))
  62. elif PIL.__version__ >= '9.0.0':
  63. try:
  64. np.testing.assert_equal(result_dict, dict(golden_array))
  65. except AssertionError:
  66. logger.warning("Results from Pillow >= 9.0.0 is incompatibale with Pillow < 9.0.0, need more validation.")
  67. else:
  68. # Beware: If error, PILLOW version of golden results may be incompatible with current PILLOW version
  69. np.testing.assert_equal(result_dict, dict(golden_array),
  70. 'Items are not equal and problem may be due to PILLOW version incompatibility')
  71. def _save_json(filename, parameters, result_dict):
  72. """
  73. Save the result dictionary in json file
  74. """
  75. fout = open(filename[:-3] + "json", "w")
  76. options = jsbeautifier.default_options()
  77. options.indent_size = 2
  78. out_dict = {**parameters, **{"columns": result_dict}}
  79. fout.write(jsbeautifier.beautify(json.dumps(out_dict), options))
  80. def save_and_check_dict(data, filename, generate_golden=False):
  81. """
  82. Save the dataset dictionary and compare (as dictionary) with golden file.
  83. Use create_dict_iterator to access the dataset.
  84. """
  85. num_iter = 0
  86. result_dict = {}
  87. for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  88. for data_key in list(item.keys()):
  89. if data_key not in result_dict:
  90. result_dict[data_key] = []
  91. result_dict[data_key].append(item[data_key].tolist())
  92. num_iter += 1
  93. logger.info("Number of data in ds1: {}".format(num_iter))
  94. cur_dir = os.path.dirname(os.path.realpath(__file__))
  95. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  96. if generate_golden:
  97. # Save as the golden result
  98. _save_golden_dict(cur_dir, golden_ref_dir, result_dict)
  99. _compare_to_golden_dict(golden_ref_dir, result_dict, False)
  100. if SAVE_JSON:
  101. # Save result to a json file for inspection
  102. parameters = {"params": {}}
  103. _save_json(filename, parameters, result_dict)
  104. def save_and_check_md5(data, filename, generate_golden=False):
  105. """
  106. Save the dataset dictionary and compare (as dictionary) with golden file (md5).
  107. Use create_dict_iterator to access the dataset.
  108. """
  109. num_iter = 0
  110. result_dict = {}
  111. for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  112. for data_key in list(item.keys()):
  113. if data_key not in result_dict:
  114. result_dict[data_key] = []
  115. # save the md5 as numpy array
  116. result_dict[data_key].append(np.frombuffer(hashlib.md5(item[data_key]).digest(), dtype='<f4'))
  117. num_iter += 1
  118. logger.info("Number of data in ds1: {}".format(num_iter))
  119. cur_dir = os.path.dirname(os.path.realpath(__file__))
  120. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  121. if generate_golden:
  122. # Save as the golden result
  123. _save_golden_dict(cur_dir, golden_ref_dir, result_dict)
  124. _compare_to_golden_dict(golden_ref_dir, result_dict, True)
  125. def save_and_check_tuple(data, parameters, filename, generate_golden=False):
  126. """
  127. Save the dataset dictionary and compare (as numpy array) with golden file.
  128. Use create_tuple_iterator to access the dataset.
  129. """
  130. num_iter = 0
  131. result_dict = {}
  132. for item in data.create_tuple_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  133. for data_key, _ in enumerate(item):
  134. if data_key not in result_dict:
  135. result_dict[data_key] = []
  136. result_dict[data_key].append(item[data_key].tolist())
  137. num_iter += 1
  138. logger.info("Number of data in data1: {}".format(num_iter))
  139. cur_dir = os.path.dirname(os.path.realpath(__file__))
  140. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  141. if generate_golden:
  142. # Save as the golden result
  143. _save_golden(cur_dir, golden_ref_dir, result_dict)
  144. _compare_to_golden(golden_ref_dir, result_dict)
  145. if SAVE_JSON:
  146. # Save result to a json file for inspection
  147. _save_json(filename, parameters, result_dict)
  148. def config_get_set_seed(seed_new):
  149. """
  150. Get and return the original configuration seed value.
  151. Set the new configuration seed value.
  152. """
  153. seed_original = ds.config.get_seed()
  154. ds.config.set_seed(seed_new)
  155. logger.info("seed: original = {} new = {} ".format(seed_original, seed_new))
  156. return seed_original
  157. def config_get_set_num_parallel_workers(num_parallel_workers_new):
  158. """
  159. Get and return the original configuration num_parallel_workers value.
  160. Set the new configuration num_parallel_workers value.
  161. """
  162. num_parallel_workers_original = ds.config.get_num_parallel_workers()
  163. ds.config.set_num_parallel_workers(num_parallel_workers_new)
  164. logger.info("num_parallel_workers: original = {} new = {} ".format(num_parallel_workers_original,
  165. num_parallel_workers_new))
  166. return num_parallel_workers_original
  167. def diff_mse(in1, in2):
  168. mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean()
  169. return mse * 100
  170. def diff_me(in1, in2):
  171. mse = (np.abs(in1.astype(float) - in2.astype(float))).mean()
  172. return mse / 255 * 100
  173. def visualize_audio(waveform, expect_waveform):
  174. """
  175. Visualizes audio waveform.
  176. """
  177. plt.figure(1)
  178. plt.subplot(1, 3, 1)
  179. plt.imshow(waveform)
  180. plt.title("waveform")
  181. plt.subplot(1, 3, 2)
  182. plt.imshow(expect_waveform)
  183. plt.title("expect waveform")
  184. plt.subplot(1, 3, 3)
  185. plt.imshow(waveform - expect_waveform)
  186. plt.title("difference")
  187. plt.show()
  188. def visualize_one_channel_dataset(images_original, images_transformed, labels):
  189. """
  190. Helper function to visualize one channel grayscale images
  191. """
  192. num_samples = len(images_original)
  193. for i in range(num_samples):
  194. plt.subplot(2, num_samples, i + 1)
  195. # Note: Use squeeze() to convert (H, W, 1) images to (H, W)
  196. plt.imshow(images_original[i].squeeze(), cmap=plt.cm.gray)
  197. plt.title(PLOT_TITLE_DICT[1][0] + ":" + str(labels[i]))
  198. plt.subplot(2, num_samples, i + num_samples + 1)
  199. plt.imshow(images_transformed[i].squeeze(), cmap=plt.cm.gray)
  200. plt.title(PLOT_TITLE_DICT[1][1] + ":" + str(labels[i]))
  201. plt.show()
  202. def visualize_list(image_list_1, image_list_2, visualize_mode=1):
  203. """
  204. visualizes a list of images using DE op
  205. """
  206. plot_title = PLOT_TITLE_DICT[visualize_mode]
  207. num = len(image_list_1)
  208. for i in range(num):
  209. plt.subplot(2, num, i + 1)
  210. plt.imshow(image_list_1[i])
  211. plt.title(plot_title[0])
  212. plt.subplot(2, num, i + num + 1)
  213. plt.imshow(image_list_2[i])
  214. plt.title(plot_title[1])
  215. plt.show()
  216. def visualize_image(image_original, image_de, mse=None, image_lib=None):
  217. """
  218. visualizes one example image with optional input: mse, image using 3rd party op.
  219. If three images are passing in, different image is calculated by 2nd and 3rd images.
  220. """
  221. num = 2
  222. if image_lib is not None:
  223. num += 1
  224. if mse is not None:
  225. num += 1
  226. plt.subplot(1, num, 1)
  227. plt.imshow(image_original)
  228. plt.title("Original image")
  229. plt.subplot(1, num, 2)
  230. plt.imshow(image_de)
  231. plt.title("DE Op image")
  232. if image_lib is not None:
  233. plt.subplot(1, num, 3)
  234. plt.imshow(image_lib)
  235. plt.title("Lib Op image")
  236. if mse is not None:
  237. plt.subplot(1, num, 4)
  238. plt.imshow(image_de - image_lib)
  239. plt.title("Diff image,\n mse : {}".format(mse))
  240. elif mse is not None:
  241. plt.subplot(1, num, 3)
  242. plt.imshow(image_original - image_de)
  243. plt.title("Diff image,\n mse : {}".format(mse))
  244. plt.show()
  245. def visualize_with_bounding_boxes(orig, aug, annot_name="bbox", plot_rows=3):
  246. """
  247. Take a list of un-augmented and augmented images with "bbox" bounding boxes
  248. Plot images to compare test correct BBox augment functionality
  249. :param orig: list of original images and bboxes (without aug)
  250. :param aug: list of augmented images and bboxes
  251. :param annot_name: the dict key for bboxes in data, e.g "bbox" (COCO) / "bbox" (VOC)
  252. :param plot_rows: number of rows on plot (rows = samples on one plot)
  253. :return: None
  254. """
  255. def add_bounding_boxes(ax, bboxes):
  256. for bbox in bboxes:
  257. rect = patches.Rectangle((bbox[0], bbox[1]),
  258. bbox[2] * 0.997, bbox[3] * 0.997,
  259. linewidth=1.80, edgecolor='r', facecolor='none')
  260. # Add the patch to the Axes
  261. # Params to Rectangle slightly modified to prevent drawing overflow
  262. ax.add_patch(rect)
  263. # Quick check to confirm correct input parameters
  264. if not isinstance(orig, list) or not isinstance(aug, list):
  265. return
  266. if len(orig) != len(aug) or not orig:
  267. return
  268. batch_size = int(len(orig) / plot_rows) # creates batches of images to plot together
  269. split_point = batch_size * plot_rows
  270. orig, aug = np.array(orig), np.array(aug)
  271. if len(orig) > plot_rows:
  272. # Create batches of required size and add remainder to last batch
  273. orig = np.split(orig[:split_point], batch_size) + (
  274. [orig[split_point:]] if (split_point < orig.shape[0]) else []) # check to avoid empty arrays being added
  275. aug = np.split(aug[:split_point], batch_size) + ([aug[split_point:]] if (split_point < aug.shape[0]) else [])
  276. else:
  277. orig = [orig]
  278. aug = [aug]
  279. for ix, all_data in enumerate(zip(orig, aug)):
  280. base_ix = ix * plot_rows # current batch starting index
  281. cur_plot = len(all_data[0])
  282. fig, axs = plt.subplots(cur_plot, 2)
  283. fig.tight_layout(pad=1.5)
  284. for x, (data_a, data_b) in enumerate(zip(all_data[0], all_data[1])):
  285. cur_ix = base_ix + x
  286. # select plotting axes based on number of image rows on plot - else case when 1 row
  287. (ax_a, ax_b) = (axs[x, 0], axs[x, 1]) if (cur_plot > 1) else (axs[0], axs[1])
  288. ax_a.imshow(data_a["image"])
  289. add_bounding_boxes(ax_a, data_a[annot_name])
  290. ax_a.title.set_text("Original" + str(cur_ix + 1))
  291. ax_b.imshow(data_b["image"])
  292. add_bounding_boxes(ax_b, data_b[annot_name])
  293. ax_b.title.set_text("Augmented" + str(cur_ix + 1))
  294. logger.info("Original **\n{} : {}".format(str(cur_ix + 1), data_a[annot_name]))
  295. logger.info("Augmented **\n{} : {}\n".format(str(cur_ix + 1), data_b[annot_name]))
  296. plt.show()
  297. class InvalidBBoxType(Enum):
  298. """
  299. Defines Invalid Bounding Bbox types for test cases
  300. """
  301. WidthOverflow = 1
  302. HeightOverflow = 2
  303. NegativeXY = 3
  304. WrongShape = 4
  305. def check_bad_bbox(data, test_op, invalid_bbox_type, expected_error):
  306. """
  307. :param data: de object detection pipeline
  308. :param test_op: Augmentation Op to test on image
  309. :param invalid_bbox_type: type of bad box
  310. :param expected_error: error expected to get due to bad box
  311. :return: None
  312. """
  313. def add_bad_bbox(img, bboxes, invalid_bbox_type_):
  314. """
  315. Used to generate erroneous bounding box examples on given img.
  316. :param img: image where the bounding boxes are.
  317. :param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format
  318. :param box_type_: type of bad box
  319. :return: bboxes with bad examples added
  320. """
  321. height = img.shape[0]
  322. width = img.shape[1]
  323. if invalid_bbox_type_ == InvalidBBoxType.WidthOverflow:
  324. # use box that overflows on width
  325. return img, np.array([[0, 0, width + 1, height, 0, 0, 0]]).astype(np.float32)
  326. if invalid_bbox_type_ == InvalidBBoxType.HeightOverflow:
  327. # use box that overflows on height
  328. return img, np.array([[0, 0, width, height + 1, 0, 0, 0]]).astype(np.float32)
  329. if invalid_bbox_type_ == InvalidBBoxType.NegativeXY:
  330. # use box with negative xy
  331. return img, np.array([[-10, -10, width, height, 0, 0, 0]]).astype(np.float32)
  332. if invalid_bbox_type_ == InvalidBBoxType.WrongShape:
  333. # use box that has incorrect shape
  334. return img, np.array([[0, 0, width - 1]]).astype(np.float32)
  335. return img, bboxes
  336. try:
  337. # map to use selected invalid bounding box type
  338. data = data.map(operations=lambda img, bboxes: add_bad_bbox(img, bboxes, invalid_bbox_type),
  339. input_columns=["image", "bbox"],
  340. output_columns=["image", "bbox"],
  341. column_order=["image", "bbox"])
  342. # map to apply ops
  343. data = data.map(operations=[test_op], input_columns=["image", "bbox"],
  344. output_columns=["image", "bbox"],
  345. column_order=["image", "bbox"]) # Add column for "bbox"
  346. for _, _ in enumerate(data.create_dict_iterator(num_epochs=1, output_numpy=True)):
  347. break
  348. except RuntimeError as error:
  349. logger.info("Got an exception in DE: {}".format(str(error)))
  350. assert expected_error in str(error)
  351. # return true if datasets are equal
  352. def dataset_equal(data1, data2, mse_threshold):
  353. if data1.get_dataset_size() != data2.get_dataset_size():
  354. return False
  355. equal = True
  356. for item1, item2 in itertools.zip_longest(data1, data2):
  357. for column1, column2 in itertools.zip_longest(item1, item2):
  358. mse = diff_mse(column1.asnumpy(), column2.asnumpy())
  359. if mse > mse_threshold:
  360. equal = False
  361. break
  362. if not equal:
  363. break
  364. return equal
  365. # return true if datasets are equal after modification to target
  366. # params: data_unchanged - dataset kept unchanged
  367. # data_target - dataset to be modified by foo
  368. # mse_threshold - maximum allowable value of mse
  369. # foo - function applied to data_target columns BEFORE compare
  370. # foo_args - arguments passed into foo
  371. def dataset_equal_with_function(data_unchanged, data_target, mse_threshold, foo, *foo_args):
  372. if data_unchanged.get_dataset_size() != data_target.get_dataset_size():
  373. return False
  374. equal = True
  375. for item1, item2 in itertools.zip_longest(data_unchanged, data_target):
  376. for column1, column2 in itertools.zip_longest(item1, item2):
  377. # note the function is to be applied to the second dataset
  378. column2 = foo(column2.asnumpy(), *foo_args)
  379. mse = diff_mse(column1.asnumpy(), column2)
  380. if mse > mse_threshold:
  381. equal = False
  382. break
  383. if not equal:
  384. break
  385. return equal