You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger_grpc_server.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Implement the debugger grpc server."""
  16. from functools import wraps
  17. from mindinsight.debugger.common.log import logger as log
  18. from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \
  19. Streams
  20. from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base
  21. from mindinsight.debugger.proto.ms_graph_pb2 import GraphProto
  22. def debugger_wrap(func):
  23. """Wrapper for catch exception."""
  24. @wraps(func)
  25. def record_log(*args, **kwargs):
  26. try:
  27. return func(*args, **kwargs)
  28. except Exception as err:
  29. log.exception(err)
  30. raise err
  31. return record_log
  32. class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
  33. """The grpc server used to interactive with grpc client."""
  34. def __init__(self, cache_store):
  35. """
  36. Initialize.
  37. Args:
  38. cache_store (DebuggerCache): Debugger cache store.
  39. """
  40. cache_store.initialize()
  41. self._cache_store = cache_store
  42. self._pos = None
  43. self._status = None
  44. self._continue_steps = None
  45. self._received_view_cmd = None
  46. self.init()
  47. def init(self):
  48. """Init debugger grpc server."""
  49. self._pos = '0'
  50. self._status = ServerStatus.PENDING
  51. self._continue_steps = 0
  52. self._received_view_cmd = {}
  53. self._cache_store.clean()
  54. @debugger_wrap
  55. def WaitCMD(self, request, context):
  56. """Wait for a command in DebuggerCache."""
  57. # check if graph have already received.
  58. log.info("Received WaitCMD at %s-th step.", request.cur_step)
  59. if self._status == ServerStatus.PENDING:
  60. log.warning("No graph received before WaitCMD.")
  61. reply = get_ack_reply(1)
  62. return reply
  63. self._send_received_tensor_tag()
  64. # send graph if has not been sent before
  65. self._pre_process(request)
  66. # deal with old command
  67. reply = self._deal_with_old_command()
  68. if reply:
  69. log.info("Reply to WaitCMD with old command: %s", reply)
  70. return reply
  71. # continue multiple steps training
  72. if self._continue_steps:
  73. reply = get_ack_reply()
  74. reply.run_cmd.run_steps = 1
  75. reply.run_cmd.run_level = 'step'
  76. self._continue_steps = self._continue_steps - 1 if self._continue_steps > 0 else -1
  77. self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean()
  78. log.debug("Send RunCMD. Clean watchpoint hit.")
  79. # wait for command
  80. else:
  81. reply = self._wait_for_next_command()
  82. if reply is None:
  83. reply = get_ack_reply(1)
  84. log.warning("Failed to get command event.")
  85. else:
  86. log.info("Reply to WaitCMD: %s", reply)
  87. return reply
  88. def _send_received_tensor_tag(self):
  89. """Send received_finish_tag."""
  90. node_name = self._received_view_cmd.get('node_name')
  91. if not node_name or self._received_view_cmd.get('wait_for_tensor'):
  92. return
  93. metadata = self._cache_store.get_stream_handler(Streams.METADATA).get()
  94. ret = {'receive_tensor': {'node_name': node_name}}
  95. ret.update(metadata)
  96. self._cache_store.put_data(ret)
  97. self._received_view_cmd.clear()
  98. log.info("Send receive tensor flag for %s", node_name)
  99. def _pre_process(self, request):
  100. """Send graph and metadata when WaitCMD first called."""
  101. metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
  102. if self._status == ServerStatus.RECEIVE_GRAPH:
  103. self._status = ServerStatus.WAITING
  104. metadata_stream.state = 'waiting'
  105. metadata = metadata_stream.get()
  106. self._cache_store.clean_command()
  107. res = self._cache_store.get_stream_handler(Streams.GRAPH).get()
  108. res.update(metadata)
  109. self._cache_store.put_data(res)
  110. log.info("Put graph into data queue.")
  111. if metadata_stream.step < request.cur_step or metadata_stream.full_name != request.cur_node:
  112. # clean tensor cache and DataQueue at the beginning of each step
  113. self._update_metadata(metadata_stream, request)
  114. def _update_metadata(self, metadata_stream, metadata_proto):
  115. """Update metadata."""
  116. # reset view round and clean cache data
  117. if metadata_stream.step < metadata_proto.cur_step:
  118. self._cache_store.clean_data()
  119. self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors(
  120. metadata_proto.cur_step)
  121. # put new metadata into cache
  122. metadata_stream.put(metadata_proto)
  123. cur_node = self._cache_store.get_stream_handler(Streams.GRAPH).get_node_name_by_full_name(
  124. metadata_proto.cur_node) if metadata_proto.cur_node else ''
  125. metadata_stream.node_name = cur_node
  126. metadata = metadata_stream.get()
  127. self._cache_store.put_data(metadata)
  128. log.info("Put new metadata into data queue.")
  129. def _deal_with_old_command(self):
  130. """Deal with old command."""
  131. event = None
  132. while self._cache_store.has_command(self._pos) and event is None:
  133. event = self._get_next_command()
  134. log.debug("Deal with old %s-th command:\n%s.", self._pos, event)
  135. return event
  136. def _wait_for_next_command(self):
  137. """
  138. Wait for next command.
  139. Returns:
  140. EventReply, the command event.
  141. """
  142. log.info("Start to wait for command.")
  143. self._cache_store.get_stream_handler(Streams.METADATA).state = 'waiting'
  144. self._cache_store.put_data({'metadata': {'state': 'waiting'}})
  145. event = None
  146. while event is None and self._status == ServerStatus.WAITING:
  147. log.debug("Wait for %s-th command", self._pos)
  148. event = self._get_next_command()
  149. return event
  150. def _get_next_command(self):
  151. """Get next command."""
  152. self._pos, event = self._cache_store.get_command(self._pos)
  153. log.debug("Received event :%s", event)
  154. if event is None:
  155. return event
  156. if isinstance(event, dict):
  157. event = self._deal_with_view_cmd(event)
  158. elif event.HasField('run_cmd'):
  159. event = self._deal_with_run_cmd(event)
  160. elif event.HasField('exit'):
  161. self._cache_store.clean()
  162. log.info("Clean cache for exit cmd.")
  163. return event
  164. def _deal_with_view_cmd(self, event):
  165. """Deal with view cmd."""
  166. view_cmd = event.get('view_cmd')
  167. node_name = event.get('node_name')
  168. log.debug("Receive view cmd %s for node: %s.", view_cmd, node_name)
  169. if not (view_cmd and node_name):
  170. log.warning("Invaid view command. Ignore it.")
  171. return None
  172. self._received_view_cmd['node_name'] = node_name
  173. self._received_view_cmd['wait_for_tensor'] = True
  174. return view_cmd
  175. def _deal_with_run_cmd(self, event):
  176. """Deal with run cmd."""
  177. run_cmd = event.run_cmd
  178. # receive step command
  179. if run_cmd.run_level == 'step':
  180. # receive pause cmd
  181. if run_cmd.run_steps == 0:
  182. log.debug("Pause training and wait for next command.")
  183. self._continue_steps = 0
  184. return None
  185. # receive step cmd
  186. self._continue_steps = run_cmd.run_steps - 1
  187. event.run_cmd.run_steps = 1
  188. self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean()
  189. log.debug("Receive RunCMD. Clean watchpoint hit cache.")
  190. return event
  191. @debugger_wrap
  192. def SendMetadata(self, request, context):
  193. """Send metadata into DebuggerCache."""
  194. log.info("Received Metadata.")
  195. if self._status != ServerStatus.PENDING:
  196. log.info("Re-initialize cache store when new session comes.")
  197. self.init()
  198. client_ip = context.peer().split(':', 1)[-1]
  199. metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
  200. if request.training_done:
  201. log.info("The training from %s has finished.", client_ip)
  202. else:
  203. metadata_stream.put(request)
  204. metadata_stream.client_ip = client_ip
  205. log.info("Put new metadata from %s into cache.", client_ip)
  206. # put metadata into data queue
  207. metadata = metadata_stream.get()
  208. self._cache_store.put_data(metadata)
  209. reply = get_ack_reply()
  210. log.info("Send the reply to %s.", client_ip)
  211. return reply
  212. @debugger_wrap
  213. def SendGraph(self, request_iterator, context):
  214. """Send graph into DebuggerCache."""
  215. log.info("Received graph.")
  216. serial_graph = b""
  217. for chunk in request_iterator:
  218. serial_graph += chunk.buffer
  219. graph = GraphProto.FromString(serial_graph)
  220. log.debug("Deserialize the graph. Receive %s nodes", len(graph.node))
  221. self._cache_store.get_stream_handler(Streams.GRAPH).put(graph)
  222. self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(graph.const_vals)
  223. self._status = ServerStatus.RECEIVE_GRAPH
  224. reply = get_ack_reply()
  225. log.info("Send the reply for graph.")
  226. return reply
  227. @debugger_wrap
  228. def SendTensors(self, request_iterator, context):
  229. """Send tensors into DebuggerCache."""
  230. log.info("Received tensor.")
  231. tensor_construct = []
  232. tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR)
  233. metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
  234. tensor_names = []
  235. step = metadata_stream.step
  236. for tensor in request_iterator:
  237. tensor_construct.append(tensor)
  238. if tensor.finished:
  239. if self._received_view_cmd.get('wait_for_tensor') and tensor.tensor_content:
  240. self._received_view_cmd['wait_for_tensor'] = False
  241. tensor_stream.put({'step': step, 'tensor_protos': tensor_construct})
  242. tensor_construct = []
  243. tensor_names.append(':'.join([tensor.node_name, tensor.slot]))
  244. continue
  245. reply = get_ack_reply()
  246. return reply
  247. @debugger_wrap
  248. def SendWatchpointHits(self, request_iterator, context):
  249. """Send watchpoint hits info DebuggerCache."""
  250. log.info("Received WatchpointHits. Left steps %d change to 0.", self._continue_steps)
  251. self._continue_steps = 0
  252. watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
  253. watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT)
  254. graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH)
  255. for watchpoint_hit_proto in request_iterator:
  256. ui_node_name = graph_stream.get_node_name_by_full_name(
  257. watchpoint_hit_proto.tensor.node_name)
  258. log.debug("Receive watch point hit: %s", watchpoint_hit_proto)
  259. if not ui_node_name:
  260. log.info("Not support to show %s on graph.", watchpoint_hit_proto.tensor.node_name)
  261. continue
  262. watchpoint_hit = {
  263. 'tensor_proto': watchpoint_hit_proto.tensor,
  264. 'watchpoint': watchpoint_stream.get_watchpoint_by_id(watchpoint_hit_proto.id),
  265. 'node_name': ui_node_name
  266. }
  267. watchpoint_hit_stream.put(watchpoint_hit)
  268. watchpoint_hits_info = watchpoint_hit_stream.get()
  269. self._cache_store.put_data(watchpoint_hits_info)
  270. log.info("Send the watchpoint hits to DataQueue.\nSend the reply.")
  271. reply = get_ack_reply()
  272. return reply