You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.py 17 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import hashlib
  16. import json
  17. import os
  18. import itertools
  19. from enum import Enum
  20. import numpy as np
  21. import matplotlib.pyplot as plt
  22. import matplotlib.patches as patches
  23. import PIL
  24. # import jsbeautifier
  25. import mindspore.dataset as ds
  26. from mindspore import log as logger
  27. # These are list of plot title in different visualize modes
  28. PLOT_TITLE_DICT = {
  29. 1: ["Original image", "Transformed image"],
  30. 2: ["c_transform image", "py_transform image"]
  31. }
  32. SAVE_JSON = False
  33. def _save_golden(cur_dir, golden_ref_dir, result_dict):
  34. """
  35. Save the dictionary values as the golden result in .npz file
  36. """
  37. logger.info("cur_dir is {}".format(cur_dir))
  38. logger.info("golden_ref_dir is {}".format(golden_ref_dir))
  39. np.savez(golden_ref_dir, np.array(list(result_dict.values())))
  40. def _save_golden_dict(cur_dir, golden_ref_dir, result_dict):
  41. """
  42. Save the dictionary (both keys and values) as the golden result in .npz file
  43. """
  44. logger.info("cur_dir is {}".format(cur_dir))
  45. logger.info("golden_ref_dir is {}".format(golden_ref_dir))
  46. np.savez(golden_ref_dir, np.array(list(result_dict.items())))
  47. def _compare_to_golden(golden_ref_dir, result_dict):
  48. """
  49. Compare as numpy arrays the test result to the golden result
  50. """
  51. test_array = np.array(list(result_dict.values()))
  52. golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0']
  53. np.testing.assert_array_equal(test_array, golden_array)
  54. def _compare_to_golden_dict(golden_ref_dir, result_dict, check_pillow_version=False):
  55. """
  56. Compare as dictionaries the test result to the golden result
  57. """
  58. golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0']
  59. # Note: The version of PILLOW that is used in Jenkins CI is compared with below
  60. if (not check_pillow_version or PIL.__version__ == '7.1.2'):
  61. np.testing.assert_equal(result_dict, dict(golden_array))
  62. else:
  63. # Beware: If error, PILLOW version of golden results may be incompatible with current PILLOW version
  64. np.testing.assert_equal(result_dict, dict(golden_array),
  65. 'Items are not equal and problem may be due to PILLOW version incompatibility')
  66. def _save_json(filename, parameters, result_dict):
  67. """
  68. Save the result dictionary in json file
  69. """
  70. fout = open(filename[:-3] + "json", "w")
  71. options = jsbeautifier.default_options()
  72. options.indent_size = 2
  73. out_dict = {**parameters, **{"columns": result_dict}}
  74. fout.write(jsbeautifier.beautify(json.dumps(out_dict), options))
  75. def save_and_check_dict(data, filename, generate_golden=False):
  76. """
  77. Save the dataset dictionary and compare (as dictionary) with golden file.
  78. Use create_dict_iterator to access the dataset.
  79. """
  80. num_iter = 0
  81. result_dict = {}
  82. for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  83. for data_key in list(item.keys()):
  84. if data_key not in result_dict:
  85. result_dict[data_key] = []
  86. result_dict[data_key].append(item[data_key].tolist())
  87. num_iter += 1
  88. logger.info("Number of data in ds1: {}".format(num_iter))
  89. cur_dir = os.path.dirname(os.path.realpath(__file__))
  90. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  91. if generate_golden:
  92. # Save as the golden result
  93. _save_golden_dict(cur_dir, golden_ref_dir, result_dict)
  94. _compare_to_golden_dict(golden_ref_dir, result_dict, False)
  95. if SAVE_JSON:
  96. # Save result to a json file for inspection
  97. parameters = {"params": {}}
  98. _save_json(filename, parameters, result_dict)
  99. def save_and_check_md5(data, filename, generate_golden=False):
  100. """
  101. Save the dataset dictionary and compare (as dictionary) with golden file (md5).
  102. Use create_dict_iterator to access the dataset.
  103. """
  104. num_iter = 0
  105. result_dict = {}
  106. for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  107. for data_key in list(item.keys()):
  108. if data_key not in result_dict:
  109. result_dict[data_key] = []
  110. # save the md5 as numpy array
  111. result_dict[data_key].append(np.frombuffer(hashlib.md5(item[data_key]).digest(), dtype='<f4'))
  112. num_iter += 1
  113. logger.info("Number of data in ds1: {}".format(num_iter))
  114. cur_dir = os.path.dirname(os.path.realpath(__file__))
  115. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  116. if generate_golden:
  117. # Save as the golden result
  118. _save_golden_dict(cur_dir, golden_ref_dir, result_dict)
  119. _compare_to_golden_dict(golden_ref_dir, result_dict, True)
  120. def save_and_check_tuple(data, parameters, filename, generate_golden=False):
  121. """
  122. Save the dataset dictionary and compare (as numpy array) with golden file.
  123. Use create_tuple_iterator to access the dataset.
  124. """
  125. num_iter = 0
  126. result_dict = {}
  127. for item in data.create_tuple_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  128. for data_key, _ in enumerate(item):
  129. if data_key not in result_dict:
  130. result_dict[data_key] = []
  131. result_dict[data_key].append(item[data_key].tolist())
  132. num_iter += 1
  133. logger.info("Number of data in data1: {}".format(num_iter))
  134. cur_dir = os.path.dirname(os.path.realpath(__file__))
  135. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  136. if generate_golden:
  137. # Save as the golden result
  138. _save_golden(cur_dir, golden_ref_dir, result_dict)
  139. _compare_to_golden(golden_ref_dir, result_dict)
  140. if SAVE_JSON:
  141. # Save result to a json file for inspection
  142. _save_json(filename, parameters, result_dict)
  143. def config_get_set_seed(seed_new):
  144. """
  145. Get and return the original configuration seed value.
  146. Set the new configuration seed value.
  147. """
  148. seed_original = ds.config.get_seed()
  149. ds.config.set_seed(seed_new)
  150. logger.info("seed: original = {} new = {} ".format(seed_original, seed_new))
  151. return seed_original
  152. def config_get_set_num_parallel_workers(num_parallel_workers_new):
  153. """
  154. Get and return the original configuration num_parallel_workers value.
  155. Set the new configuration num_parallel_workers value.
  156. """
  157. num_parallel_workers_original = ds.config.get_num_parallel_workers()
  158. ds.config.set_num_parallel_workers(num_parallel_workers_new)
  159. logger.info("num_parallel_workers: original = {} new = {} ".format(num_parallel_workers_original,
  160. num_parallel_workers_new))
  161. return num_parallel_workers_original
  162. def diff_mse(in1, in2):
  163. mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean()
  164. return mse * 100
  165. def diff_me(in1, in2):
  166. mse = (np.abs(in1.astype(float) - in2.astype(float))).mean()
  167. return mse / 255 * 100
  168. def visualize_one_channel_dataset(images_original, images_transformed, labels):
  169. """
  170. Helper function to visualize one channel grayscale images
  171. """
  172. num_samples = len(images_original)
  173. for i in range(num_samples):
  174. plt.subplot(2, num_samples, i + 1)
  175. # Note: Use squeeze() to convert (H, W, 1) images to (H, W)
  176. plt.imshow(images_original[i].squeeze(), cmap=plt.cm.gray)
  177. plt.title(PLOT_TITLE_DICT[1][0] + ":" + str(labels[i]))
  178. plt.subplot(2, num_samples, i + num_samples + 1)
  179. plt.imshow(images_transformed[i].squeeze(), cmap=plt.cm.gray)
  180. plt.title(PLOT_TITLE_DICT[1][1] + ":" + str(labels[i]))
  181. plt.show()
  182. def visualize_list(image_list_1, image_list_2, visualize_mode=1):
  183. """
  184. visualizes a list of images using DE op
  185. """
  186. plot_title = PLOT_TITLE_DICT[visualize_mode]
  187. num = len(image_list_1)
  188. for i in range(num):
  189. plt.subplot(2, num, i + 1)
  190. plt.imshow(image_list_1[i])
  191. plt.title(plot_title[0])
  192. plt.subplot(2, num, i + num + 1)
  193. plt.imshow(image_list_2[i])
  194. plt.title(plot_title[1])
  195. plt.show()
  196. def visualize_image(image_original, image_de, mse=None, image_lib=None):
  197. """
  198. visualizes one example image with optional input: mse, image using 3rd party op.
  199. If three images are passing in, different image is calculated by 2nd and 3rd images.
  200. """
  201. num = 2
  202. if image_lib is not None:
  203. num += 1
  204. if mse is not None:
  205. num += 1
  206. plt.subplot(1, num, 1)
  207. plt.imshow(image_original)
  208. plt.title("Original image")
  209. plt.subplot(1, num, 2)
  210. plt.imshow(image_de)
  211. plt.title("DE Op image")
  212. if image_lib is not None:
  213. plt.subplot(1, num, 3)
  214. plt.imshow(image_lib)
  215. plt.title("Lib Op image")
  216. if mse is not None:
  217. plt.subplot(1, num, 4)
  218. plt.imshow(image_de - image_lib)
  219. plt.title("Diff image,\n mse : {}".format(mse))
  220. elif mse is not None:
  221. plt.subplot(1, num, 3)
  222. plt.imshow(image_original - image_de)
  223. plt.title("Diff image,\n mse : {}".format(mse))
  224. plt.show()
  225. def visualize_with_bounding_boxes(orig, aug, annot_name="bbox", plot_rows=3):
  226. """
  227. Take a list of un-augmented and augmented images with "bbox" bounding boxes
  228. Plot images to compare test correct BBox augment functionality
  229. :param orig: list of original images and bboxes (without aug)
  230. :param aug: list of augmented images and bboxes
  231. :param annot_name: the dict key for bboxes in data, e.g "bbox" (COCO) / "bbox" (VOC)
  232. :param plot_rows: number of rows on plot (rows = samples on one plot)
  233. :return: None
  234. """
  235. def add_bounding_boxes(ax, bboxes):
  236. for bbox in bboxes:
  237. rect = patches.Rectangle((bbox[0], bbox[1]),
  238. bbox[2] * 0.997, bbox[3] * 0.997,
  239. linewidth=1.80, edgecolor='r', facecolor='none')
  240. # Add the patch to the Axes
  241. # Params to Rectangle slightly modified to prevent drawing overflow
  242. ax.add_patch(rect)
  243. # Quick check to confirm correct input parameters
  244. if not isinstance(orig, list) or not isinstance(aug, list):
  245. return
  246. if len(orig) != len(aug) or not orig:
  247. return
  248. batch_size = int(len(orig) / plot_rows) # creates batches of images to plot together
  249. split_point = batch_size * plot_rows
  250. orig, aug = np.array(orig), np.array(aug)
  251. if len(orig) > plot_rows:
  252. # Create batches of required size and add remainder to last batch
  253. orig = np.split(orig[:split_point], batch_size) + (
  254. [orig[split_point:]] if (split_point < orig.shape[0]) else []) # check to avoid empty arrays being added
  255. aug = np.split(aug[:split_point], batch_size) + ([aug[split_point:]] if (split_point < aug.shape[0]) else [])
  256. else:
  257. orig = [orig]
  258. aug = [aug]
  259. for ix, allData in enumerate(zip(orig, aug)):
  260. base_ix = ix * plot_rows # current batch starting index
  261. curPlot = len(allData[0])
  262. fig, axs = plt.subplots(curPlot, 2)
  263. fig.tight_layout(pad=1.5)
  264. for x, (dataA, dataB) in enumerate(zip(allData[0], allData[1])):
  265. cur_ix = base_ix + x
  266. # select plotting axes based on number of image rows on plot - else case when 1 row
  267. (axA, axB) = (axs[x, 0], axs[x, 1]) if (curPlot > 1) else (axs[0], axs[1])
  268. axA.imshow(dataA["image"])
  269. add_bounding_boxes(axA, dataA[annot_name])
  270. axA.title.set_text("Original" + str(cur_ix + 1))
  271. axB.imshow(dataB["image"])
  272. add_bounding_boxes(axB, dataB[annot_name])
  273. axB.title.set_text("Augmented" + str(cur_ix + 1))
  274. logger.info("Original **\n{} : {}".format(str(cur_ix + 1), dataA[annot_name]))
  275. logger.info("Augmented **\n{} : {}\n".format(str(cur_ix + 1), dataB[annot_name]))
  276. plt.show()
  277. class InvalidBBoxType(Enum):
  278. """
  279. Defines Invalid Bounding Bbox types for test cases
  280. """
  281. WidthOverflow = 1
  282. HeightOverflow = 2
  283. NegativeXY = 3
  284. WrongShape = 4
  285. def check_bad_bbox(data, test_op, invalid_bbox_type, expected_error):
  286. """
  287. :param data: de object detection pipeline
  288. :param test_op: Augmentation Op to test on image
  289. :param invalid_bbox_type: type of bad box
  290. :param expected_error: error expected to get due to bad box
  291. :return: None
  292. """
  293. def add_bad_bbox(img, bboxes, invalid_bbox_type_):
  294. """
  295. Used to generate erroneous bounding box examples on given img.
  296. :param img: image where the bounding boxes are.
  297. :param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format
  298. :param box_type_: type of bad box
  299. :return: bboxes with bad examples added
  300. """
  301. height = img.shape[0]
  302. width = img.shape[1]
  303. if invalid_bbox_type_ == InvalidBBoxType.WidthOverflow:
  304. # use box that overflows on width
  305. return img, np.array([[0, 0, width + 1, height, 0, 0, 0]]).astype(np.float32)
  306. if invalid_bbox_type_ == InvalidBBoxType.HeightOverflow:
  307. # use box that overflows on height
  308. return img, np.array([[0, 0, width, height + 1, 0, 0, 0]]).astype(np.float32)
  309. if invalid_bbox_type_ == InvalidBBoxType.NegativeXY:
  310. # use box with negative xy
  311. return img, np.array([[-10, -10, width, height, 0, 0, 0]]).astype(np.float32)
  312. if invalid_bbox_type_ == InvalidBBoxType.WrongShape:
  313. # use box that has incorrect shape
  314. return img, np.array([[0, 0, width - 1]]).astype(np.float32)
  315. return img, bboxes
  316. try:
  317. # map to use selected invalid bounding box type
  318. data = data.map(operations=lambda img, bboxes: add_bad_bbox(img, bboxes, invalid_bbox_type),
  319. input_columns=["image", "bbox"],
  320. output_columns=["image", "bbox"],
  321. column_order=["image", "bbox"])
  322. # map to apply ops
  323. data = data.map(operations=[test_op], input_columns=["image", "bbox"],
  324. output_columns=["image", "bbox"],
  325. column_order=["image", "bbox"]) # Add column for "bbox"
  326. for _, _ in enumerate(data.create_dict_iterator(num_epochs=1, output_numpy=True)):
  327. break
  328. except RuntimeError as error:
  329. logger.info("Got an exception in DE: {}".format(str(error)))
  330. assert expected_error in str(error)
  331. # return true if datasets are equal
  332. def dataset_equal(data1, data2, mse_threshold):
  333. if data1.get_dataset_size() != data2.get_dataset_size():
  334. return False
  335. equal = True
  336. for item1, item2 in itertools.zip_longest(data1, data2):
  337. for column1, column2 in itertools.zip_longest(item1, item2):
  338. mse = diff_mse(column1.asnumpy(), column2.asnumpy())
  339. if mse > mse_threshold:
  340. equal = False
  341. break
  342. if not equal:
  343. break
  344. return equal
  345. # return true if datasets are equal after modification to target
  346. # params: data_unchanged - dataset kept unchanged
  347. # data_target - dataset to be modified by foo
  348. # mse_threshold - maximum allowable value of mse
  349. # foo - function applied to data_target columns BEFORE compare
  350. # foo_args - arguments passed into foo
  351. def dataset_equal_with_function(data_unchanged, data_target, mse_threshold, foo, *foo_args):
  352. if data_unchanged.get_dataset_size() != data_target.get_dataset_size():
  353. return False
  354. equal = True
  355. for item1, item2 in itertools.zip_longest(data_unchanged, data_target):
  356. for column1, column2 in itertools.zip_longest(item1, item2):
  357. # note the function is to be applied to the second dataset
  358. column2 = foo(column2.asnumpy(), *foo_args)
  359. mse = diff_mse(column1.asnumpy(), column2)
  360. if mse > mse_threshold:
  361. equal = False
  362. break
  363. if not equal:
  364. break
  365. return equal