Browse Source

!882 add set_recommended_conditions api

From: @jiang-shuqiang
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
ca39fe356c
11 changed files with 118 additions and 27 deletions
  1. +17
    -1
      mindinsight/backend/conditionmgr/conditionmgr_api.py
  2. +0
    -18
      mindinsight/debugger/debugger_grpc_server.py
  3. +28
    -0
      mindinsight/debugger/debugger_server.py
  4. +20
    -1
      mindinsight/debugger/stream_handler/metadata_handler.py
  5. +1
    -1
      tests/st/func/debugger/expect_results/restful_results/before_train_begin.json
  6. +1
    -1
      tests/st/func/debugger/expect_results/restful_results/multi_next_node.json
  7. +45
    -1
      tests/st/func/debugger/expect_results/restful_results/multi_retrieve_all.json
  8. +1
    -1
      tests/st/func/debugger/expect_results/restful_results/recommended_watchpoints_at_startup.json
  9. +2
    -1
      tests/st/func/debugger/expect_results/restful_results/retrieve_all.json
  10. +2
    -1
      tests/st/func/debugger/expect_results/restful_results/retrieve_next_node_on_gpu.json
  11. +1
    -1
      tests/ut/debugger/expected_results/debugger_server/retrieve_all.json

+ 17
- 1
mindinsight/backend/conditionmgr/conditionmgr_api.py View File

@@ -13,9 +13,12 @@
# limitations under the License.
# ============================================================================
"""Conditionmgr restful api."""
from flask import Blueprint
import json

from flask import Blueprint, request

from mindinsight.conf import settings
from mindinsight.utils.exceptions import ParamValueError
from mindinsight.backend.debugger.debugger_api import BACKEND_SERVER, _wrap_reply

BLUEPRINT = Blueprint("conditionmgr", __name__,
@@ -36,6 +39,19 @@ def get_condition_collections(train_id):
return reply


@BLUEPRINT.route("/conditionmgr/train-jobs/<train_id>/set-recommended-watch-points", methods=["POST"])
def set_recommended_watch_points(train_id):
"""set recommended watch points."""
set_recommended = request.stream.read()
try:
set_recommended = json.loads(set_recommended if set_recommended else "{}")
except json.JSONDecodeError:
raise ParamValueError("Json data parse failed.")

reply = _wrap_reply(BACKEND_SERVER.set_recommended_watch_points, set_recommended, train_id)
return reply


def init_module(app):
"""
Init module entry.


+ 0
- 18
mindinsight/debugger/debugger_grpc_server.py View File

@@ -15,13 +15,11 @@
"""Implement the debugger grpc server."""
from functools import wraps

import mindinsight.conditionmgr.recommender
from mindinsight.debugger.common.log import LOGGER as log
from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \
Streams, RunLevel
from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base
from mindinsight.debugger.proto.ms_graph_pb2 import GraphProto
from mindinsight.conditionmgr.condition import ConditionContext


def debugger_wrap(func):
@@ -96,20 +94,6 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
log.debug("Reply to WaitCMD: %s", reply)
return reply

def _add_predefined_watchpoints(self, condition_context):
"""Add predefined watchpoints."""
log.debug("Add predefined watchpoints.")
graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH)
watchpoints = mindinsight.conditionmgr.recommender.recommend_watchpoints(self._condition_mgr, graph_stream,
condition_context)
watch_point_stream_handler = self._cache_store.get_stream_handler(Streams.WATCHPOINT)
for watchpoint in watchpoints:
watch_point_stream_handler.create_watchpoint(
watch_condition=watchpoint.get_watch_condition_dict(),
watch_nodes=watchpoint.watch_nodes,
condition_mgr=self._condition_mgr
)

def _pre_process(self, request):
"""Pre-process before dealing with command."""
metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
@@ -125,8 +109,6 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
watchpoint_stream.clean_temp_cached_names()
# receive graph at the beginning of the training
if self._status == ServerStatus.RECEIVE_GRAPH:
condition_context = ConditionContext(backend=request.backend, debugger_capability=(1, 0))
self._add_predefined_watchpoints(condition_context)
self._send_graph_flag(metadata_stream)
# receive new metadata
if is_new_step or is_new_node:


+ 28
- 0
mindinsight/debugger/debugger_server.py View File

@@ -20,6 +20,7 @@ import grpc

from mindinsight.conditionmgr.conditionmgr import ConditionMgr
from mindinsight.conditionmgr.condition import ConditionContext, ConditionIdEnum
from mindinsight.conditionmgr.recommender import recommend_watchpoints
from mindinsight.conf import settings
from mindinsight.datavisual.data_transform.graph import NodeTypeEnum
from mindinsight.datavisual.utils.tools import to_float
@@ -67,6 +68,33 @@ class DebuggerServer:
log.debug("Train_id: %s, backend: %s", train_id, condition_context.backend)
return self.condition_mgr.get_all_collections(condition_context)

def set_recommended_watch_points(self, set_recommended, train_id):
"""set recommended watch points."""
metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
condition_context = ConditionContext(metadata_stream.backend, metadata_stream.step, (1, 0))
log.debug("Train_id: %s, backend: %s", train_id, condition_context.backend)
res = metadata_stream.get(['state', 'enable_recheck'])
if set_recommended:
res['id'] = self._add_recommended_watchpoints(condition_context)
metadata_stream.recommendation_confirmed = True
return res

def _add_recommended_watchpoints(self, condition_context):
"""Add predefined watchpoints."""
log.debug("Add predefined watchpoints.")
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
watchpoints = recommend_watchpoints(self.condition_mgr, graph_stream, condition_context)
watch_point_stream_handler = self.cache_store.get_stream_handler(Streams.WATCHPOINT)
watch_points_ids = []
for watchpoint in watchpoints:
watch_points_id = watch_point_stream_handler.create_watchpoint(
watch_condition=watchpoint.get_watch_condition_dict(),
watch_nodes=watchpoint.watch_nodes,
condition_mgr=self.condition_mgr
)
watch_points_ids.append(watch_points_id)
return watch_points_ids

def start(self):
"""Start server."""
grpc_port = self.grpc_port if self.grpc_port else "50051"


+ 20
- 1
mindinsight/debugger/stream_handler/metadata_handler.py View File

@@ -31,6 +31,9 @@ class MetadataHandler(StreamHandlerBase):
self._backend = ""
self._enable_recheck = False
self._cur_graph_name = ""
# If recommendation_confirmed is true, it only means the user has answered yes or no to the question,
# it does not necessarily mean that the user will use the recommended watch points.
self._recommendation_confirmed = False

@property
def device_name(self):
@@ -117,6 +120,21 @@ class MetadataHandler(StreamHandlerBase):
"""
self._enable_recheck = bool(value)

@property
def recommendation_confirmed(self):
"""The property of recommendation_confirmed."""
return self._recommendation_confirmed

@recommendation_confirmed.setter
def recommendation_confirmed(self, value):
"""
Set the property of recommendation_confirmed.

Args:
value (str): The new ip.
"""
self._recommendation_confirmed = value

def put(self, value):
"""
Put value into metadata cache. Called by grpc server.
@@ -151,7 +169,8 @@ class MetadataHandler(StreamHandlerBase):
'node_name': self.node_name,
'backend': self.backend,
'enable_recheck': self.enable_recheck,
'graph_name': self.graph_name
'graph_name': self.graph_name,
'recommendation_confirmed': self._recommendation_confirmed
}
else:
if not isinstance(filter_condition, list):


+ 1
- 1
tests/st/func/debugger/expect_results/restful_results/before_train_begin.json View File

@@ -1 +1 @@
{"metadata": {"state": "pending", "step": 0, "device_name": "", "ip": "", "node_name": "", "backend": "", "enable_recheck": false, "graph_name": ""}}
{"metadata": {"state": "pending", "step": 0, "device_name": "", "ip": "", "node_name": "", "backend": "", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false}}

+ 1
- 1
tests/st/func/debugger/expect_results/restful_results/multi_next_node.json View File

@@ -1 +1 @@
{"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/OneHot-op0", "backend": "GPU", "enable_recheck": false, "graph_name": "graph_1"}, "graph": {"graph_names": ["graph_0", "graph_1"], "nodes": [{"name": "graph_0", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}, {"name": "graph_1", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}]}, "watch_points": []}
{"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/OneHot-op0", "backend": "GPU", "enable_recheck": false, "graph_name": "graph_1", "recommendation_confirmed": false}, "graph": {"graph_names": ["graph_0", "graph_1"], "nodes": [{"name": "graph_0", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}, {"name": "graph_1", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}]}, "watch_points": []}

+ 45
- 1
tests/st/func/debugger/expect_results/restful_results/multi_retrieve_all.json View File

@@ -1 +1,45 @@
{"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "", "backend": "Ascend", "enable_recheck": false, "graph_name": ""}, "graph": {"graph_names": ["graph_0", "graph_1"], "nodes": [{"name": "graph_0", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}, {"name": "graph_1", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}]}, "watch_points": []}
{
"metadata": {
"state": "waiting",
"step": 1,
"device_name": "0",
"node_name": "",
"backend": "Ascend",
"enable_recheck": false,
"graph_name": "",
"recommendation_confirmed": false
},
"graph": {
"graph_names": [
"graph_0",
"graph_1"
],
"nodes": [
{
"name": "graph_0",
"type": "name_scope",
"attr": {},
"input": {},
"output": {},
"output_i": 0,
"proxy_input": {},
"proxy_output": {},
"subnode_count": 2,
"independent_layout": false
},
{
"name": "graph_1",
"type": "name_scope",
"attr": {},
"input": {},
"output": {},
"output_i": 0,
"proxy_input": {},
"proxy_output": {},
"subnode_count": 2,
"independent_layout": false
}
]
},
"watch_points": []
}

+ 1
- 1
tests/st/func/debugger/expect_results/restful_results/recommended_watchpoints_at_startup.json View File

@@ -1 +1 @@
{"watch_points": [{"id": 1, "watch_condition": {"id": "overflow", "params": [], "abbr": "OVERFLOW"}}]}
{"watch_points": []}

+ 2
- 1
tests/st/func/debugger/expect_results/restful_results/retrieve_all.json View File

@@ -6,7 +6,8 @@
"node_name": "",
"backend": "Ascend",
"enable_recheck": false,
"graph_name": "graph_0"
"graph_name": "graph_0",
"recommendation_confirmed": false
},
"graph": {
"graph_names": [


+ 2
- 1
tests/st/func/debugger/expect_results/restful_results/retrieve_next_node_on_gpu.json View File

@@ -6,7 +6,8 @@
"node_name": "Default/TransData-op99",
"backend": "GPU",
"enable_recheck": false,
"graph_name": "graph_0"
"graph_name": "graph_0",
"recommendation_confirmed": false
},
"graph": {
"graph_names": [


+ 1
- 1
tests/ut/debugger/expected_results/debugger_server/retrieve_all.json View File

@@ -1 +1 @@
{"metadata": {"state": "waiting", "step": 0, "device_name": "", "pos": "0", "ip": "", "node_name": "", "backend": "", "enable_recheck": false, "graph_name": ""}, "graph": {}, "watch_points": []}
{"metadata": {"state": "waiting", "step": 0, "device_name": "", "pos": "0", "ip": "", "node_name": "", "backend": "", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false}, "graph": {}, "watch_points": []}

Loading…
Cancel
Save