Browse Source

extract tensor history

tags/v1.0.0
yelihua 5 years ago
parent
commit
362aaa79b4
7 changed files with 95 additions and 116 deletions
  1. +0
    -2
      mindinsight/debugger/common/exceptions/error_code.py
  2. +14
    -7
      mindinsight/debugger/common/exceptions/exceptions.py
  3. +46
    -54
      mindinsight/debugger/debugger_grpc_server.py
  4. +16
    -31
      mindinsight/debugger/debugger_server.py
  5. +1
    -1
      mindinsight/debugger/stream_handler/event_handler.py
  6. +13
    -16
      mindinsight/debugger/stream_handler/tensor_handler.py
  7. +5
    -5
      mindinsight/debugger/stream_handler/watchpoint_handler.py

+ 0
- 2
mindinsight/debugger/common/exceptions/error_code.py View File

@@ -44,8 +44,6 @@ class DebuggerErrorMsg(Enum):
"""Debugger error messages."""
PARAM_TYPE_ERROR = "TypeError. {}"
PARAM_VALUE_ERROR = "ValueError. {}"
PARAM_MISSING_ERROR = "MissingError. {}"
UNEXPECTED_EXCEPTION_ERROR = "Unexpected exception. {}"

GRAPH_NOT_EXIST_ERROR = "The graph does not exist."



+ 14
- 7
mindinsight/debugger/common/exceptions/exceptions.py View File

@@ -23,7 +23,8 @@ class DebuggerParamTypeError(MindInsightException):
def __init__(self, msg):
super(DebuggerParamTypeError, self).__init__(
error=DebuggerErrors.PARAM_TYPE_ERROR,
message=DebuggerErrorMsg.PARAM_TYPE_ERROR.value.format(msg)
message=DebuggerErrorMsg.PARAM_TYPE_ERROR.value.format(msg),
http_code=400
)


@@ -33,7 +34,8 @@ class DebuggerParamValueError(MindInsightException):
def __init__(self, msg):
super(DebuggerParamValueError, self).__init__(
error=DebuggerErrors.PARAM_VALUE_ERROR,
message=DebuggerErrorMsg.PARAM_VALUE_ERROR.value.format(msg)
message=DebuggerErrorMsg.PARAM_VALUE_ERROR.value.format(msg),
http_code=400
)


@@ -43,7 +45,8 @@ class DebuggerCreateWatchPointError(MindInsightException):
def __init__(self, msg):
super(DebuggerCreateWatchPointError, self).__init__(
error=DebuggerErrors.CREATE_WATCHPOINT_ERROR,
message=DebuggerErrorMsg.CREATE_WATCHPOINT_ERROR.value.format(msg)
message=DebuggerErrorMsg.CREATE_WATCHPOINT_ERROR.value.format(msg),
http_code=400
)


@@ -53,7 +56,8 @@ class DebuggerUpdateWatchPointError(MindInsightException):
def __init__(self, msg):
super(DebuggerUpdateWatchPointError, self).__init__(
error=DebuggerErrors.UPDATE_WATCHPOINT_ERROR,
message=DebuggerErrorMsg.UPDATE_WATCHPOINT_ERROR.value.format(msg)
message=DebuggerErrorMsg.UPDATE_WATCHPOINT_ERROR.value.format(msg),
http_code=400
)


@@ -63,7 +67,8 @@ class DebuggerDeleteWatchPointError(MindInsightException):
def __init__(self, msg):
super(DebuggerDeleteWatchPointError, self).__init__(
error=DebuggerErrors.DELETE_WATCHPOINT_ERROR,
message=DebuggerErrorMsg.DELETE_WATCHPOINT_ERROR.value.format(msg)
message=DebuggerErrorMsg.DELETE_WATCHPOINT_ERROR.value.format(msg),
http_code=400
)


@@ -82,7 +87,8 @@ class DebuggerContinueError(MindInsightException):
def __init__(self, msg):
super(DebuggerContinueError, self).__init__(
error=DebuggerErrors.CONTINUE_ERROR,
message=DebuggerErrorMsg.CONTINUE_ERROR.value.format(msg)
message=DebuggerErrorMsg.CONTINUE_ERROR.value.format(msg),
http_code=400
)


@@ -91,7 +97,8 @@ class DebuggerPauseError(MindInsightException):
def __init__(self, msg):
super(DebuggerPauseError, self).__init__(
error=DebuggerErrors.PAUSE_ERROR,
message=DebuggerErrorMsg.PAUSE_ERROR.value.format(msg)
message=DebuggerErrorMsg.PAUSE_ERROR.value.format(msg),
http_code=400
)




+ 46
- 54
mindinsight/debugger/debugger_grpc_server.py View File

@@ -17,7 +17,7 @@ from functools import wraps

from mindinsight.debugger.common.log import logger as log
from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \
create_view_event_from_tensor_history, Streams
Streams
from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base
from mindinsight.debugger.proto.ms_graph_pb2 import GraphProto

@@ -50,18 +50,16 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
self._cache_store = cache_store
self._pos = None
self._status = None
self._view_event = None
self._view_round = None
self._continue_steps = None
self._received_view_cmd = None
self.init()

def init(self):
"""Init debugger grpc server."""
self._pos = '0'
self._status = ServerStatus.PENDING
self._view_event = None
self._view_round = True
self._continue_steps = 0
self._received_view_cmd = {}
self._cache_store.clean()

@debugger_wrap
@@ -73,6 +71,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
log.warning("No graph received before WaitCMD.")
reply = get_ack_reply(1)
return reply
self._send_received_tensor_tag()
# send graph if has not been sent before
self._pre_process(request)
# deal with old command
@@ -80,13 +79,8 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
if reply:
log.info("Reply to WaitCMD with old command: %s", reply)
return reply
# send view cmd
if self._view_round and self._view_event:
self._view_round = False
reply = self._view_event
log.debug("Send ViewCMD.")
# continue multiple steps training
elif self._continue_steps != 0:
if self._continue_steps:
reply = get_ack_reply()
reply.run_cmd.run_steps = 1
reply.run_cmd.run_level = 'step'
@@ -104,6 +98,18 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
log.info("Reply to WaitCMD: %s", reply)
return reply

def _send_received_tensor_tag(self):
"""Send received_finish_tag."""
node_name = self._received_view_cmd.get('node_name')
if not node_name or self._received_view_cmd.get('wait_for_tensor'):
return
metadata = self._cache_store.get_stream_handler(Streams.METADATA).get()
ret = {'receive_tensor': {'node_name': node_name}}
ret.update(metadata)
self._cache_store.put_data(ret)
self._received_view_cmd.clear()
log.info("Send receive tensor flag for %s", node_name)

def _pre_process(self, request):
"""Send graph and metadata when WaitCMD first called."""
metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
@@ -124,7 +130,6 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
def _update_metadata(self, metadata_stream, metadata_proto):
"""Update metadata."""
# reset view round and clean cache data
self._view_round = True
if metadata_stream.step < metadata_proto.cur_step:
self._cache_store.clean_data()
self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors(
@@ -169,19 +174,28 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
log.debug("Received event :%s", event)
if event is None:
return event
if isinstance(event, dict) and event.get('reset'):
self._set_view_event(event)
event = None
if isinstance(event, dict):
event = self._deal_with_view_cmd(event)
elif event.HasField('run_cmd'):
event = self._deal_with_run_cmd(event)
elif event.HasField('view_cmd'):
self._view_round = False
elif event.HasField('exit'):
self._cache_store.clean()
log.info("Clean cache for exit cmd.")

return event

def _deal_with_view_cmd(self, event):
"""Deal with view cmd."""
view_cmd = event.get('view_cmd')
node_name = event.get('node_name')
log.debug("Receive view cmd %s for node: %s.", view_cmd, node_name)
if not (view_cmd and node_name):
log.warning("Invaid view command. Ignore it.")
return None
self._received_view_cmd['node_name'] = node_name
self._received_view_cmd['wait_for_tensor'] = True
return view_cmd

def _deal_with_run_cmd(self, event):
"""Deal with run cmd."""
run_cmd = event.run_cmd
@@ -200,19 +214,6 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):

return event

def _set_view_event(self, event):
"""Create view event for view cmd."""
# the first tensor in view cmd is always the output
node_name = event.get('node_name')
tensor_history = event.get('tensor_history')
if not node_name or not tensor_history:
self._view_event = None
log.info("Reset view command to None.")
else:
# create view event and set
self._view_event = create_view_event_from_tensor_history(tensor_history)
log.info("Reset view command to %s.", node_name)

@debugger_wrap
def SendMetadata(self, request, context):
"""Send metadata into DebuggerCache."""
@@ -223,12 +224,15 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):

client_ip = context.peer().split(':', 1)[-1]
metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
metadata_stream.put(request)
metadata_stream.client_ip = client_ip
metadata = metadata_stream.get()
if request.training_done:
log.info("The training from %s has finished.", client_ip)
else:
metadata_stream.put(request)
metadata_stream.client_ip = client_ip
log.info("Put new metadata from %s into cache.", client_ip)
# put metadata into data queue
metadata = metadata_stream.get()
self._cache_store.put_data(metadata)
log.info("Put new metadata to DataQueue.")
reply = get_ack_reply()
log.info("Send the reply to %s.", client_ip)
return reply
@@ -253,6 +257,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
def SendTensors(self, request_iterator, context):
"""Send tensors into DebuggerCache."""
log.info("Received tensor.")
self._received_view_cmd['wait_for_tensor'] = False
tensor_construct = []
tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR)
metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
@@ -265,41 +270,28 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
tensor_construct = []
tensor_names.append(':'.join([tensor.node_name, tensor.slot]))
continue
# send back tensor finished flag when all waiting tensor has value.
tensor_history = tensor_stream.get_tensor_history(tensor_names)
self._add_node_name_for_tensor_history(tensor_history)
metadata = metadata_stream.get()
tensor_history.update(metadata)
self._cache_store.put_data({}) # reply to the listening request
self._cache_store.put_data(tensor_history)
log.info("Send updated tensor history to data queue.")
reply = get_ack_reply()
return reply

def _add_node_name_for_tensor_history(self, tensor_history):
"""Add node name for tensor history."""
graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH)
for tensor_info in tensor_history.get('tensor_history'):
if tensor_info:
full_name, slot = tensor_info.get('full_name', '').rsplit(':', 1)
node_name = graph_stream.get_node_name_by_full_name(full_name)
tensor_info['name'] = node_name + ':' + slot

@debugger_wrap
def SendWatchpointHits(self, request_iterator, context):
"""Send watchpoint hits info DebuggerCache."""
log.info("Received WatchpointHits. Left steps %d change to 0.", self._continue_steps)
self._continue_steps = 0
self._view_event = None
watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT)
graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH)
for watchpoint_hit_proto in request_iterator:
ui_node_name = graph_stream.get_node_name_by_full_name(
watchpoint_hit_proto.tensor.node_name)
log.debug("Receive watch point hit: %s", watchpoint_hit_proto)
if not ui_node_name:
log.info("Not support to show %s on graph.", watchpoint_hit_proto.tensor.node_name)
continue
watchpoint_hit = {
'tensor_proto': watchpoint_hit_proto.tensor,
'watchpoint': watchpoint_stream.get_watchpoint_by_id(watchpoint_hit_proto.id),
'node_name': graph_stream.get_node_name_by_full_name(
watchpoint_hit_proto.tensor.node_name)
'node_name': ui_node_name
}
watchpoint_hit_stream.put(watchpoint_hit)
watchpoint_hits_info = watchpoint_hit_stream.get()


+ 16
- 31
mindinsight/debugger/debugger_server.py View File

@@ -168,7 +168,6 @@ class DebuggerServer:
"'watchpoint_hit', 'tensor'], but got %s.", mode_mapping)
raise DebuggerParamTypeError("Invalid mode.")
filter_condition = {} if filter_condition is None else filter_condition
self._watch_point_id = filter_condition.get('watch_point_id', 0)
reply = mode_mapping[mode](filter_condition)

return reply
@@ -179,9 +178,9 @@ class DebuggerServer:
log.error("No filter condition required for retrieve all request.")
raise DebuggerParamTypeError("filter_condition should be empty.")
result = {}
self._watch_point_id = 0
self.cache_store.clean_data()
log.info("Clean data queue cache when retrieve all request.")
self.cache_store.put_command({'reset': True})
for stream in [Streams.METADATA, Streams.GRAPH, Streams.WATCHPOINT]:
sub_res = self.cache_store.get_stream_handler(stream).get()
result.update(sub_res)
@@ -197,8 +196,6 @@ class DebuggerServer:

- name (str): The name of single node.

- watch_point_id (int): The id of watchpoint.

- single_node (bool): If False, return the sub-layer of single node. If True, return
the node list from root node to single node.

@@ -206,24 +203,12 @@ class DebuggerServer:
dict, the node info.
"""
log.info("Retrieve node %s.", filter_condition)
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
# validate parameters
node_name = filter_condition.get('name')
if not node_name:
node_type = NodeTypeEnum.NAME_SCOPE.value
else:
node_type = graph_stream.get_node_type(node_name)
filter_condition['node_type'] = node_type
if node_name:
# validate node name
self.cache_store.get_stream_handler(Streams.GRAPH).get_node_type(node_name)
filter_condition['single_node'] = bool(filter_condition.get('single_node'))
# get graph for scope node
if is_scope_type(node_type):
reply = self._get_nodes_info(filter_condition)
# get tensor history for leaf node
else:
reply = self._get_tensor_history(node_name)
if filter_condition.get('single_node'):
graph = self._get_nodes_info(filter_condition)
reply.update(graph)
reply = self._get_nodes_info(filter_condition)
return reply

def _get_nodes_info(self, filter_condition):
@@ -238,8 +223,6 @@ class DebuggerServer:
- single_node (bool): If False, return the sub-layer of single node. If True, return
the node list from root node to single node.

- watch_point_id (int): The id of watchpoint.

Returns:
dict, reply with graph.
"""
@@ -288,13 +271,8 @@ class DebuggerServer:
# get basic tensor history
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
tensor_history = graph_stream.get_tensor_history(node_name)
# set the view event
self.cache_store.put_command(
{'reset': True,
'node_name': node_name,
'tensor_history': tensor_history.get('tensor_history')})
# add tensor value for tensor history
self._add_tensor_value_for_tensor_history(tensor_history)
self._add_tensor_value_for_tensor_history(tensor_history, node_name)
# add hit label for tensor history
watchpoint_hit_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
watchpoint_hit_stream.update_tensor_history(tensor_history)
@@ -303,12 +281,13 @@ class DebuggerServer:
tensor_history.update(metadata)
return tensor_history

def _add_tensor_value_for_tensor_history(self, tensor_history):
def _add_tensor_value_for_tensor_history(self, tensor_history, node_name):
"""
Add tensor value for_tensor_history and send ViewCMD if tensor value missed.

Args:
tensor_history (list[dict]): A list of tensor info, including name and type.
node_name (str): The UI node name.

Returns:
dict, the tensor info.
@@ -317,7 +296,7 @@ class DebuggerServer:
missed_tensors = tensor_stream.update_tensor_history(tensor_history)
if missed_tensors:
view_cmd = create_view_event_from_tensor_history(missed_tensors)
self.cache_store.put_command(view_cmd)
self.cache_store.put_command({'view_cmd': view_cmd, 'node_name': node_name})
log.debug("Send view cmd.")

def retrieve_tensor_value(self, name, detail, shape):
@@ -400,7 +379,10 @@ class DebuggerServer:
dict, watch point list or relative graph.
"""
watchpoint_id = filter_condition.get('watch_point_id')
if watchpoint_id is None:
watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT)
watchpoint_stream.validate_watchpoint_id(watchpoint_id)
self._watch_point_id = watchpoint_id if watchpoint_id else 0
if not watchpoint_id:
reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT).get()
log.debug("Get condition of watchpoints.")
else:
@@ -472,6 +454,7 @@ class DebuggerServer:
watch_nodes = self._get_node_basic_infos(watch_nodes)
watch_point_id = self.cache_store.get_stream_handler(Streams.WATCHPOINT).create_watchpoint(
watch_condition, watch_nodes, watch_point_id)
self._watch_point_id = 0
log.info("Create watchpoint %d", watch_point_id)
return {'id': watch_point_id}

@@ -507,6 +490,7 @@ class DebuggerServer:

self.cache_store.get_stream_handler(Streams.WATCHPOINT).update_watchpoint(
watch_point_id, watch_nodes, mode)
self._watch_point_id = watch_point_id
log.info("Update watchpoint with id: %d", watch_point_id)
return {}

@@ -544,6 +528,7 @@ class DebuggerServer:
)
self.cache_store.get_stream_handler(Streams.WATCHPOINT).delete_watchpoint(
watch_point_id)
self._watch_point_id = 0
log.info("Delete watchpoint with id: %d", watch_point_id)
return {}



+ 1
- 1
mindinsight/debugger/stream_handler/event_handler.py View File

@@ -60,7 +60,7 @@ class EventHandler(StreamHandlerBase):
self._event_cache = [None] * self.max_limit
value = {'metadata': {'pos': '0'}}
self.clean_pending_requests(value)
log.debug("Clean event cache.")
log.debug("Clean event cache. %d request is waiting.", len(self._pending_requests))

def put(self, value):
"""


+ 13
- 16
mindinsight/debugger/stream_handler/tensor_handler.py View File

@@ -51,7 +51,7 @@ class TensorHandler(StreamHandlerBase):
step -= 1
tensor = OpTensor(merged_tensor, step)
self._put_tensor_into_cache(tensor, step)
log.debug("Put tensor %s of step: %d, into cache", tensor.name, step)
log.info("Put tensor %s of step: %d, into cache", tensor.name, step)

@staticmethod
def _get_merged_tensor(tensor_protos):
@@ -164,16 +164,6 @@ class TensorHandler(StreamHandlerBase):

return None

def get_tensor_history(self, tensor_names):
"""Get tensor history for tensor names."""
# only used by grpc server, could be remove later
tensor_infos = []
for tensor_name in tensor_names:
tensor_info = self._get_basic_info(tensor_name)
tensor_infos.append(tensor_info)

return {'tensor_history': tensor_infos}

def update_tensor_history(self, tensor_history):
"""
Add tensor basic info in tensor_history.
@@ -208,20 +198,22 @@ class TensorHandler(StreamHandlerBase):
"""Update has_prev_step field in tensor info."""
flag = None
if node_type == NodeTypeEnum.PARAMETER.value:
flag = self._has_prev_tensor_value(tensor_name)
flag = self._get_prev_tensor_value_status(tensor_name)
if flag and tensor_info:
tensor_info['has_prev_step'] = True
return flag

def _has_prev_tensor_value(self, tensor_name):
def _get_prev_tensor_value_status(self, tensor_name):
"""
Check if the tensor has valid value of previous step.
Get the status of tensor value of previous step.

Args:
tensor_name (str): Tensor name.

Returns:
bool, whether the tensor has valid tensor value.
Union[None, bool], the status of previous tensor value. If True, there is valid previous
tensor value. If False, the tensor value should be queried from client.
If None, ignore.
"""
flag = None
# check if the tensor has previous step value.
@@ -229,7 +221,12 @@ class TensorHandler(StreamHandlerBase):
if prev_step < 0:
return flag
tensor = self._get_tensor(tensor_name, step=prev_step)
flag = bool(tensor and tensor.value)
if not tensor:
# the tensor need to be queried from client
flag = False
elif tensor.value:
flag = True

return flag

def get_tensor_value_by_name(self, tensor_name, prev=False):


+ 5
- 5
mindinsight/debugger/stream_handler/watchpoint_handler.py View File

@@ -98,7 +98,7 @@ class WatchpointHandler(StreamHandlerBase):
"""
if not (watch_point_id and graph):
return
self._validate_watchpoint_id(watch_point_id)
self.validate_watchpoint_id(watch_point_id)
log.debug("add watch flags")
watchpoint = self._watchpoints.get(watch_point_id)
self._set_watch_status_recursively(graph, watchpoint)
@@ -144,7 +144,7 @@ class WatchpointHandler(StreamHandlerBase):
if watch_nodes:
watchpoint.add_nodes(watch_nodes)
elif watch_point_id:
self._validate_watchpoint_id(watch_point_id)
self.validate_watchpoint_id(watch_point_id)
watchpoint.copy_nodes_from(self._watchpoints.get(watch_point_id))
self.put(watchpoint)

@@ -163,7 +163,7 @@ class WatchpointHandler(StreamHandlerBase):
Returns:
dict, empty response.
"""
self._validate_watchpoint_id(watch_point_id)
self.validate_watchpoint_id(watch_point_id)
watchpoint = self._watchpoints.get(watch_point_id)
if watched:
watchpoint.add_nodes(watch_nodes)
@@ -182,7 +182,7 @@ class WatchpointHandler(StreamHandlerBase):
Returns:
dict, empty response.
"""
self._validate_watchpoint_id(watch_point_id)
self.validate_watchpoint_id(watch_point_id)
self._watchpoints.pop(watch_point_id)
set_cmd = SetCMD()
set_cmd.id = watch_point_id
@@ -190,7 +190,7 @@ class WatchpointHandler(StreamHandlerBase):
self._deleted_watchpoints.append(set_cmd)
log.debug("Delete watchpoint %d in cache.", watch_point_id)

def _validate_watchpoint_id(self, watch_point_id):
def validate_watchpoint_id(self, watch_point_id):
"""Validate watchpoint id."""
if watch_point_id and watch_point_id not in self._watchpoints:
log.error("Invalid watchpoint id: %d.", watch_point_id)


Loading…
Cancel
Save