|
|
|
@@ -20,6 +20,7 @@ import grpc |
|
|
|
|
|
|
|
from mindinsight.conditionmgr.conditionmgr import ConditionMgr |
|
|
|
from mindinsight.conditionmgr.condition import ConditionContext, ConditionIdEnum |
|
|
|
from mindinsight.conditionmgr.recommender import recommend_watchpoints |
|
|
|
from mindinsight.conf import settings |
|
|
|
from mindinsight.datavisual.data_transform.graph import NodeTypeEnum |
|
|
|
from mindinsight.datavisual.utils.tools import to_float |
|
|
|
@@ -67,6 +68,33 @@ class DebuggerServer: |
|
|
|
log.debug("Train_id: %s, backend: %s", train_id, condition_context.backend) |
|
|
|
return self.condition_mgr.get_all_collections(condition_context) |
|
|
|
|
|
|
|
def set_recommended_watch_points(self, set_recommended, train_id): |
|
|
|
"""set recommended watch points.""" |
|
|
|
metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) |
|
|
|
condition_context = ConditionContext(metadata_stream.backend, metadata_stream.step, (1, 0)) |
|
|
|
log.debug("Train_id: %s, backend: %s", train_id, condition_context.backend) |
|
|
|
res = metadata_stream.get(['state', 'enable_recheck']) |
|
|
|
if set_recommended: |
|
|
|
res['id'] = self._add_recommended_watchpoints(condition_context) |
|
|
|
metadata_stream.recommendation_confirmed = True |
|
|
|
return res |
|
|
|
|
|
|
|
def _add_recommended_watchpoints(self, condition_context): |
|
|
|
"""Add predefined watchpoints.""" |
|
|
|
log.debug("Add predefined watchpoints.") |
|
|
|
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) |
|
|
|
watchpoints = recommend_watchpoints(self.condition_mgr, graph_stream, condition_context) |
|
|
|
watch_point_stream_handler = self.cache_store.get_stream_handler(Streams.WATCHPOINT) |
|
|
|
watch_points_ids = [] |
|
|
|
for watchpoint in watchpoints: |
|
|
|
watch_points_id = watch_point_stream_handler.create_watchpoint( |
|
|
|
watch_condition=watchpoint.get_watch_condition_dict(), |
|
|
|
watch_nodes=watchpoint.watch_nodes, |
|
|
|
condition_mgr=self.condition_mgr |
|
|
|
) |
|
|
|
watch_points_ids.append(watch_points_id) |
|
|
|
return watch_points_ids |
|
|
|
|
|
|
|
def start(self): |
|
|
|
"""Start server.""" |
|
|
|
grpc_port = self.grpc_port if self.grpc_port else "50051" |
|
|
|
|