You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.py 16 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import hashlib
  16. import json
  17. import os
  18. import itertools
  19. from enum import Enum
  20. import numpy as np
  21. import matplotlib.pyplot as plt
  22. import matplotlib.patches as patches
  23. # import jsbeautifier
  24. import mindspore.dataset as ds
  25. from mindspore import log as logger
  26. # These are list of plot title in different visualize modes
  27. PLOT_TITLE_DICT = {
  28. 1: ["Original image", "Transformed image"],
  29. 2: ["c_transform image", "py_transform image"]
  30. }
  31. SAVE_JSON = False
  32. def _save_golden(cur_dir, golden_ref_dir, result_dict):
  33. """
  34. Save the dictionary values as the golden result in .npz file
  35. """
  36. logger.info("cur_dir is {}".format(cur_dir))
  37. logger.info("golden_ref_dir is {}".format(golden_ref_dir))
  38. np.savez(golden_ref_dir, np.array(list(result_dict.values())))
  39. def _save_golden_dict(cur_dir, golden_ref_dir, result_dict):
  40. """
  41. Save the dictionary (both keys and values) as the golden result in .npz file
  42. """
  43. logger.info("cur_dir is {}".format(cur_dir))
  44. logger.info("golden_ref_dir is {}".format(golden_ref_dir))
  45. np.savez(golden_ref_dir, np.array(list(result_dict.items())))
  46. def _compare_to_golden(golden_ref_dir, result_dict):
  47. """
  48. Compare as numpy arrays the test result to the golden result
  49. """
  50. test_array = np.array(list(result_dict.values()))
  51. golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0']
  52. np.testing.assert_array_equal(test_array, golden_array)
  53. def _compare_to_golden_dict(golden_ref_dir, result_dict):
  54. """
  55. Compare as dictionaries the test result to the golden result
  56. """
  57. golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0']
  58. np.testing.assert_equal(result_dict, dict(golden_array))
  59. def _save_json(filename, parameters, result_dict):
  60. """
  61. Save the result dictionary in json file
  62. """
  63. fout = open(filename[:-3] + "json", "w")
  64. options = jsbeautifier.default_options()
  65. options.indent_size = 2
  66. out_dict = {**parameters, **{"columns": result_dict}}
  67. fout.write(jsbeautifier.beautify(json.dumps(out_dict), options))
  68. def save_and_check_dict(data, filename, generate_golden=False):
  69. """
  70. Save the dataset dictionary and compare (as dictionary) with golden file.
  71. Use create_dict_iterator to access the dataset.
  72. """
  73. num_iter = 0
  74. result_dict = {}
  75. for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  76. for data_key in list(item.keys()):
  77. if data_key not in result_dict:
  78. result_dict[data_key] = []
  79. result_dict[data_key].append(item[data_key].tolist())
  80. num_iter += 1
  81. logger.info("Number of data in ds1: {}".format(num_iter))
  82. cur_dir = os.path.dirname(os.path.realpath(__file__))
  83. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  84. if generate_golden:
  85. # Save as the golden result
  86. _save_golden_dict(cur_dir, golden_ref_dir, result_dict)
  87. _compare_to_golden_dict(golden_ref_dir, result_dict)
  88. if SAVE_JSON:
  89. # Save result to a json file for inspection
  90. parameters = {"params": {}}
  91. _save_json(filename, parameters, result_dict)
  92. def save_and_check_md5(data, filename, generate_golden=False):
  93. """
  94. Save the dataset dictionary and compare (as dictionary) with golden file (md5).
  95. Use create_dict_iterator to access the dataset.
  96. """
  97. num_iter = 0
  98. result_dict = {}
  99. for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  100. for data_key in list(item.keys()):
  101. if data_key not in result_dict:
  102. result_dict[data_key] = []
  103. # save the md5 as numpy array
  104. result_dict[data_key].append(np.frombuffer(hashlib.md5(item[data_key]).digest(), dtype='<f4'))
  105. num_iter += 1
  106. logger.info("Number of data in ds1: {}".format(num_iter))
  107. cur_dir = os.path.dirname(os.path.realpath(__file__))
  108. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  109. if generate_golden:
  110. # Save as the golden result
  111. _save_golden_dict(cur_dir, golden_ref_dir, result_dict)
  112. _compare_to_golden_dict(golden_ref_dir, result_dict)
  113. def save_and_check_tuple(data, parameters, filename, generate_golden=False):
  114. """
  115. Save the dataset dictionary and compare (as numpy array) with golden file.
  116. Use create_tuple_iterator to access the dataset.
  117. """
  118. num_iter = 0
  119. result_dict = {}
  120. for item in data.create_tuple_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  121. for data_key, _ in enumerate(item):
  122. if data_key not in result_dict:
  123. result_dict[data_key] = []
  124. result_dict[data_key].append(item[data_key].tolist())
  125. num_iter += 1
  126. logger.info("Number of data in data1: {}".format(num_iter))
  127. cur_dir = os.path.dirname(os.path.realpath(__file__))
  128. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  129. if generate_golden:
  130. # Save as the golden result
  131. _save_golden(cur_dir, golden_ref_dir, result_dict)
  132. _compare_to_golden(golden_ref_dir, result_dict)
  133. if SAVE_JSON:
  134. # Save result to a json file for inspection
  135. _save_json(filename, parameters, result_dict)
  136. def config_get_set_seed(seed_new):
  137. """
  138. Get and return the original configuration seed value.
  139. Set the new configuration seed value.
  140. """
  141. seed_original = ds.config.get_seed()
  142. ds.config.set_seed(seed_new)
  143. logger.info("seed: original = {} new = {} ".format(seed_original, seed_new))
  144. return seed_original
  145. def config_get_set_num_parallel_workers(num_parallel_workers_new):
  146. """
  147. Get and return the original configuration num_parallel_workers value.
  148. Set the new configuration num_parallel_workers value.
  149. """
  150. num_parallel_workers_original = ds.config.get_num_parallel_workers()
  151. ds.config.set_num_parallel_workers(num_parallel_workers_new)
  152. logger.info("num_parallel_workers: original = {} new = {} ".format(num_parallel_workers_original,
  153. num_parallel_workers_new))
  154. return num_parallel_workers_original
  155. def diff_mse(in1, in2):
  156. mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean()
  157. return mse * 100
  158. def diff_me(in1, in2):
  159. mse = (np.abs(in1.astype(float) - in2.astype(float))).mean()
  160. return mse / 255 * 100
  161. def visualize_one_channel_dataset(images_original, images_transformed, labels):
  162. """
  163. Helper function to visualize one channel grayscale images
  164. """
  165. num_samples = len(images_original)
  166. for i in range(num_samples):
  167. plt.subplot(2, num_samples, i + 1)
  168. # Note: Use squeeze() to convert (H, W, 1) images to (H, W)
  169. plt.imshow(images_original[i].squeeze(), cmap=plt.cm.gray)
  170. plt.title(PLOT_TITLE_DICT[1][0] + ":" + str(labels[i]))
  171. plt.subplot(2, num_samples, i + num_samples + 1)
  172. plt.imshow(images_transformed[i].squeeze(), cmap=plt.cm.gray)
  173. plt.title(PLOT_TITLE_DICT[1][1] + ":" + str(labels[i]))
  174. plt.show()
  175. def visualize_list(image_list_1, image_list_2, visualize_mode=1):
  176. """
  177. visualizes a list of images using DE op
  178. """
  179. plot_title = PLOT_TITLE_DICT[visualize_mode]
  180. num = len(image_list_1)
  181. for i in range(num):
  182. plt.subplot(2, num, i + 1)
  183. plt.imshow(image_list_1[i])
  184. plt.title(plot_title[0])
  185. plt.subplot(2, num, i + num + 1)
  186. plt.imshow(image_list_2[i])
  187. plt.title(plot_title[1])
  188. plt.show()
  189. def visualize_image(image_original, image_de, mse=None, image_lib=None):
  190. """
  191. visualizes one example image with optional input: mse, image using 3rd party op.
  192. If three images are passing in, different image is calculated by 2nd and 3rd images.
  193. """
  194. num = 2
  195. if image_lib is not None:
  196. num += 1
  197. if mse is not None:
  198. num += 1
  199. plt.subplot(1, num, 1)
  200. plt.imshow(image_original)
  201. plt.title("Original image")
  202. plt.subplot(1, num, 2)
  203. plt.imshow(image_de)
  204. plt.title("DE Op image")
  205. if image_lib is not None:
  206. plt.subplot(1, num, 3)
  207. plt.imshow(image_lib)
  208. plt.title("Lib Op image")
  209. if mse is not None:
  210. plt.subplot(1, num, 4)
  211. plt.imshow(image_de - image_lib)
  212. plt.title("Diff image,\n mse : {}".format(mse))
  213. elif mse is not None:
  214. plt.subplot(1, num, 3)
  215. plt.imshow(image_original - image_de)
  216. plt.title("Diff image,\n mse : {}".format(mse))
  217. plt.show()
  218. def visualize_with_bounding_boxes(orig, aug, annot_name="bbox", plot_rows=3):
  219. """
  220. Take a list of un-augmented and augmented images with "bbox" bounding boxes
  221. Plot images to compare test correct BBox augment functionality
  222. :param orig: list of original images and bboxes (without aug)
  223. :param aug: list of augmented images and bboxes
  224. :param annot_name: the dict key for bboxes in data, e.g "bbox" (COCO) / "bbox" (VOC)
  225. :param plot_rows: number of rows on plot (rows = samples on one plot)
  226. :return: None
  227. """
  228. def add_bounding_boxes(ax, bboxes):
  229. for bbox in bboxes:
  230. rect = patches.Rectangle((bbox[0], bbox[1]),
  231. bbox[2] * 0.997, bbox[3] * 0.997,
  232. linewidth=1.80, edgecolor='r', facecolor='none')
  233. # Add the patch to the Axes
  234. # Params to Rectangle slightly modified to prevent drawing overflow
  235. ax.add_patch(rect)
  236. # Quick check to confirm correct input parameters
  237. if not isinstance(orig, list) or not isinstance(aug, list):
  238. return
  239. if len(orig) != len(aug) or not orig:
  240. return
  241. batch_size = int(len(orig) / plot_rows) # creates batches of images to plot together
  242. split_point = batch_size * plot_rows
  243. orig, aug = np.array(orig), np.array(aug)
  244. if len(orig) > plot_rows:
  245. # Create batches of required size and add remainder to last batch
  246. orig = np.split(orig[:split_point], batch_size) + (
  247. [orig[split_point:]] if (split_point < orig.shape[0]) else []) # check to avoid empty arrays being added
  248. aug = np.split(aug[:split_point], batch_size) + ([aug[split_point:]] if (split_point < aug.shape[0]) else [])
  249. else:
  250. orig = [orig]
  251. aug = [aug]
  252. for ix, allData in enumerate(zip(orig, aug)):
  253. base_ix = ix * plot_rows # current batch starting index
  254. curPlot = len(allData[0])
  255. fig, axs = plt.subplots(curPlot, 2)
  256. fig.tight_layout(pad=1.5)
  257. for x, (dataA, dataB) in enumerate(zip(allData[0], allData[1])):
  258. cur_ix = base_ix + x
  259. # select plotting axes based on number of image rows on plot - else case when 1 row
  260. (axA, axB) = (axs[x, 0], axs[x, 1]) if (curPlot > 1) else (axs[0], axs[1])
  261. axA.imshow(dataA["image"])
  262. add_bounding_boxes(axA, dataA[annot_name])
  263. axA.title.set_text("Original" + str(cur_ix + 1))
  264. axB.imshow(dataB["image"])
  265. add_bounding_boxes(axB, dataB[annot_name])
  266. axB.title.set_text("Augmented" + str(cur_ix + 1))
  267. logger.info("Original **\n{} : {}".format(str(cur_ix + 1), dataA[annot_name]))
  268. logger.info("Augmented **\n{} : {}\n".format(str(cur_ix + 1), dataB[annot_name]))
  269. plt.show()
  270. class InvalidBBoxType(Enum):
  271. """
  272. Defines Invalid Bounding Bbox types for test cases
  273. """
  274. WidthOverflow = 1
  275. HeightOverflow = 2
  276. NegativeXY = 3
  277. WrongShape = 4
  278. def check_bad_bbox(data, test_op, invalid_bbox_type, expected_error):
  279. """
  280. :param data: de object detection pipeline
  281. :param test_op: Augmentation Op to test on image
  282. :param invalid_bbox_type: type of bad box
  283. :param expected_error: error expected to get due to bad box
  284. :return: None
  285. """
  286. def add_bad_bbox(img, bboxes, invalid_bbox_type_):
  287. """
  288. Used to generate erroneous bounding box examples on given img.
  289. :param img: image where the bounding boxes are.
  290. :param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format
  291. :param box_type_: type of bad box
  292. :return: bboxes with bad examples added
  293. """
  294. height = img.shape[0]
  295. width = img.shape[1]
  296. if invalid_bbox_type_ == InvalidBBoxType.WidthOverflow:
  297. # use box that overflows on width
  298. return img, np.array([[0, 0, width + 1, height, 0, 0, 0]]).astype(np.float32)
  299. if invalid_bbox_type_ == InvalidBBoxType.HeightOverflow:
  300. # use box that overflows on height
  301. return img, np.array([[0, 0, width, height + 1, 0, 0, 0]]).astype(np.float32)
  302. if invalid_bbox_type_ == InvalidBBoxType.NegativeXY:
  303. # use box with negative xy
  304. return img, np.array([[-10, -10, width, height, 0, 0, 0]]).astype(np.float32)
  305. if invalid_bbox_type_ == InvalidBBoxType.WrongShape:
  306. # use box that has incorrect shape
  307. return img, np.array([[0, 0, width - 1]]).astype(np.float32)
  308. return img, bboxes
  309. try:
  310. # map to use selected invalid bounding box type
  311. data = data.map(operations=lambda img, bboxes: add_bad_bbox(img, bboxes, invalid_bbox_type),
  312. input_columns=["image", "bbox"],
  313. output_columns=["image", "bbox"],
  314. column_order=["image", "bbox"])
  315. # map to apply ops
  316. data = data.map(operations=[test_op], input_columns=["image", "bbox"],
  317. output_columns=["image", "bbox"],
  318. column_order=["image", "bbox"]) # Add column for "bbox"
  319. for _, _ in enumerate(data.create_dict_iterator(num_epochs=1, output_numpy=True)):
  320. break
  321. except RuntimeError as error:
  322. logger.info("Got an exception in DE: {}".format(str(error)))
  323. assert expected_error in str(error)
  324. # return true if datasets are equal
  325. def dataset_equal(data1, data2, mse_threshold):
  326. if data1.get_dataset_size() != data2.get_dataset_size():
  327. return False
  328. equal = True
  329. for item1, item2 in itertools.zip_longest(data1, data2):
  330. for column1, column2 in itertools.zip_longest(item1, item2):
  331. mse = diff_mse(column1.asnumpy(), column2.asnumpy())
  332. if mse > mse_threshold:
  333. equal = False
  334. break
  335. if not equal:
  336. break
  337. return equal
  338. # return true if datasets are equal after modification to target
  339. # params: data_unchanged - dataset kept unchanged
  340. # data_target - dataset to be modified by foo
  341. # mse_threshold - maximum allowable value of mse
  342. # foo - function applied to data_target columns BEFORE compare
  343. # foo_args - arguments passed into foo
  344. def dataset_equal_with_function(data_unchanged, data_target, mse_threshold, foo, *foo_args):
  345. if data_unchanged.get_dataset_size() != data_target.get_dataset_size():
  346. return False
  347. equal = True
  348. for item1, item2 in itertools.zip_longest(data_unchanged, data_target):
  349. for column1, column2 in itertools.zip_longest(item1, item2):
  350. # note the function is to be applied to the second dataset
  351. column2 = foo(column2.asnumpy(), *foo_args)
  352. mse = diff_mse(column1.asnumpy(), column2)
  353. if mse > mse_threshold:
  354. equal = False
  355. break
  356. if not equal:
  357. break
  358. return equal