You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mock_ms_client.py 7.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Mocked MindSpore debugger client."""
  16. from threading import Thread
  17. import grpc
  18. import numpy as np
  19. from mindinsight.debugger.proto import ms_graph_pb2
  20. from mindinsight.debugger.proto.debug_grpc_pb2 import Metadata, WatchpointHit, Chunk, EventReply
  21. from mindinsight.debugger.proto.debug_grpc_pb2_grpc import EventListenerStub
  22. from mindinsight.debugger.proto.ms_graph_pb2 import TensorProto, DataType
  23. from tests.st.func.debugger.conftest import GRAPH_PROTO_FILE
  24. class MockDebuggerClient:
  25. """Mocked Debugger client."""
  26. def __init__(self, hostname='localhost:50051', backend='Ascend'):
  27. channel = grpc.insecure_channel(hostname)
  28. self.stub = EventListenerStub(channel)
  29. self.flag = True
  30. self._step = 0
  31. self._watchpoint_id = 0
  32. self._leaf_node = []
  33. self._cur_node = ''
  34. self._backend = backend
  35. def _clean(self):
  36. """Clean cache."""
  37. self._step = 0
  38. self._watchpoint_id = 0
  39. self._leaf_node = []
  40. self._cur_node = ''
  41. def get_thread_instance(self):
  42. """Get debugger client thread."""
  43. return MockDebuggerClientThread(self)
  44. def next_node(self, name=None):
  45. """Update the current node to next node."""
  46. if not self._cur_node:
  47. self._cur_node = self._leaf_node[0]
  48. return
  49. cur_index = self._leaf_node.index(self._cur_node)
  50. # if name is not None, go to the specified node.
  51. if not name:
  52. next_index = cur_index + 1
  53. else:
  54. next_index = self._leaf_node.index(name)
  55. # update step
  56. if next_index <= cur_index or next_index == len(self._leaf_node):
  57. self._step += 1
  58. # update current node
  59. if next_index == len(self._leaf_node):
  60. self._cur_node = self._leaf_node[0]
  61. else:
  62. self._cur_node = self._leaf_node[next_index]
  63. def command_loop(self):
  64. """Wait for the command."""
  65. total_steps = 100
  66. wait_flag = True
  67. while self.flag and wait_flag:
  68. if self._step > total_steps:
  69. self.send_metadata_cmd(training_done=True)
  70. return
  71. wait_flag = self._wait_cmd()
  72. def _wait_cmd(self):
  73. """Wait for command and deal with command."""
  74. metadata = self.get_metadata_cmd()
  75. response = self.stub.WaitCMD(metadata)
  76. assert response.status == EventReply.Status.OK
  77. if response.HasField('run_cmd'):
  78. self._deal_with_run_cmd(response)
  79. elif response.HasField('view_cmd'):
  80. for tensor in response.view_cmd.tensors:
  81. self.send_tensor_cmd(in_tensor=tensor)
  82. elif response.HasField('set_cmd'):
  83. self._watchpoint_id += 1
  84. elif response.HasField('exit'):
  85. self._watchpoint_id = 0
  86. self._step = 0
  87. return False
  88. return True
  89. def _deal_with_run_cmd(self, response):
  90. self._step += response.run_cmd.run_steps
  91. if response.run_cmd.run_level == 'node':
  92. self.next_node(response.run_cmd.node_name)
  93. if self._watchpoint_id > 0:
  94. self.send_watchpoint_hit()
  95. def get_metadata_cmd(self, training_done=False):
  96. """Construct metadata message."""
  97. metadata = Metadata()
  98. metadata.device_name = '0'
  99. metadata.cur_step = self._step
  100. metadata.cur_node = self._cur_node
  101. metadata.backend = self._backend
  102. metadata.training_done = training_done
  103. return metadata
  104. def send_metadata_cmd(self, training_done=False):
  105. """Send metadata command."""
  106. self._clean()
  107. metadata = self.get_metadata_cmd(training_done)
  108. response = self.stub.SendMetadata(metadata)
  109. assert response.status == EventReply.Status.OK
  110. if training_done is False:
  111. self.send_graph_cmd()
  112. def send_graph_cmd(self):
  113. """Send graph to debugger server."""
  114. self._step = 1
  115. with open(GRAPH_PROTO_FILE, 'rb') as file_handle:
  116. content = file_handle.read()
  117. size = len(content)
  118. graph = ms_graph_pb2.GraphProto()
  119. graph.ParseFromString(content)
  120. graph.name = 'graph_name'
  121. self._leaf_node = [node.full_name for node in graph.node]
  122. # the max limit of grpc data size is 4kb
  123. # split graph into 3kb per chunk
  124. chunk_size = 1024 * 1024 * 3
  125. chunks = []
  126. for index in range(0, size, chunk_size):
  127. sub_size = min(chunk_size, size - index)
  128. sub_chunk = Chunk(buffer=content[index: index + sub_size])
  129. chunks.append(sub_chunk)
  130. response = self.stub.SendGraph(self._generate_graph(chunks))
  131. assert response.status == EventReply.Status.OK
  132. # go to command loop
  133. self.command_loop()
  134. @staticmethod
  135. def _generate_graph(chunks):
  136. """Construct graph generator."""
  137. for buffer in chunks:
  138. yield buffer
  139. def send_tensor_cmd(self, in_tensor=None):
  140. """Send tensor info with value."""
  141. response = self.stub.SendTensors(self.generate_tensor(in_tensor))
  142. assert response.status == EventReply.Status.OK
  143. @staticmethod
  144. def generate_tensor(in_tensor=None):
  145. """Generate tensor message."""
  146. tensor_content = np.asarray([1, 2, 3, 4, 5, 6]).astype(np.float32).tobytes()
  147. tensors = [TensorProto(), TensorProto()]
  148. tensors[0].CopyFrom(in_tensor)
  149. tensors[0].data_type = DataType.DT_FLOAT32
  150. tensors[0].dims.extend([2, 3])
  151. tensors[1].CopyFrom(tensors[0])
  152. tensors[0].tensor_content = tensor_content[:12]
  153. tensors[1].tensor_content = tensor_content[12:]
  154. tensors[0].finished = 0
  155. tensors[1].finished = 1
  156. for sub_tensor in tensors:
  157. yield sub_tensor
  158. def send_watchpoint_hit(self):
  159. """Send watchpoint hit value."""
  160. tensors = [TensorProto(node_name='Default/TransData-op99', slot='0'),
  161. TensorProto(node_name='Default/optimizer-Momentum/ApplyMomentum-op25', slot='0')]
  162. response = self.stub.SendWatchpointHits(self._generate_hits(tensors))
  163. assert response.status == EventReply.Status.OK
  164. @staticmethod
  165. def _generate_hits(tensors):
  166. """Construct watchpoint hits."""
  167. for tensor in tensors:
  168. hit = WatchpointHit()
  169. hit.id = 1
  170. hit.tensor.CopyFrom(tensor)
  171. yield hit
  172. class MockDebuggerClientThread:
  173. """Mocked debugger client thread."""
  174. def __init__(self, debugger_client):
  175. self._debugger_client = debugger_client
  176. self._debugger_client_thread = Thread(target=debugger_client.send_metadata_cmd)
  177. def __enter__(self, backend='Ascend'):
  178. self._debugger_client.flag = True
  179. self._debugger_client_thread.start()
  180. return self._debugger_client_thread
  181. def __exit__(self, exc_type, exc_val, exc_tb):
  182. self._debugger_client_thread.join(timeout=5)
  183. self._debugger_client.flag = False