You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.py 7.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import json
  16. import os
  17. import numpy as np
  18. import matplotlib.pyplot as plt
  19. import hashlib
  20. #import jsbeautifier
  21. from mindspore import log as logger
  22. COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
  23. "col_sint16", "col_sint32", "col_sint64"]
  24. def save_golden(cur_dir, golden_ref_dir, result_dict):
  25. """
  26. Save the dictionary values as the golden result in .npz file
  27. """
  28. logger.info("cur_dir is {}".format(cur_dir))
  29. logger.info("golden_ref_dir is {}".format(golden_ref_dir))
  30. np.savez(golden_ref_dir, np.array(list(result_dict.values())))
  31. def save_golden_dict(cur_dir, golden_ref_dir, result_dict):
  32. """
  33. Save the dictionary (both keys and values) as the golden result in .npz file
  34. """
  35. logger.info("cur_dir is {}".format(cur_dir))
  36. logger.info("golden_ref_dir is {}".format(golden_ref_dir))
  37. np.savez(golden_ref_dir, np.array(list(result_dict.items())))
  38. def save_golden_md5(cur_dir, golden_ref_dir, result_dict):
  39. """
  40. Save the dictionary (both keys and values) as the golden result in .npz file
  41. """
  42. logger.info("cur_dir is {}".format(cur_dir))
  43. logger.info("golden_ref_dir is {}".format(golden_ref_dir))
  44. np.savez(golden_ref_dir, np.array(list(result_dict.items())))
  45. def compare_to_golden(golden_ref_dir, result_dict):
  46. """
  47. Compare as numpy arrays the test result to the golden result
  48. """
  49. test_array = np.array(list(result_dict.values()))
  50. golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0']
  51. assert np.array_equal(test_array, golden_array)
  52. def compare_to_golden_dict(golden_ref_dir, result_dict):
  53. """
  54. Compare as dictionaries the test result to the golden result
  55. """
  56. golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0']
  57. np.testing.assert_equal (result_dict, dict(golden_array))
  58. # assert result_dict == dict(golden_array)
  59. def save_json(filename, parameters, result_dict):
  60. """
  61. Save the result dictionary in json file
  62. """
  63. fout = open(filename[:-3] + "json", "w")
  64. options = jsbeautifier.default_options()
  65. options.indent_size = 2
  66. out_dict = {**parameters, **{"columns": result_dict}}
  67. fout.write(jsbeautifier.beautify(json.dumps(out_dict), options))
  68. def save_and_check(data, parameters, filename, generate_golden=False):
  69. """
  70. Save the dataset dictionary and compare (as numpy array) with golden file.
  71. Use create_dict_iterator to access the dataset.
  72. """
  73. num_iter = 0
  74. result_dict = {}
  75. for column_name in COLUMNS:
  76. result_dict[column_name] = []
  77. for item in data.create_dict_iterator(): # each data is a dictionary
  78. for data_key in list(item.keys()):
  79. if data_key not in result_dict:
  80. result_dict[data_key] = []
  81. result_dict[data_key].append(item[data_key].tolist())
  82. num_iter += 1
  83. logger.info("Number of data in data1: {}".format(num_iter))
  84. cur_dir = os.path.dirname(os.path.realpath(__file__))
  85. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  86. if generate_golden:
  87. # Save as the golden result
  88. save_golden(cur_dir, golden_ref_dir, result_dict)
  89. compare_to_golden(golden_ref_dir, result_dict)
  90. # Save to a json file for inspection
  91. # save_json(filename, parameters, result_dict)
  92. def save_and_check_dict(data, parameters, filename, generate_golden=False):
  93. """
  94. Save the dataset dictionary and compare (as dictionary) with golden file.
  95. Use create_dict_iterator to access the dataset.
  96. """
  97. num_iter = 0
  98. result_dict = {}
  99. for item in data.create_dict_iterator(): # each data is a dictionary
  100. for data_key in list(item.keys()):
  101. if data_key not in result_dict:
  102. result_dict[data_key] = []
  103. result_dict[data_key].append(item[data_key].tolist())
  104. num_iter += 1
  105. logger.info("Number of data in ds1: {}".format(num_iter))
  106. cur_dir = os.path.dirname(os.path.realpath(__file__))
  107. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  108. if generate_golden:
  109. # Save as the golden result
  110. save_golden_dict(cur_dir, golden_ref_dir, result_dict)
  111. compare_to_golden_dict(golden_ref_dir, result_dict)
  112. # Save to a json file for inspection
  113. # save_json(filename, parameters, result_dict)
  114. def save_and_check_md5(data, parameters, filename, generate_golden=False):
  115. """
  116. Save the dataset dictionary and compare (as dictionary) with golden file (md5).
  117. Use create_dict_iterator to access the dataset.
  118. """
  119. num_iter = 0
  120. result_dict = {}
  121. for item in data.create_dict_iterator(): # each data is a dictionary
  122. for data_key in list(item.keys()):
  123. if data_key not in result_dict:
  124. result_dict[data_key] = []
  125. # save the md5 as numpy array
  126. result_dict[data_key].append(np.frombuffer(hashlib.md5(item[data_key]).digest(), dtype='<f4'))
  127. num_iter += 1
  128. logger.info("Number of data in ds1: {}".format(num_iter))
  129. cur_dir = os.path.dirname(os.path.realpath(__file__))
  130. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  131. if generate_golden:
  132. # Save as the golden result
  133. save_golden_dict(cur_dir, golden_ref_dir, result_dict)
  134. compare_to_golden_dict(golden_ref_dir, result_dict)
  135. def ordered_save_and_check(data, parameters, filename, generate_golden=False):
  136. """
  137. Save the dataset dictionary and compare (as numpy array) with golden file.
  138. Use create_tuple_iterator to access the dataset.
  139. """
  140. num_iter = 0
  141. result_dict = {}
  142. for item in data.create_tuple_iterator(): # each data is a dictionary
  143. for data_key in range(0, len(item)):
  144. if data_key not in result_dict:
  145. result_dict[data_key] = []
  146. result_dict[data_key].append(item[data_key].tolist())
  147. num_iter += 1
  148. logger.info("Number of data in data1: {}".format(num_iter))
  149. cur_dir = os.path.dirname(os.path.realpath(__file__))
  150. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  151. if generate_golden:
  152. # Save as the golden result
  153. save_golden(cur_dir, golden_ref_dir, result_dict)
  154. compare_to_golden(golden_ref_dir, result_dict)
  155. # Save to a json file for inspection
  156. # save_json(filename, parameters, result_dict)
  157. def diff_mse(in1, in2):
  158. mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean()
  159. return mse * 100
  160. def diff_me(in1, in2):
  161. mse = (np.abs(in1.astype(float) - in2.astype(float))).mean()
  162. return mse / 255 * 100
  163. def diff_ssim(in1, in2):
  164. from skimage.measure import compare_ssim as ssim
  165. val = ssim(in1, in2, multichannel=True)
  166. return (1 - val) * 100
  167. def visualize(image_original, image_transformed):
  168. """
  169. visualizes the image using DE op and Numpy op
  170. """
  171. num = len(image_cropped)
  172. for i in range(num):
  173. plt.subplot(2, num, i + 1)
  174. plt.imshow(image_original[i])
  175. plt.title("Original image")
  176. plt.subplot(2, num, i + num + 1)
  177. plt.imshow(image_cropped[i])
  178. plt.title("Transformed image")
  179. plt.show()