You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.py 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import json
  16. import os
  17. import hashlib
  18. import numpy as np
  19. import matplotlib.pyplot as plt
  20. #import jsbeautifier
  21. from mindspore import log as logger
  22. COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
  23. "col_sint16", "col_sint32", "col_sint64"]
  24. SAVE_JSON = False
  25. def save_golden(cur_dir, golden_ref_dir, result_dict):
  26. """
  27. Save the dictionary values as the golden result in .npz file
  28. """
  29. logger.info("cur_dir is {}".format(cur_dir))
  30. logger.info("golden_ref_dir is {}".format(golden_ref_dir))
  31. np.savez(golden_ref_dir, np.array(list(result_dict.values())))
  32. def save_golden_dict(cur_dir, golden_ref_dir, result_dict):
  33. """
  34. Save the dictionary (both keys and values) as the golden result in .npz file
  35. """
  36. logger.info("cur_dir is {}".format(cur_dir))
  37. logger.info("golden_ref_dir is {}".format(golden_ref_dir))
  38. np.savez(golden_ref_dir, np.array(list(result_dict.items())))
  39. def compare_to_golden(golden_ref_dir, result_dict):
  40. """
  41. Compare as numpy arrays the test result to the golden result
  42. """
  43. test_array = np.array(list(result_dict.values()))
  44. golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0']
  45. assert np.array_equal(test_array, golden_array)
  46. def compare_to_golden_dict(golden_ref_dir, result_dict):
  47. """
  48. Compare as dictionaries the test result to the golden result
  49. """
  50. golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0']
  51. np.testing.assert_equal(result_dict, dict(golden_array))
  52. # assert result_dict == dict(golden_array)
  53. def save_json(filename, parameters, result_dict):
  54. """
  55. Save the result dictionary in json file
  56. """
  57. fout = open(filename[:-3] + "json", "w")
  58. options = jsbeautifier.default_options()
  59. options.indent_size = 2
  60. out_dict = {**parameters, **{"columns": result_dict}}
  61. fout.write(jsbeautifier.beautify(json.dumps(out_dict), options))
  62. def save_and_check(data, parameters, filename, generate_golden=False):
  63. """
  64. Save the dataset dictionary and compare (as numpy array) with golden file.
  65. Use create_dict_iterator to access the dataset.
  66. """
  67. num_iter = 0
  68. result_dict = {}
  69. for column_name in COLUMNS:
  70. result_dict[column_name] = []
  71. for item in data.create_dict_iterator(): # each data is a dictionary
  72. for data_key in list(item.keys()):
  73. if data_key not in result_dict:
  74. result_dict[data_key] = []
  75. result_dict[data_key].append(item[data_key].tolist())
  76. num_iter += 1
  77. logger.info("Number of data in data1: {}".format(num_iter))
  78. cur_dir = os.path.dirname(os.path.realpath(__file__))
  79. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  80. if generate_golden:
  81. # Save as the golden result
  82. save_golden(cur_dir, golden_ref_dir, result_dict)
  83. compare_to_golden(golden_ref_dir, result_dict)
  84. if SAVE_JSON:
  85. # Save result to a json file for inspection
  86. save_json(filename, parameters, result_dict)
  87. def save_and_check_dict(data, filename, generate_golden=False):
  88. """
  89. Save the dataset dictionary and compare (as dictionary) with golden file.
  90. Use create_dict_iterator to access the dataset.
  91. """
  92. num_iter = 0
  93. result_dict = {}
  94. for item in data.create_dict_iterator(): # each data is a dictionary
  95. for data_key in list(item.keys()):
  96. if data_key not in result_dict:
  97. result_dict[data_key] = []
  98. result_dict[data_key].append(item[data_key].tolist())
  99. num_iter += 1
  100. logger.info("Number of data in ds1: {}".format(num_iter))
  101. cur_dir = os.path.dirname(os.path.realpath(__file__))
  102. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  103. if generate_golden:
  104. # Save as the golden result
  105. save_golden_dict(cur_dir, golden_ref_dir, result_dict)
  106. compare_to_golden_dict(golden_ref_dir, result_dict)
  107. if SAVE_JSON:
  108. # Save result to a json file for inspection
  109. parameters = {"params": {}}
  110. save_json(filename, parameters, result_dict)
  111. def save_and_check_md5(data, filename, generate_golden=False):
  112. """
  113. Save the dataset dictionary and compare (as dictionary) with golden file (md5).
  114. Use create_dict_iterator to access the dataset.
  115. """
  116. num_iter = 0
  117. result_dict = {}
  118. for item in data.create_dict_iterator(): # each data is a dictionary
  119. for data_key in list(item.keys()):
  120. if data_key not in result_dict:
  121. result_dict[data_key] = []
  122. # save the md5 as numpy array
  123. result_dict[data_key].append(np.frombuffer(hashlib.md5(item[data_key]).digest(), dtype='<f4'))
  124. num_iter += 1
  125. logger.info("Number of data in ds1: {}".format(num_iter))
  126. cur_dir = os.path.dirname(os.path.realpath(__file__))
  127. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  128. if generate_golden:
  129. # Save as the golden result
  130. save_golden_dict(cur_dir, golden_ref_dir, result_dict)
  131. compare_to_golden_dict(golden_ref_dir, result_dict)
  132. def ordered_save_and_check(data, parameters, filename, generate_golden=False):
  133. """
  134. Save the dataset dictionary and compare (as numpy array) with golden file.
  135. Use create_tuple_iterator to access the dataset.
  136. """
  137. num_iter = 0
  138. result_dict = {}
  139. for item in data.create_tuple_iterator(): # each data is a dictionary
  140. for data_key in range(0, len(item)):
  141. if data_key not in result_dict:
  142. result_dict[data_key] = []
  143. result_dict[data_key].append(item[data_key].tolist())
  144. num_iter += 1
  145. logger.info("Number of data in data1: {}".format(num_iter))
  146. cur_dir = os.path.dirname(os.path.realpath(__file__))
  147. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  148. if generate_golden:
  149. # Save as the golden result
  150. save_golden(cur_dir, golden_ref_dir, result_dict)
  151. compare_to_golden(golden_ref_dir, result_dict)
  152. if SAVE_JSON:
  153. # Save result to a json file for inspection
  154. save_json(filename, parameters, result_dict)
  155. def diff_mse(in1, in2):
  156. mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean()
  157. return mse * 100
  158. def diff_me(in1, in2):
  159. mse = (np.abs(in1.astype(float) - in2.astype(float))).mean()
  160. return mse / 255 * 100
  161. def visualize(image_original, image_transformed):
  162. """
  163. visualizes the image using DE op and Numpy op
  164. """
  165. num = len(image_transformed)
  166. for i in range(num):
  167. plt.subplot(2, num, i + 1)
  168. plt.imshow(image_original[i])
  169. plt.title("Original image")
  170. plt.subplot(2, num, i + num + 1)
  171. plt.imshow(image_transformed[i])
  172. plt.title("Transformed image")
  173. plt.show()