You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.py 7.7 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import hashlib
  16. import json
  17. import matplotlib.pyplot as plt
  18. import numpy as np
  19. import os
  20. # import jsbeautifier
  21. from mindspore import log as logger
  22. # These are the column names defined in the testTFTestAllTypes dataset
  23. COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
  24. "col_sint16", "col_sint32", "col_sint64"]
  25. SAVE_JSON = False
  26. def _save_golden(cur_dir, golden_ref_dir, result_dict):
  27. """
  28. Save the dictionary values as the golden result in .npz file
  29. """
  30. logger.info("cur_dir is {}".format(cur_dir))
  31. logger.info("golden_ref_dir is {}".format(golden_ref_dir))
  32. np.savez(golden_ref_dir, np.array(list(result_dict.values())))
  33. def _save_golden_dict(cur_dir, golden_ref_dir, result_dict):
  34. """
  35. Save the dictionary (both keys and values) as the golden result in .npz file
  36. """
  37. logger.info("cur_dir is {}".format(cur_dir))
  38. logger.info("golden_ref_dir is {}".format(golden_ref_dir))
  39. np.savez(golden_ref_dir, np.array(list(result_dict.items())))
  40. def _compare_to_golden(golden_ref_dir, result_dict):
  41. """
  42. Compare as numpy arrays the test result to the golden result
  43. """
  44. test_array = np.array(list(result_dict.values()))
  45. golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0']
  46. assert np.array_equal(test_array, golden_array)
  47. def _compare_to_golden_dict(golden_ref_dir, result_dict):
  48. """
  49. Compare as dictionaries the test result to the golden result
  50. """
  51. golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0']
  52. np.testing.assert_equal(result_dict, dict(golden_array))
  53. def _save_json(filename, parameters, result_dict):
  54. """
  55. Save the result dictionary in json file
  56. """
  57. fout = open(filename[:-3] + "json", "w")
  58. options = jsbeautifier.default_options()
  59. options.indent_size = 2
  60. out_dict = {**parameters, **{"columns": result_dict}}
  61. fout.write(jsbeautifier.beautify(json.dumps(out_dict), options))
  62. def save_and_check(data, parameters, filename, generate_golden=False):
  63. """
  64. Save the dataset dictionary and compare (as numpy array) with golden file.
  65. Use create_dict_iterator to access the dataset.
  66. Note: save_and_check() is deprecated; use save_and_check_dict().
  67. """
  68. num_iter = 0
  69. result_dict = {}
  70. for column_name in COLUMNS:
  71. result_dict[column_name] = []
  72. for item in data.create_dict_iterator(): # each data is a dictionary
  73. for data_key in list(item.keys()):
  74. if data_key not in result_dict:
  75. result_dict[data_key] = []
  76. result_dict[data_key].append(item[data_key].tolist())
  77. num_iter += 1
  78. logger.info("Number of data in data1: {}".format(num_iter))
  79. cur_dir = os.path.dirname(os.path.realpath(__file__))
  80. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  81. if generate_golden:
  82. # Save as the golden result
  83. _save_golden(cur_dir, golden_ref_dir, result_dict)
  84. _compare_to_golden(golden_ref_dir, result_dict)
  85. if SAVE_JSON:
  86. # Save result to a json file for inspection
  87. _save_json(filename, parameters, result_dict)
  88. def save_and_check_dict(data, filename, generate_golden=False):
  89. """
  90. Save the dataset dictionary and compare (as dictionary) with golden file.
  91. Use create_dict_iterator to access the dataset.
  92. """
  93. num_iter = 0
  94. result_dict = {}
  95. for item in data.create_dict_iterator(): # each data is a dictionary
  96. for data_key in list(item.keys()):
  97. if data_key not in result_dict:
  98. result_dict[data_key] = []
  99. result_dict[data_key].append(item[data_key].tolist())
  100. num_iter += 1
  101. logger.info("Number of data in ds1: {}".format(num_iter))
  102. cur_dir = os.path.dirname(os.path.realpath(__file__))
  103. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  104. if generate_golden:
  105. # Save as the golden result
  106. _save_golden_dict(cur_dir, golden_ref_dir, result_dict)
  107. _compare_to_golden_dict(golden_ref_dir, result_dict)
  108. if SAVE_JSON:
  109. # Save result to a json file for inspection
  110. parameters = {"params": {}}
  111. _save_json(filename, parameters, result_dict)
  112. def save_and_check_md5(data, filename, generate_golden=False):
  113. """
  114. Save the dataset dictionary and compare (as dictionary) with golden file (md5).
  115. Use create_dict_iterator to access the dataset.
  116. """
  117. num_iter = 0
  118. result_dict = {}
  119. for item in data.create_dict_iterator(): # each data is a dictionary
  120. for data_key in list(item.keys()):
  121. if data_key not in result_dict:
  122. result_dict[data_key] = []
  123. # save the md5 as numpy array
  124. result_dict[data_key].append(np.frombuffer(hashlib.md5(item[data_key]).digest(), dtype='<f4'))
  125. num_iter += 1
  126. logger.info("Number of data in ds1: {}".format(num_iter))
  127. cur_dir = os.path.dirname(os.path.realpath(__file__))
  128. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  129. if generate_golden:
  130. # Save as the golden result
  131. _save_golden_dict(cur_dir, golden_ref_dir, result_dict)
  132. _compare_to_golden_dict(golden_ref_dir, result_dict)
  133. def save_and_check_tuple(data, parameters, filename, generate_golden=False):
  134. """
  135. Save the dataset dictionary and compare (as numpy array) with golden file.
  136. Use create_tuple_iterator to access the dataset.
  137. """
  138. num_iter = 0
  139. result_dict = {}
  140. for item in data.create_tuple_iterator(): # each data is a dictionary
  141. for data_key, _ in enumerate(item):
  142. if data_key not in result_dict:
  143. result_dict[data_key] = []
  144. result_dict[data_key].append(item[data_key].tolist())
  145. num_iter += 1
  146. logger.info("Number of data in data1: {}".format(num_iter))
  147. cur_dir = os.path.dirname(os.path.realpath(__file__))
  148. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  149. if generate_golden:
  150. # Save as the golden result
  151. _save_golden(cur_dir, golden_ref_dir, result_dict)
  152. _compare_to_golden(golden_ref_dir, result_dict)
  153. if SAVE_JSON:
  154. # Save result to a json file for inspection
  155. _save_json(filename, parameters, result_dict)
  156. def diff_mse(in1, in2):
  157. mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean()
  158. return mse * 100
  159. def diff_me(in1, in2):
  160. mse = (np.abs(in1.astype(float) - in2.astype(float))).mean()
  161. return mse / 255 * 100
  162. def visualize(image_original, image_transformed):
  163. """
  164. visualizes the image using DE op and Numpy op
  165. """
  166. num = len(image_transformed)
  167. for i in range(num):
  168. plt.subplot(2, num, i + 1)
  169. plt.imshow(image_original[i])
  170. plt.title("Original image")
  171. plt.subplot(2, num, i + num + 1)
  172. plt.imshow(image_transformed[i])
  173. plt.title("Transformed image")
  174. plt.show()