You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.py 15 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import hashlib
  16. import json
  17. import os
  18. from enum import Enum
  19. import matplotlib.pyplot as plt
  20. import matplotlib.patches as patches
  21. import numpy as np
  22. # import jsbeautifier
  23. import mindspore.dataset as ds
  24. from mindspore import log as logger
  25. # These are the column names defined in the testTFTestAllTypes dataset
  26. COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
  27. "col_sint16", "col_sint32", "col_sint64"]
  28. # These are list of plot title in different visualize modes
  29. PLOT_TITLE_DICT = {
  30. 1: ["Original image", "Transformed image"],
  31. 2: ["c_transform image", "py_transform image"]
  32. }
  33. SAVE_JSON = False
  34. def _save_golden(cur_dir, golden_ref_dir, result_dict):
  35. """
  36. Save the dictionary values as the golden result in .npz file
  37. """
  38. logger.info("cur_dir is {}".format(cur_dir))
  39. logger.info("golden_ref_dir is {}".format(golden_ref_dir))
  40. np.savez(golden_ref_dir, np.array(list(result_dict.values())))
  41. def _save_golden_dict(cur_dir, golden_ref_dir, result_dict):
  42. """
  43. Save the dictionary (both keys and values) as the golden result in .npz file
  44. """
  45. logger.info("cur_dir is {}".format(cur_dir))
  46. logger.info("golden_ref_dir is {}".format(golden_ref_dir))
  47. np.savez(golden_ref_dir, np.array(list(result_dict.items())))
  48. def _compare_to_golden(golden_ref_dir, result_dict):
  49. """
  50. Compare as numpy arrays the test result to the golden result
  51. """
  52. test_array = np.array(list(result_dict.values()))
  53. golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0']
  54. assert np.array_equal(test_array, golden_array)
  55. def _compare_to_golden_dict(golden_ref_dir, result_dict):
  56. """
  57. Compare as dictionaries the test result to the golden result
  58. """
  59. golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0']
  60. np.testing.assert_equal(result_dict, dict(golden_array))
  61. def _save_json(filename, parameters, result_dict):
  62. """
  63. Save the result dictionary in json file
  64. """
  65. fout = open(filename[:-3] + "json", "w")
  66. options = jsbeautifier.default_options()
  67. options.indent_size = 2
  68. out_dict = {**parameters, **{"columns": result_dict}}
  69. fout.write(jsbeautifier.beautify(json.dumps(out_dict), options))
  70. def save_and_check(data, parameters, filename, generate_golden=False):
  71. """
  72. Save the dataset dictionary and compare (as numpy array) with golden file.
  73. Use create_dict_iterator to access the dataset.
  74. Note: save_and_check() is deprecated; use save_and_check_dict().
  75. """
  76. num_iter = 0
  77. result_dict = {}
  78. for column_name in COLUMNS:
  79. result_dict[column_name] = []
  80. for item in data.create_dict_iterator(): # each data is a dictionary
  81. for data_key in list(item.keys()):
  82. if data_key not in result_dict:
  83. result_dict[data_key] = []
  84. result_dict[data_key].append(item[data_key].tolist())
  85. num_iter += 1
  86. logger.info("Number of data in data1: {}".format(num_iter))
  87. cur_dir = os.path.dirname(os.path.realpath(__file__))
  88. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  89. if generate_golden:
  90. # Save as the golden result
  91. _save_golden(cur_dir, golden_ref_dir, result_dict)
  92. _compare_to_golden(golden_ref_dir, result_dict)
  93. if SAVE_JSON:
  94. # Save result to a json file for inspection
  95. _save_json(filename, parameters, result_dict)
  96. def save_and_check_dict(data, filename, generate_golden=False):
  97. """
  98. Save the dataset dictionary and compare (as dictionary) with golden file.
  99. Use create_dict_iterator to access the dataset.
  100. """
  101. num_iter = 0
  102. result_dict = {}
  103. for item in data.create_dict_iterator(): # each data is a dictionary
  104. for data_key in list(item.keys()):
  105. if data_key not in result_dict:
  106. result_dict[data_key] = []
  107. result_dict[data_key].append(item[data_key].tolist())
  108. num_iter += 1
  109. logger.info("Number of data in ds1: {}".format(num_iter))
  110. cur_dir = os.path.dirname(os.path.realpath(__file__))
  111. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  112. if generate_golden:
  113. # Save as the golden result
  114. _save_golden_dict(cur_dir, golden_ref_dir, result_dict)
  115. _compare_to_golden_dict(golden_ref_dir, result_dict)
  116. if SAVE_JSON:
  117. # Save result to a json file for inspection
  118. parameters = {"params": {}}
  119. _save_json(filename, parameters, result_dict)
  120. def save_and_check_md5(data, filename, generate_golden=False):
  121. """
  122. Save the dataset dictionary and compare (as dictionary) with golden file (md5).
  123. Use create_dict_iterator to access the dataset.
  124. """
  125. num_iter = 0
  126. result_dict = {}
  127. for item in data.create_dict_iterator(): # each data is a dictionary
  128. for data_key in list(item.keys()):
  129. if data_key not in result_dict:
  130. result_dict[data_key] = []
  131. # save the md5 as numpy array
  132. result_dict[data_key].append(np.frombuffer(hashlib.md5(item[data_key]).digest(), dtype='<f4'))
  133. num_iter += 1
  134. logger.info("Number of data in ds1: {}".format(num_iter))
  135. cur_dir = os.path.dirname(os.path.realpath(__file__))
  136. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  137. if generate_golden:
  138. # Save as the golden result
  139. _save_golden_dict(cur_dir, golden_ref_dir, result_dict)
  140. _compare_to_golden_dict(golden_ref_dir, result_dict)
  141. def save_and_check_tuple(data, parameters, filename, generate_golden=False):
  142. """
  143. Save the dataset dictionary and compare (as numpy array) with golden file.
  144. Use create_tuple_iterator to access the dataset.
  145. """
  146. num_iter = 0
  147. result_dict = {}
  148. for item in data.create_tuple_iterator(): # each data is a dictionary
  149. for data_key, _ in enumerate(item):
  150. if data_key not in result_dict:
  151. result_dict[data_key] = []
  152. result_dict[data_key].append(item[data_key].tolist())
  153. num_iter += 1
  154. logger.info("Number of data in data1: {}".format(num_iter))
  155. cur_dir = os.path.dirname(os.path.realpath(__file__))
  156. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  157. if generate_golden:
  158. # Save as the golden result
  159. _save_golden(cur_dir, golden_ref_dir, result_dict)
  160. _compare_to_golden(golden_ref_dir, result_dict)
  161. if SAVE_JSON:
  162. # Save result to a json file for inspection
  163. _save_json(filename, parameters, result_dict)
  164. def diff_mse(in1, in2):
  165. mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean()
  166. return mse * 100
  167. def diff_me(in1, in2):
  168. mse = (np.abs(in1.astype(float) - in2.astype(float))).mean()
  169. return mse / 255 * 100
  170. def visualize_list(image_list_1, image_list_2, visualize_mode=1):
  171. """
  172. visualizes a list of images using DE op
  173. """
  174. plot_title = PLOT_TITLE_DICT[visualize_mode]
  175. num = len(image_list_1)
  176. for i in range(num):
  177. plt.subplot(2, num, i + 1)
  178. plt.imshow(image_list_1[i])
  179. plt.title(plot_title[0])
  180. plt.subplot(2, num, i + num + 1)
  181. plt.imshow(image_list_2[i])
  182. plt.title(plot_title[1])
  183. plt.show()
  184. def visualize_image(image_original, image_de, mse=None, image_lib=None):
  185. """
  186. visualizes one example image with optional input: mse, image using 3rd party op.
  187. If three images are passing in, different image is calculated by 2nd and 3rd images.
  188. """
  189. num = 2
  190. if image_lib is not None:
  191. num += 1
  192. if mse is not None:
  193. num += 1
  194. plt.subplot(1, num, 1)
  195. plt.imshow(image_original)
  196. plt.title("Original image")
  197. plt.subplot(1, num, 2)
  198. plt.imshow(image_de)
  199. plt.title("DE Op image")
  200. if image_lib is not None:
  201. plt.subplot(1, num, 3)
  202. plt.imshow(image_lib)
  203. plt.title("Lib Op image")
  204. if mse is not None:
  205. plt.subplot(1, num, 4)
  206. plt.imshow(image_de - image_lib)
  207. plt.title("Diff image,\n mse : {}".format(mse))
  208. elif mse is not None:
  209. plt.subplot(1, num, 3)
  210. plt.imshow(image_original - image_de)
  211. plt.title("Diff image,\n mse : {}".format(mse))
  212. plt.show()
  213. def config_get_set_seed(seed_new):
  214. """
  215. Get and return the original configuration seed value.
  216. Set the new configuration seed value.
  217. """
  218. seed_original = ds.config.get_seed()
  219. ds.config.set_seed(seed_new)
  220. logger.info("seed: original = {} new = {} ".format(seed_original, seed_new))
  221. return seed_original
  222. def config_get_set_num_parallel_workers(num_parallel_workers_new):
  223. """
  224. Get and return the original configuration num_parallel_workers value.
  225. Set the new configuration num_parallel_workers value.
  226. """
  227. num_parallel_workers_original = ds.config.get_num_parallel_workers()
  228. ds.config.set_num_parallel_workers(num_parallel_workers_new)
  229. logger.info("num_parallel_workers: original = {} new = {} ".format(num_parallel_workers_original,
  230. num_parallel_workers_new))
  231. return num_parallel_workers_original
  232. def visualize_with_bounding_boxes(orig, aug, annot_name="annotation", plot_rows=3):
  233. """
  234. Take a list of un-augmented and augmented images with "annotation" bounding boxes
  235. Plot images to compare test correct BBox augment functionality
  236. :param orig: list of original images and bboxes (without aug)
  237. :param aug: list of augmented images and bboxes
  238. :param annot_name: the dict key for bboxes in data, e.g "bbox" (COCO) / "annotation" (VOC)
  239. :param plot_rows: number of rows on plot (rows = samples on one plot)
  240. :return: None
  241. """
  242. def add_bounding_boxes(ax, bboxes):
  243. for bbox in bboxes:
  244. rect = patches.Rectangle((bbox[0], bbox[1]),
  245. bbox[2]*0.997, bbox[3]*0.997,
  246. linewidth=1.80, edgecolor='r', facecolor='none')
  247. # Add the patch to the Axes
  248. # Params to Rectangle slightly modified to prevent drawing overflow
  249. ax.add_patch(rect)
  250. # Quick check to confirm correct input parameters
  251. if not isinstance(orig, list) or not isinstance(aug, list):
  252. return
  253. if len(orig) != len(aug) or not orig:
  254. return
  255. batch_size = int(len(orig) / plot_rows) # creates batches of images to plot together
  256. split_point = batch_size * plot_rows
  257. orig, aug = np.array(orig), np.array(aug)
  258. if len(orig) > plot_rows:
  259. # Create batches of required size and add remainder to last batch
  260. orig = np.split(orig[:split_point], batch_size) + (
  261. [orig[split_point:]] if (split_point < orig.shape[0]) else []) # check to avoid empty arrays being added
  262. aug = np.split(aug[:split_point], batch_size) + ([aug[split_point:]] if (split_point < aug.shape[0]) else [])
  263. else:
  264. orig = [orig]
  265. aug = [aug]
  266. for ix, allData in enumerate(zip(orig, aug)):
  267. base_ix = ix * plot_rows # current batch starting index
  268. curPlot = len(allData[0])
  269. fig, axs = plt.subplots(curPlot, 2)
  270. fig.tight_layout(pad=1.5)
  271. for x, (dataA, dataB) in enumerate(zip(allData[0], allData[1])):
  272. cur_ix = base_ix + x
  273. # select plotting axes based on number of image rows on plot - else case when 1 row
  274. (axA, axB) = (axs[x, 0], axs[x, 1]) if (curPlot > 1) else (axs[0], axs[1])
  275. axA.imshow(dataA["image"])
  276. add_bounding_boxes(axA, dataA[annot_name])
  277. axA.title.set_text("Original" + str(cur_ix+1))
  278. axB.imshow(dataB["image"])
  279. add_bounding_boxes(axB, dataB[annot_name])
  280. axB.title.set_text("Augmented" + str(cur_ix+1))
  281. logger.info("Original **\n{} : {}".format(str(cur_ix+1), dataA[annot_name]))
  282. logger.info("Augmented **\n{} : {}\n".format(str(cur_ix+1), dataB[annot_name]))
  283. plt.show()
  284. class InvalidBBoxType(Enum):
  285. """
  286. Defines Invalid Bounding Bbox types for test cases
  287. """
  288. WidthOverflow = 1
  289. HeightOverflow = 2
  290. NegativeXY = 3
  291. WrongShape = 4
  292. def check_bad_bbox(data, test_op, invalid_bbox_type, expected_error):
  293. """
  294. :param data: de object detection pipeline
  295. :param test_op: Augmentation Op to test on image
  296. :param invalid_bbox_type: type of bad box
  297. :param expected_error: error expected to get due to bad box
  298. :return: None
  299. """
  300. def add_bad_annotation(img, bboxes, invalid_bbox_type_):
  301. """
  302. Used to generate erroneous bounding box examples on given img.
  303. :param img: image where the bounding boxes are.
  304. :param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format
  305. :param box_type_: type of bad box
  306. :return: bboxes with bad examples added
  307. """
  308. height = img.shape[0]
  309. width = img.shape[1]
  310. if invalid_bbox_type_ == InvalidBBoxType.WidthOverflow:
  311. # use box that overflows on width
  312. return img, np.array([[0, 0, width + 1, height, 0, 0, 0]]).astype(np.float32)
  313. if invalid_bbox_type_ == InvalidBBoxType.HeightOverflow:
  314. # use box that overflows on height
  315. return img, np.array([[0, 0, width, height + 1, 0, 0, 0]]).astype(np.float32)
  316. if invalid_bbox_type_ == InvalidBBoxType.NegativeXY:
  317. # use box with negative xy
  318. return img, np.array([[-10, -10, width, height, 0, 0, 0]]).astype(np.float32)
  319. if invalid_bbox_type_ == InvalidBBoxType.WrongShape:
  320. # use box that has incorrect shape
  321. return img, np.array([[0, 0, width - 1]]).astype(np.float32)
  322. return img, bboxes
  323. try:
  324. # map to use selected invalid bounding box type
  325. data = data.map(input_columns=["image", "annotation"],
  326. output_columns=["image", "annotation"],
  327. columns_order=["image", "annotation"],
  328. operations=lambda img, bboxes: add_bad_annotation(img, bboxes, invalid_bbox_type))
  329. # map to apply ops
  330. data = data.map(input_columns=["image", "annotation"],
  331. output_columns=["image", "annotation"],
  332. columns_order=["image", "annotation"],
  333. operations=[test_op]) # Add column for "annotation"
  334. for _, _ in enumerate(data.create_dict_iterator()):
  335. break
  336. except RuntimeError as error:
  337. logger.info("Got an exception in DE: {}".format(str(error)))
  338. assert expected_error in str(error)