You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_watchpoints.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Watchpoints test script for offline debugger APIs.
  17. """
  18. import os
  19. import json
  20. import shutil
  21. import numpy as np
  22. import mindspore.offline_debug.dbg_services as d
  23. from dump_test_utils import build_dump_structure
  24. from tests.security_utils import security_off_wrap
  25. class TestOfflineWatchpoints:
  26. """Test watchpoint for offline debugger."""
  27. GENERATE_GOLDEN = False
  28. test_name = "watchpoints"
  29. watchpoint_hits_json = []
  30. temp_dir = ''
  31. @classmethod
  32. def setup_class(cls):
  33. """Init setup for offline watchpoints test"""
  34. name1 = "Conv2D.Conv2D-op369.0.0.1"
  35. tensor1 = np.array([[[-1.2808e-03, 7.7629e-03, 1.9241e-02],
  36. [-1.3931e-02, 8.9359e-04, -1.1520e-02],
  37. [-6.3248e-03, 1.8749e-03, 1.0132e-02]],
  38. [[-2.5520e-03, -6.0005e-03, -5.1918e-03],
  39. [-2.7866e-03, 2.5487e-04, 8.4782e-04],
  40. [-4.6310e-03, -8.9111e-03, -8.1778e-05]],
  41. [[1.3914e-03, 6.0844e-04, 1.0643e-03],
  42. [-2.0966e-02, -1.2865e-03, -1.8692e-03],
  43. [-1.6647e-02, 1.0233e-03, -4.1313e-03]]], np.float32)
  44. info1 = d.TensorInfo(node_name="Default/network-WithLossCell/_backbone-AlexNet/conv1-Conv2d/Conv2D-op369",
  45. slot=1, iteration=2, rank_id=0, root_graph_id=0, is_output=False)
  46. name2 = "Parameter.fc2.bias.0.0.2"
  47. tensor2 = np.array([-5.0167350e-06, 1.2509107e-05, -4.3148934e-06, 8.1415592e-06,
  48. 2.1177532e-07, 2.9952851e-06], np.float32)
  49. info2 = d.TensorInfo(node_name="Default/network-WithLossCell/_backbone-AlexNet/fc3-Dense/"
  50. "Parameter[6]_11/fc2.bias",
  51. slot=0, iteration=2, rank_id=0, root_graph_id=0, is_output=True)
  52. tensor3 = np.array([2.9060817e-07, -5.1009415e-06, -2.8662325e-06, 2.6036503e-06,
  53. -5.1546101e-07, 6.0798648e-06], np.float32)
  54. info3 = d.TensorInfo(node_name="Default/network-WithLossCell/_backbone-AlexNet/fc3-Dense/"
  55. "Parameter[6]_11/fc2.bias",
  56. slot=0, iteration=3, rank_id=0, root_graph_id=0, is_output=True)
  57. name3 = "CudnnUniformReal.CudnnUniformReal-op391.0.0.3"
  58. tensor4 = np.array([-32.0, -4096.0], np.float32)
  59. info4 = d.TensorInfo(node_name="Default/CudnnUniformReal-op391",
  60. slot=0, iteration=2, rank_id=0, root_graph_id=0, is_output=False)
  61. tensor_info = [info1, info2, info3, info4]
  62. tensor_name = [name1, name2, name2, name3]
  63. tensor_list = [tensor1, tensor2, tensor3, tensor4]
  64. cls.temp_dir = build_dump_structure(tensor_name, tensor_list, "Test", tensor_info)
  65. @classmethod
  66. def teardown_class(cls):
  67. shutil.rmtree(cls.temp_dir)
  68. @security_off_wrap
  69. def test_sync_add_remove_watchpoints_hit(self):
  70. # NOTES: watch_condition=6 is MIN_LT
  71. # watchpoint set and hit (watch_condition=6), then remove it
  72. debugger_backend = d.DbgServices(dump_file_path=self.temp_dir)
  73. _ = debugger_backend.initialize(net_name="Test", is_sync_mode=True)
  74. param = d.Parameter(name="param", disabled=False, value=0.0)
  75. _ = debugger_backend.add_watchpoint(watchpoint_id=1, watch_condition=6,
  76. check_node_list={"Default/network-WithLossCell/_backbone-AlexNet"
  77. "/conv1-Conv2d/Conv2D-op369":
  78. {"rank_id": [0], "root_graph_id": [0], "is_output": False
  79. }}, parameter_list=[param])
  80. # add second watchpoint to check the watchpoint hit in correct order
  81. param1 = d.Parameter(name="param", disabled=False, value=10.0)
  82. _ = debugger_backend.add_watchpoint(watchpoint_id=2, watch_condition=6,
  83. check_node_list={"Default/CudnnUniformReal-op391":
  84. {"rank_id": [0], "root_graph_id": [0], "is_output": False
  85. }}, parameter_list=[param1])
  86. watchpoint_hits_test = debugger_backend.check_watchpoints(iteration=2)
  87. assert len(watchpoint_hits_test) == 2
  88. if self.GENERATE_GOLDEN:
  89. self.print_watchpoint_hits(watchpoint_hits_test, 0, False)
  90. else:
  91. self.compare_expect_actual_result(watchpoint_hits_test, 0)
  92. _ = debugger_backend.remove_watchpoint(watchpoint_id=1)
  93. watchpoint_hits_test_1 = debugger_backend.check_watchpoints(iteration=2)
  94. assert len(watchpoint_hits_test_1) == 1
  95. @security_off_wrap
  96. def test_sync_add_remove_watchpoints_not_hit(self):
  97. # watchpoint set and not hit(watch_condition=6), then remove
  98. debugger_backend = d.DbgServices(dump_file_path=self.temp_dir)
  99. _ = debugger_backend.initialize(net_name="Test", is_sync_mode=True)
  100. param = d.Parameter(name="param", disabled=False, value=-1000.0)
  101. _ = debugger_backend.add_watchpoint(watchpoint_id=2, watch_condition=6,
  102. check_node_list={"Default/network-WithLossCell/_backbone-AlexNet"
  103. "/conv1-Conv2d/Conv2D-op369":
  104. {"rank_id": [0], "root_graph_id": [0], "is_output": False
  105. }}, parameter_list=[param])
  106. watchpoint_hits_test = debugger_backend.check_watchpoints(iteration=2)
  107. assert not watchpoint_hits_test
  108. _ = debugger_backend.remove_watchpoint(watchpoint_id=2)
  109. @security_off_wrap
  110. def test_sync_weight_change_watchpoints_hit(self):
  111. # NOTES: watch_condition=18 is CHANGE_TOO_LARGE
  112. # weight change watchpoint set and hit(watch_condition=18)
  113. debugger_backend = d.DbgServices(dump_file_path=self.temp_dir)
  114. _ = debugger_backend.initialize(net_name="Test", is_sync_mode=True)
  115. param_abs_mean_update_ratio_gt = d.Parameter(
  116. name="abs_mean_update_ratio_gt", disabled=False, value=0.0)
  117. param_epsilon = d.Parameter(name="epsilon", disabled=True, value=0.0)
  118. _ = debugger_backend.add_watchpoint(watchpoint_id=3, watch_condition=18,
  119. check_node_list={"Default/network-WithLossCell/_backbone-AlexNet/fc3-Dense/"
  120. "Parameter[6]_11/fc2.bias":
  121. {"rank_id": [0], "root_graph_id": [0], "is_output": True
  122. }}, parameter_list=[param_abs_mean_update_ratio_gt,
  123. param_epsilon])
  124. watchpoint_hits_test = debugger_backend.check_watchpoints(iteration=3)
  125. assert len(watchpoint_hits_test) == 1
  126. if self.GENERATE_GOLDEN:
  127. self.print_watchpoint_hits(watchpoint_hits_test, 2, True)
  128. else:
  129. self.compare_expect_actual_result(watchpoint_hits_test, 2)
  130. @security_off_wrap
  131. def test_async_add_remove_watchpoint_hit(self):
  132. # watchpoint set and hit(watch_condition=6) in async mode, then remove
  133. debugger_backend = d.DbgServices(dump_file_path=self.temp_dir)
  134. _ = debugger_backend.initialize(net_name="Test", is_sync_mode=False)
  135. param = d.Parameter(name="param", disabled=False, value=0.0)
  136. _ = debugger_backend.add_watchpoint(watchpoint_id=1, watch_condition=6,
  137. check_node_list={"Default/network-WithLossCell/_backbone-AlexNet"
  138. "/conv1-Conv2d/Conv2D-op369":
  139. {"rank_id": [0], "root_graph_id": [0], "is_output": False
  140. }}, parameter_list=[param])
  141. watchpoint_hits_test = debugger_backend.check_watchpoints(iteration=2)
  142. assert len(watchpoint_hits_test) == 1
  143. if not self.GENERATE_GOLDEN:
  144. self.compare_expect_actual_result(watchpoint_hits_test, 0)
  145. _ = debugger_backend.remove_watchpoint(watchpoint_id=1)
  146. watchpoint_hits_test_1 = debugger_backend.check_watchpoints(iteration=2)
  147. assert not watchpoint_hits_test_1
  148. @security_off_wrap
  149. def test_async_add_remove_watchpoints_not_hit(self):
  150. # watchpoint set and not hit(watch_condition=6) in async mode, then remove
  151. debugger_backend = d.DbgServices(dump_file_path=self.temp_dir)
  152. _ = debugger_backend.initialize(net_name="Test", is_sync_mode=False)
  153. param = d.Parameter(name="param", disabled=False, value=-1000.0)
  154. _ = debugger_backend.add_watchpoint(watchpoint_id=2, watch_condition=6,
  155. check_node_list={"Default/network-WithLossCell/_backbone-AlexNet"
  156. "/conv1-Conv2d/Conv2D-op369":
  157. {"rank_id": [0], "root_graph_id": [0], "is_output": False
  158. }}, parameter_list=[param])
  159. watchpoint_hits_test = debugger_backend.check_watchpoints(iteration=2)
  160. assert not watchpoint_hits_test
  161. _ = debugger_backend.remove_watchpoint(watchpoint_id=2)
  162. def compare_expect_actual_result(self, watchpoint_hits_list, test_index):
  163. """Compare actual result with golden file."""
  164. golden_file = os.path.realpath(os.path.join("../data/dump/gpu_dumps/golden/",
  165. self.test_name + "_expected.json"))
  166. with open(golden_file) as f:
  167. expected_list = json.load(f)
  168. for x, watchpoint_hits in enumerate(watchpoint_hits_list):
  169. test_id = "watchpoint_hit" + str(test_index+x+1)
  170. info = expected_list[x+test_index][test_id]
  171. assert watchpoint_hits.name == info['name']
  172. assert watchpoint_hits.slot == info['slot']
  173. assert watchpoint_hits.condition == info['condition']
  174. assert watchpoint_hits.watchpoint_id == info['watchpoint_id']
  175. assert watchpoint_hits.error_code == info['error_code']
  176. assert watchpoint_hits.rank_id == info['rank_id']
  177. assert watchpoint_hits.root_graph_id == info['root_graph_id']
  178. for p, _ in enumerate(watchpoint_hits.parameters):
  179. parameter = "parameter" + str(p)
  180. assert watchpoint_hits.parameters[p].name == info['paremeter'][p][parameter]['name']
  181. assert watchpoint_hits.parameters[p].disabled == info['paremeter'][p][parameter]['disabled']
  182. assert watchpoint_hits.parameters[p].value == info['paremeter'][p][parameter]['value']
  183. assert watchpoint_hits.parameters[p].hit == info['paremeter'][p][parameter]['hit']
  184. assert watchpoint_hits.parameters[p].actual_value == info['paremeter'][p][parameter]['actual_value']
  185. def print_watchpoint_hits(self, watchpoint_hits_list, test_index, is_print):
  186. """Print watchpoint hits."""
  187. for x, watchpoint_hits in enumerate(watchpoint_hits_list):
  188. parameter_json = []
  189. for p, _ in enumerate(watchpoint_hits.parameters):
  190. parameter = "parameter" + str(p)
  191. parameter_json.append({
  192. parameter: {
  193. 'name': watchpoint_hits.parameters[p].name,
  194. 'disabled': watchpoint_hits.parameters[p].disabled,
  195. 'value': watchpoint_hits.parameters[p].value,
  196. 'hit': watchpoint_hits.parameters[p].hit,
  197. 'actual_value': watchpoint_hits.parameters[p].actual_value
  198. }
  199. })
  200. watchpoint_hit = "watchpoint_hit" + str(test_index+x+1)
  201. self.watchpoint_hits_json.append({
  202. watchpoint_hit: {
  203. 'name': watchpoint_hits.name,
  204. 'slot': watchpoint_hits.slot,
  205. 'condition': watchpoint_hits.condition,
  206. 'watchpoint_id': watchpoint_hits.watchpoint_id,
  207. 'paremeter': parameter_json,
  208. 'error_code': watchpoint_hits.error_code,
  209. 'rank_id': watchpoint_hits.rank_id,
  210. 'root_graph_id': watchpoint_hits.root_graph_id
  211. }
  212. })
  213. if is_print:
  214. with open(self.test_name + "_expected.json", "w") as dump_f:
  215. json.dump(self.watchpoint_hits_json, dump_f, indent=4, separators=(',', ': '))