You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_watchpoints.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Watchpoints test script for offline debugger APIs.
  17. """
  18. import os
  19. import json
  20. import time
  21. import tempfile
  22. import numpy as np
  23. import pytest
  24. import mindspore.offline_debug.dbg_services as d
  25. from tests.security_utils import security_off_wrap
  26. from dump_test_utils import build_dump_structure, write_watchpoint_to_json
  27. GENERATE_GOLDEN = False
  28. watchpoint_hits_json = []
  29. def run_watchpoints(is_sync):
  30. if is_sync:
  31. test_name = "sync_watchpoints"
  32. else:
  33. test_name = "async_watchpoints"
  34. name1 = "Conv2D.Conv2D-op369.0.0.1"
  35. tensor1 = np.array([[[-1.2808e-03, 7.7629e-03, 1.9241e-02],
  36. [-1.3931e-02, 8.9359e-04, -1.1520e-02],
  37. [-6.3248e-03, 1.8749e-03, 1.0132e-02]],
  38. [[-2.5520e-03, -6.0005e-03, -5.1918e-03],
  39. [-2.7866e-03, 2.5487e-04, 8.4782e-04],
  40. [-4.6310e-03, -8.9111e-03, -8.1778e-05]],
  41. [[1.3914e-03, 6.0844e-04, 1.0643e-03],
  42. [-2.0966e-02, -1.2865e-03, -1.8692e-03],
  43. [-1.6647e-02, 1.0233e-03, -4.1313e-03]]], np.float32)
  44. info1 = d.TensorInfo(node_name="Default/network-WithLossCell/_backbone-AlexNet/conv1-Conv2d/Conv2D-op369",
  45. slot=1, iteration=2, rank_id=0, root_graph_id=0, is_output=False)
  46. name2 = "Parameter.fc2.bias.0.0.2"
  47. tensor2 = np.array([-5.0167350e-06, 1.2509107e-05, -4.3148934e-06, 8.1415592e-06,
  48. 2.1177532e-07, 2.9952851e-06], np.float32)
  49. info2 = d.TensorInfo(node_name="Default/network-WithLossCell/_backbone-AlexNet/fc3-Dense/"
  50. "Parameter[6]_11/fc2.bias",
  51. slot=0, iteration=2, rank_id=0, root_graph_id=0, is_output=True)
  52. tensor3 = np.array([2.9060817e-07, -5.1009415e-06, -2.8662325e-06, 2.6036503e-06,
  53. -5.1546101e-07, 6.0798648e-06], np.float32)
  54. info3 = d.TensorInfo(node_name="Default/network-WithLossCell/_backbone-AlexNet/fc3-Dense/"
  55. "Parameter[6]_11/fc2.bias",
  56. slot=0, iteration=3, rank_id=0, root_graph_id=0, is_output=True)
  57. tensor_info = [info1, info2, info3]
  58. tensor_name = [name1, name2, name2]
  59. tensor_list = [tensor1, tensor2, tensor3]
  60. pwd = os.getcwd()
  61. with tempfile.TemporaryDirectory(dir=pwd) as tmp_dir:
  62. temp_dir = build_dump_structure(tmp_dir, tensor_name, tensor_list, "Test", tensor_info)
  63. debugger_backend = d.DbgServices(dump_file_path=temp_dir)
  64. debugger_backend.initialize(net_name="Test", is_sync_mode=is_sync)
  65. # NOTES:
  66. # -> watch_condition=6 is MIN_LT
  67. # -> watch_condition=18 is CHANGE_TOO_LARGE
  68. # test 1: watchpoint set and hit (watch_condition=6)
  69. param1 = d.Parameter(name="param", disabled=False, value=0.0)
  70. debugger_backend.add_watchpoint(watchpoint_id=1, watch_condition=6,
  71. check_node_list={"Default/network-WithLossCell/_backbone-AlexNet/"
  72. "conv1-Conv2d/Conv2D-op369":
  73. {"rank_id": [0], "root_graph_id": [0], "is_output": False
  74. }}, parameter_list=[param1])
  75. watchpoint_hits_test_1 = debugger_backend.check_watchpoints(iteration=2)
  76. assert len(watchpoint_hits_test_1) == 1
  77. if GENERATE_GOLDEN:
  78. print_watchpoint_hits(watchpoint_hits_test_1, 0, False, test_name)
  79. else:
  80. compare_expect_actual_result(watchpoint_hits_test_1, 0, test_name)
  81. # test 2: watchpoint remove and ensure it's not hit
  82. debugger_backend.remove_watchpoint(watchpoint_id=1)
  83. watchpoint_hits_test_2 = debugger_backend.check_watchpoints(iteration=2)
  84. assert not watchpoint_hits_test_2
  85. # test 3: watchpoint set and not hit, then remove
  86. param2 = d.Parameter(name="param", disabled=False, value=-1000.0)
  87. debugger_backend.add_watchpoint(watchpoint_id=2, watch_condition=6,
  88. check_node_list={"Default/network-WithLossCell/_backbone-AlexNet/"
  89. "conv1-Conv2d/Conv2D-op369":
  90. {"rank_id": [0], "root_graph_id": [0], "is_output": False
  91. }}, parameter_list=[param2])
  92. watchpoint_hits_test_3 = debugger_backend.check_watchpoints(iteration=2)
  93. assert not watchpoint_hits_test_3
  94. _ = debugger_backend.remove_watchpoint(watchpoint_id=2)
  95. # test 4: weight change watchpoint set and hit
  96. param_abs_mean_update_ratio_gt = d.Parameter(
  97. name="abs_mean_update_ratio_gt", disabled=False, value=0.0)
  98. param_epsilon = d.Parameter(name="epsilon", disabled=True, value=0.0)
  99. debugger_backend.add_watchpoint(watchpoint_id=3, watch_condition=18,
  100. check_node_list={"Default/network-WithLossCell/_backbone-AlexNet/fc3-Dense/"
  101. "Parameter[6]_11/fc2.bias":
  102. {"rank_id": [0], "root_graph_id": [0], "is_output": True
  103. }}, parameter_list=[param_abs_mean_update_ratio_gt,
  104. param_epsilon])
  105. watchpoint_hits_test_4 = debugger_backend.check_watchpoints(iteration=3)
  106. assert len(watchpoint_hits_test_4) == 1
  107. if GENERATE_GOLDEN:
  108. print_watchpoint_hits(watchpoint_hits_test_4, 1, True, test_name)
  109. else:
  110. compare_expect_actual_result(watchpoint_hits_test_4, 1, test_name)
  111. @pytest.mark.level0
  112. @pytest.mark.platform_arm_ascend_training
  113. @pytest.mark.platform_x86_ascend_training
  114. @pytest.mark.env_onecard
  115. @security_off_wrap
  116. def test_sync_watchpoints():
  117. run_watchpoints(True)
  118. @pytest.mark.level0
  119. @pytest.mark.platform_arm_ascend_training
  120. @pytest.mark.platform_x86_ascend_training
  121. @pytest.mark.env_onecard
  122. @security_off_wrap
  123. def test_async_watchpoints():
  124. run_watchpoints(False)
  125. def run_overflow_watchpoint(is_overflow):
  126. test_name = "overflow_watchpoint"
  127. tensor = np.array([65504, 65504], np.float16)
  128. task_id = 2
  129. stream_id = 7
  130. pwd = os.getcwd()
  131. with tempfile.TemporaryDirectory(dir=pwd) as tmp_dir:
  132. path = os.path.join(tmp_dir, "rank_0", "Add", "0", "0")
  133. os.makedirs(path, exist_ok=True)
  134. add_file = os.path.join(path, "Add.Default_Add-op0." + str(task_id) + "." + str(stream_id) + "."
  135. + str(int(round(time.time() * 1000000))))
  136. with open(add_file, 'wb') as add_f:
  137. add_f.write(b'1')
  138. add_f.seek(8)
  139. add_f.write(b'\n\x032.0\x10\x83\xf7\xef\x9f\x99\xc8\xf3\x02\x1a\x10\x08\x02\x10\x02\x1a\x03')
  140. add_f.write(b'\n\x01\x020\x04:\x03\n\x01\x022\x0f')
  141. add_f.write(b'Default/Add-op0')
  142. add_f.write(tensor)
  143. overflow_file = os.path.join(path, "Opdebug.Node_OpDebug." + str(task_id) + "." + str(stream_id) +
  144. "." + str(int(round(time.time() * 1000000))))
  145. with open(overflow_file, 'wb') as f:
  146. f.seek(321, 0)
  147. byte_list = []
  148. for i in range(256):
  149. if i == 16:
  150. byte_list.append(stream_id)
  151. elif i == 24:
  152. if is_overflow:
  153. byte_list.append(task_id)
  154. else:
  155. # wrong task_id, should not generate overflow watchpoint hit
  156. byte_list.append(task_id + 1)
  157. else:
  158. byte_list.append(0)
  159. new_byte_array = bytearray(byte_list)
  160. f.write(bytes(new_byte_array))
  161. debugger_backend = d.DbgServices(dump_file_path=tmp_dir)
  162. debugger_backend.initialize(net_name="Add", is_sync_mode=False)
  163. debugger_backend.add_watchpoint(watchpoint_id=1, watch_condition=2,
  164. check_node_list={"Default/Add-op0":
  165. {"rank_id": [0], "root_graph_id": [0], "is_output": True
  166. }}, parameter_list=[])
  167. watchpoint_hits_test = debugger_backend.check_watchpoints(iteration=0)
  168. if is_overflow:
  169. assert len(watchpoint_hits_test) == 1
  170. if GENERATE_GOLDEN:
  171. print_watchpoint_hits(watchpoint_hits_test, 0, True, test_name)
  172. else:
  173. compare_expect_actual_result(watchpoint_hits_test, 0, test_name)
  174. else:
  175. assert not watchpoint_hits_test
  176. @pytest.mark.level0
  177. @pytest.mark.platform_arm_ascend_training
  178. @pytest.mark.platform_x86_ascend_training
  179. @pytest.mark.env_onecard
  180. @security_off_wrap
  181. def test_async_overflow_watchpoints_hit():
  182. """
  183. Feature: Offline Debugger CheckWatchpoint
  184. Description: Test check overflow watchpoint hit
  185. Expectation: Overflow watchpoint is hit
  186. """
  187. run_overflow_watchpoint(True)
  188. def compare_expect_actual_result(watchpoint_hits_list, test_index, test_name):
  189. """Compare actual result with golden file."""
  190. pwd = os.getcwd()
  191. golden_file = os.path.realpath(os.path.join(pwd, "golden", test_name + "_expected.json"))
  192. with open(golden_file) as f:
  193. expected_list = json.load(f)
  194. for x, watchpoint_hits in enumerate(watchpoint_hits_list):
  195. test_id = "watchpoint_hit" + str(test_index + x + 1)
  196. expect_wp = expected_list[x + test_index][test_id]
  197. actual_wp = write_watchpoint_to_json(watchpoint_hits)
  198. assert actual_wp == expect_wp
  199. def print_watchpoint_hits(watchpoint_hits_list, test_index, is_print, test_name):
  200. """Print watchpoint hits."""
  201. for x, watchpoint_hits in enumerate(watchpoint_hits_list):
  202. watchpoint_hit = "watchpoint_hit" + str(test_index + x + 1)
  203. wp = write_watchpoint_to_json(watchpoint_hits)
  204. watchpoint_hits_json.append({watchpoint_hit: wp})
  205. if is_print:
  206. with open(test_name + "_expected.json", "w") as dump_f:
  207. json.dump(watchpoint_hits_json, dump_f, indent=4, separators=(',', ': '))