You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.py 8.6 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import hashlib
  16. import json
  17. import os
  18. import matplotlib.pyplot as plt
  19. import numpy as np
  20. # import jsbeautifier
  21. import mindspore.dataset as ds
  22. from mindspore import log as logger
  23. # These are the column names defined in the testTFTestAllTypes dataset
  24. COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
  25. "col_sint16", "col_sint32", "col_sint64"]
  26. SAVE_JSON = False
  27. def _save_golden(cur_dir, golden_ref_dir, result_dict):
  28. """
  29. Save the dictionary values as the golden result in .npz file
  30. """
  31. logger.info("cur_dir is {}".format(cur_dir))
  32. logger.info("golden_ref_dir is {}".format(golden_ref_dir))
  33. np.savez(golden_ref_dir, np.array(list(result_dict.values())))
  34. def _save_golden_dict(cur_dir, golden_ref_dir, result_dict):
  35. """
  36. Save the dictionary (both keys and values) as the golden result in .npz file
  37. """
  38. logger.info("cur_dir is {}".format(cur_dir))
  39. logger.info("golden_ref_dir is {}".format(golden_ref_dir))
  40. np.savez(golden_ref_dir, np.array(list(result_dict.items())))
  41. def _compare_to_golden(golden_ref_dir, result_dict):
  42. """
  43. Compare as numpy arrays the test result to the golden result
  44. """
  45. test_array = np.array(list(result_dict.values()))
  46. golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0']
  47. assert np.array_equal(test_array, golden_array)
  48. def _compare_to_golden_dict(golden_ref_dir, result_dict):
  49. """
  50. Compare as dictionaries the test result to the golden result
  51. """
  52. golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0']
  53. np.testing.assert_equal(result_dict, dict(golden_array))
  54. def _save_json(filename, parameters, result_dict):
  55. """
  56. Save the result dictionary in json file
  57. """
  58. fout = open(filename[:-3] + "json", "w")
  59. options = jsbeautifier.default_options()
  60. options.indent_size = 2
  61. out_dict = {**parameters, **{"columns": result_dict}}
  62. fout.write(jsbeautifier.beautify(json.dumps(out_dict), options))
  63. def save_and_check(data, parameters, filename, generate_golden=False):
  64. """
  65. Save the dataset dictionary and compare (as numpy array) with golden file.
  66. Use create_dict_iterator to access the dataset.
  67. Note: save_and_check() is deprecated; use save_and_check_dict().
  68. """
  69. num_iter = 0
  70. result_dict = {}
  71. for column_name in COLUMNS:
  72. result_dict[column_name] = []
  73. for item in data.create_dict_iterator(): # each data is a dictionary
  74. for data_key in list(item.keys()):
  75. if data_key not in result_dict:
  76. result_dict[data_key] = []
  77. result_dict[data_key].append(item[data_key].tolist())
  78. num_iter += 1
  79. logger.info("Number of data in data1: {}".format(num_iter))
  80. cur_dir = os.path.dirname(os.path.realpath(__file__))
  81. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  82. if generate_golden:
  83. # Save as the golden result
  84. _save_golden(cur_dir, golden_ref_dir, result_dict)
  85. _compare_to_golden(golden_ref_dir, result_dict)
  86. if SAVE_JSON:
  87. # Save result to a json file for inspection
  88. _save_json(filename, parameters, result_dict)
  89. def save_and_check_dict(data, filename, generate_golden=False):
  90. """
  91. Save the dataset dictionary and compare (as dictionary) with golden file.
  92. Use create_dict_iterator to access the dataset.
  93. """
  94. num_iter = 0
  95. result_dict = {}
  96. for item in data.create_dict_iterator(): # each data is a dictionary
  97. for data_key in list(item.keys()):
  98. if data_key not in result_dict:
  99. result_dict[data_key] = []
  100. result_dict[data_key].append(item[data_key].tolist())
  101. num_iter += 1
  102. logger.info("Number of data in ds1: {}".format(num_iter))
  103. cur_dir = os.path.dirname(os.path.realpath(__file__))
  104. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  105. if generate_golden:
  106. # Save as the golden result
  107. _save_golden_dict(cur_dir, golden_ref_dir, result_dict)
  108. _compare_to_golden_dict(golden_ref_dir, result_dict)
  109. if SAVE_JSON:
  110. # Save result to a json file for inspection
  111. parameters = {"params": {}}
  112. _save_json(filename, parameters, result_dict)
  113. def save_and_check_md5(data, filename, generate_golden=False):
  114. """
  115. Save the dataset dictionary and compare (as dictionary) with golden file (md5).
  116. Use create_dict_iterator to access the dataset.
  117. """
  118. num_iter = 0
  119. result_dict = {}
  120. for item in data.create_dict_iterator(): # each data is a dictionary
  121. for data_key in list(item.keys()):
  122. if data_key not in result_dict:
  123. result_dict[data_key] = []
  124. # save the md5 as numpy array
  125. result_dict[data_key].append(np.frombuffer(hashlib.md5(item[data_key]).digest(), dtype='<f4'))
  126. num_iter += 1
  127. logger.info("Number of data in ds1: {}".format(num_iter))
  128. cur_dir = os.path.dirname(os.path.realpath(__file__))
  129. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  130. if generate_golden:
  131. # Save as the golden result
  132. _save_golden_dict(cur_dir, golden_ref_dir, result_dict)
  133. _compare_to_golden_dict(golden_ref_dir, result_dict)
  134. def save_and_check_tuple(data, parameters, filename, generate_golden=False):
  135. """
  136. Save the dataset dictionary and compare (as numpy array) with golden file.
  137. Use create_tuple_iterator to access the dataset.
  138. """
  139. num_iter = 0
  140. result_dict = {}
  141. for item in data.create_tuple_iterator(): # each data is a dictionary
  142. for data_key, _ in enumerate(item):
  143. if data_key not in result_dict:
  144. result_dict[data_key] = []
  145. result_dict[data_key].append(item[data_key].tolist())
  146. num_iter += 1
  147. logger.info("Number of data in data1: {}".format(num_iter))
  148. cur_dir = os.path.dirname(os.path.realpath(__file__))
  149. golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
  150. if generate_golden:
  151. # Save as the golden result
  152. _save_golden(cur_dir, golden_ref_dir, result_dict)
  153. _compare_to_golden(golden_ref_dir, result_dict)
  154. if SAVE_JSON:
  155. # Save result to a json file for inspection
  156. _save_json(filename, parameters, result_dict)
  157. def diff_mse(in1, in2):
  158. mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean()
  159. return mse * 100
  160. def diff_me(in1, in2):
  161. mse = (np.abs(in1.astype(float) - in2.astype(float))).mean()
  162. return mse / 255 * 100
  163. def visualize(image_original, image_transformed):
  164. """
  165. visualizes the image using DE op and Numpy op
  166. """
  167. num = len(image_transformed)
  168. for i in range(num):
  169. plt.subplot(2, num, i + 1)
  170. plt.imshow(image_original[i])
  171. plt.title("Original image")
  172. plt.subplot(2, num, i + num + 1)
  173. plt.imshow(image_transformed[i])
  174. plt.title("Transformed image")
  175. plt.show()
  176. def config_get_set_seed(seed_new):
  177. """
  178. Get and return the original configuration seed value.
  179. Set the new configuration seed value.
  180. """
  181. seed_original = ds.config.get_seed()
  182. ds.config.set_seed(seed_new)
  183. logger.info("seed: original = {} new = {} ".format(seed_original, seed_new))
  184. return seed_original
  185. def config_get_set_num_parallel_workers(num_parallel_workers_new):
  186. """
  187. Get and return the original configuration num_parallel_workers value.
  188. Set the new configuration num_parallel_workers value.
  189. """
  190. num_parallel_workers_original = ds.config.get_num_parallel_workers()
  191. ds.config.set_num_parallel_workers(num_parallel_workers_new)
  192. logger.info("num_parallel_workers: original = {} new = {} ".format(num_parallel_workers_original,
  193. num_parallel_workers_new))
  194. return num_parallel_workers_original