You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.py 9.9 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import hashlib
  16. import json
  17. import os
  18. import matplotlib.pyplot as plt
  19. import numpy as np
  20. # import jsbeautifier
  21. import mindspore.dataset as ds
  22. from mindspore import log as logger
  23. # These are the column names defined in the testTFTestAllTypes dataset
  24. COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
  25. "col_sint16", "col_sint32", "col_sint64"]
  26. # These are list of plot title in different visualize modes
  27. PLOT_TITLE_DICT = {
  28. 1: ["Original image", "Transformed image"],
  29. 2: ["c_transform image", "py_transform image"]
  30. }
  31. SAVE_JSON = False
  32. def _save_golden(cur_dir, golden_ref_dir, result_dict):
  33. """
  34. Save the dictionary values as the golden result in .npz file
  35. """
  36. logger.info("cur_dir is {}".format(cur_dir))
  37. logger.info("golden_ref_dir is {}".format(golden_ref_dir))
  38. np.savez(golden_ref_dir, np.array(list(result_dict.values())))
  39. def _save_golden_dict(cur_dir, golden_ref_dir, result_dict):
  40. """
  41. Save the dictionary (both keys and values) as the golden result in .npz file
  42. """
  43. logger.info("cur_dir is {}".format(cur_dir))
  44. logger.info("golden_ref_dir is {}".format(golden_ref_dir))
  45. np.savez(golden_ref_dir, np.array(list(result_dict.items())))
  46. def _compare_to_golden(golden_ref_dir, result_dict):
  47. """
  48. Compare as numpy arrays the test result to the golden result
  49. """
  50. test_array = np.array(list(result_dict.values()))
  51. golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0']
  52. assert np.array_equal(test_array, golden_array)
  53. def _compare_to_golden_dict(golden_ref_dir, result_dict):
  54. """
  55. Compare as dictionaries the test result to the golden result
  56. """
  57. golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0']
  58. np.testing.assert_equal(result_dict, dict(golden_array))
  59. def _save_json(filename, parameters, result_dict):
  60. """
  61. Save the result dictionary in json file
  62. """
  63. fout = open(filename[:-3] + "json", "w")
  64. options = jsbeautifier.default_options()
  65. options.indent_size = 2
  66. out_dict = {**parameters, **{"columns": result_dict}}
  67. fout.write(jsbeautifier.beautify(json.dumps(out_dict), options))
  68. def save_and_check(data, parameters, filename, generate_golden=False):
  69. """
  70. Save the dataset dictionary and compare (as numpy array) with golden file.
  71. Use create_dict_iterator to access the dataset.
  72. Note: save_and_check() is deprecated; use save_and_check_dict().
  73. """
  74. num_iter = 0
  75. result_dict = {}
  76. for column_name in COLUMNS:
  77. result_dict[column_name] = []
  78. for item in data.create_dict_iterator(): # each data is a dictionary
  79. for data_key in list(item.keys()):
  80. if data_key not in result_dict:
  81. result_dict[data_key] = []
  82. result_dict[data_key].append(item[data_key].tolist())
  83. num_iter += 1
  84. logger.info("Number of data in data1: {}".format(num_iter))
  85. cur_dir = os.path.dirname(os.path.realpath(__file__))
  86. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  87. if generate_golden:
  88. # Save as the golden result
  89. _save_golden(cur_dir, golden_ref_dir, result_dict)
  90. _compare_to_golden(golden_ref_dir, result_dict)
  91. if SAVE_JSON:
  92. # Save result to a json file for inspection
  93. _save_json(filename, parameters, result_dict)
  94. def save_and_check_dict(data, filename, generate_golden=False):
  95. """
  96. Save the dataset dictionary and compare (as dictionary) with golden file.
  97. Use create_dict_iterator to access the dataset.
  98. """
  99. num_iter = 0
  100. result_dict = {}
  101. for item in data.create_dict_iterator(): # each data is a dictionary
  102. for data_key in list(item.keys()):
  103. if data_key not in result_dict:
  104. result_dict[data_key] = []
  105. result_dict[data_key].append(item[data_key].tolist())
  106. num_iter += 1
  107. logger.info("Number of data in ds1: {}".format(num_iter))
  108. cur_dir = os.path.dirname(os.path.realpath(__file__))
  109. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  110. if generate_golden:
  111. # Save as the golden result
  112. _save_golden_dict(cur_dir, golden_ref_dir, result_dict)
  113. _compare_to_golden_dict(golden_ref_dir, result_dict)
  114. if SAVE_JSON:
  115. # Save result to a json file for inspection
  116. parameters = {"params": {}}
  117. _save_json(filename, parameters, result_dict)
  118. def save_and_check_md5(data, filename, generate_golden=False):
  119. """
  120. Save the dataset dictionary and compare (as dictionary) with golden file (md5).
  121. Use create_dict_iterator to access the dataset.
  122. """
  123. num_iter = 0
  124. result_dict = {}
  125. for item in data.create_dict_iterator(): # each data is a dictionary
  126. for data_key in list(item.keys()):
  127. if data_key not in result_dict:
  128. result_dict[data_key] = []
  129. # save the md5 as numpy array
  130. result_dict[data_key].append(np.frombuffer(hashlib.md5(item[data_key]).digest(), dtype='<f4'))
  131. num_iter += 1
  132. logger.info("Number of data in ds1: {}".format(num_iter))
  133. cur_dir = os.path.dirname(os.path.realpath(__file__))
  134. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  135. if generate_golden:
  136. # Save as the golden result
  137. _save_golden_dict(cur_dir, golden_ref_dir, result_dict)
  138. _compare_to_golden_dict(golden_ref_dir, result_dict)
  139. def save_and_check_tuple(data, parameters, filename, generate_golden=False):
  140. """
  141. Save the dataset dictionary and compare (as numpy array) with golden file.
  142. Use create_tuple_iterator to access the dataset.
  143. """
  144. num_iter = 0
  145. result_dict = {}
  146. for item in data.create_tuple_iterator(): # each data is a dictionary
  147. for data_key, _ in enumerate(item):
  148. if data_key not in result_dict:
  149. result_dict[data_key] = []
  150. result_dict[data_key].append(item[data_key].tolist())
  151. num_iter += 1
  152. logger.info("Number of data in data1: {}".format(num_iter))
  153. cur_dir = os.path.dirname(os.path.realpath(__file__))
  154. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  155. if generate_golden:
  156. # Save as the golden result
  157. _save_golden(cur_dir, golden_ref_dir, result_dict)
  158. _compare_to_golden(golden_ref_dir, result_dict)
  159. if SAVE_JSON:
  160. # Save result to a json file for inspection
  161. _save_json(filename, parameters, result_dict)
  162. def diff_mse(in1, in2):
  163. mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean()
  164. return mse * 100
  165. def diff_me(in1, in2):
  166. mse = (np.abs(in1.astype(float) - in2.astype(float))).mean()
  167. return mse / 255 * 100
  168. def visualize_list(image_list_1, image_list_2, visualize_mode=1):
  169. """
  170. visualizes a list of images using DE op
  171. """
  172. plot_title = PLOT_TITLE_DICT[visualize_mode]
  173. num = len(image_list_1)
  174. for i in range(num):
  175. plt.subplot(2, num, i + 1)
  176. plt.imshow(image_list_1[i])
  177. plt.title(plot_title[0])
  178. plt.subplot(2, num, i + num + 1)
  179. plt.imshow(image_list_2[i])
  180. plt.title(plot_title[1])
  181. plt.show()
  182. def visualize_image(image_original, image_de, mse=None, image_lib=None):
  183. """
  184. visualizes one example image with optional input: mse, image using 3rd party op.
  185. If three images are passing in, different image is calculated by 2nd and 3rd images.
  186. """
  187. num = 2
  188. if image_lib is not None:
  189. num += 1
  190. if mse is not None:
  191. num += 1
  192. plt.subplot(1, num, 1)
  193. plt.imshow(image_original)
  194. plt.title("Original image")
  195. plt.subplot(1, num, 2)
  196. plt.imshow(image_de)
  197. plt.title("DE Op image")
  198. if image_lib is not None:
  199. plt.subplot(1, num, 3)
  200. plt.imshow(image_lib)
  201. plt.title("Lib Op image")
  202. if mse is not None:
  203. plt.subplot(1, num, 4)
  204. plt.imshow(image_de - image_lib)
  205. plt.title("Diff image,\n mse : {}".format(mse))
  206. elif mse is not None:
  207. plt.subplot(1, num, 3)
  208. plt.imshow(image_original - image_de)
  209. plt.title("Diff image,\n mse : {}".format(mse))
  210. plt.show()
  211. def config_get_set_seed(seed_new):
  212. """
  213. Get and return the original configuration seed value.
  214. Set the new configuration seed value.
  215. """
  216. seed_original = ds.config.get_seed()
  217. ds.config.set_seed(seed_new)
  218. logger.info("seed: original = {} new = {} ".format(seed_original, seed_new))
  219. return seed_original
  220. def config_get_set_num_parallel_workers(num_parallel_workers_new):
  221. """
  222. Get and return the original configuration num_parallel_workers value.
  223. Set the new configuration num_parallel_workers value.
  224. """
  225. num_parallel_workers_original = ds.config.get_num_parallel_workers()
  226. ds.config.set_num_parallel_workers(num_parallel_workers_new)
  227. logger.info("num_parallel_workers: original = {} new = {} ".format(num_parallel_workers_original,
  228. num_parallel_workers_new))
  229. return num_parallel_workers_original