Browse Source

add param check

tags/v1.1.0
jiangshuqiang 5 years ago
parent
commit
7bc895776f
3 changed files with 128 additions and 46 deletions
  1. +34
    -6
      mindinsight/conditionmgr/condition.py
  2. +87
    -38
      mindinsight/conditionmgr/condition_list.py
  3. +7
    -2
      mindinsight/debugger/stream_handler/watchpoint_handler.py

+ 34
- 6
mindinsight/conditionmgr/condition.py View File

@@ -122,13 +122,17 @@ class ConditionParameter:
Args:
name (str): parameter name.
value_type (ValueTypeEnum): the type of value.
valid_test_func (func): the function used to test whether the param is valid.
support_disable (bool): whether the param support no assignment.
default_value (float): default value.
visible_on_ui (bool): whether the param visible on ui.
"""
def __init__(self, name, value_type: ValueTypeEnum, support_disable=True, default_value=None, visible_on_ui=True):

def __init__(self, name, value_type: ValueTypeEnum, valid_test_func=None, support_disable=True, default_value=None,
visible_on_ui=True):
self._name = name
self._type = value_type
self._valid_test_func = valid_test_func
self._support_disable = support_disable
self._default_value = default_value
self._visible_on_ui = visible_on_ui
@@ -158,6 +162,12 @@ class ConditionParameter:
"""Get visible_on_ui of parameter."""
return self._visible_on_ui

def is_valid(self, value):
"""Check is the parameter valid."""
if self._valid_test_func is None:
return True
return self._valid_test_func(value)


class Condition:
"""
@@ -171,10 +181,10 @@ class Condition:
supported_target_type (TargetTypeEnum): the supported target type.
supported_platforms (tuple[PlatformEnum, PlatformEnum]): the supported platforms.
minimum_debugger_capability (tuple): the minimum debugger capability required.
available_test_func (func): the function used to test whether the condition is available
availability_test_func (func): the function used to test whether the condition is available.
"""
def __init__(self, condition_id, abbr, optimize_phase, parameters, supported_target_type, supported_platforms,
minimum_debugger_capability, available_test_func=None):
minimum_debugger_capability, availability_test_func=None):
self.id = condition_id
self._abbr = abbr
self.optimize_phase = optimize_phase
@@ -184,7 +194,7 @@ class Condition:
self._supported_target_type = supported_target_type
self.supported_platforms = supported_platforms
self.minimum_debugger_capability = minimum_debugger_capability
self.available_test_func = available_test_func
self.availability_test_func = availability_test_func

def get_parameter_definition(self, name):
"""Return parameter definition by the name"""
@@ -200,9 +210,9 @@ class Condition:
if backend not in [platform.value for platform in self.supported_platforms]:
logger.debug("The condition %s is not supported on the platform.", self.id)
return False
if self.available_test_func is None:
if self.availability_test_func is None:
return True
return self.available_test_func(condition_context)
return self.availability_test_func(condition_context)

@property
def abbr(self):
@@ -230,3 +240,21 @@ def check_initialization_available(condition_context):
if condition_context.step == 0:
return True
return False


def check_percentage_param_range(value):
if 0 <= value <= 100:
return True
return False


def check_normal_param_range(value):
if float("-inf") < value < float("inf"):
return True
return False


def check_abs_param_range(value):
if 0 <= value < float("inf"):
return True
return False

+ 87
- 38
mindinsight/conditionmgr/condition_list.py View File

@@ -24,6 +24,10 @@ from mindinsight.conditionmgr.condition import ValueTypeEnum
from mindinsight.conditionmgr.condition import TargetTypeEnum
from mindinsight.conditionmgr.condition import PlatformEnum
from mindinsight.conditionmgr.condition import check_initialization_available
from mindinsight.conditionmgr.condition import check_normal_param_range
from mindinsight.conditionmgr.condition import check_percentage_param_range
from mindinsight.conditionmgr.condition import check_abs_param_range


CONDITION_LIST = [
Condition(
@@ -35,21 +39,24 @@ CONDITION_LIST = [
ConditionParameter(
name="zero_percentage_ge",
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_percentage_param_range,
default_value=100
),
ConditionParameter(
name="max_gt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
),
ConditionParameter(
name="min_lt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
)
],
supported_target_type=TargetTypeEnum.WEIGHT,
supported_platforms=(PlatformEnum.ASCEND, PlatformEnum.GPU),
minimum_debugger_capability=(1, 1),
available_test_func=check_initialization_available
availability_test_func=check_initialization_available
),
Condition(
condition_id="weight_overflow",
@@ -69,19 +76,23 @@ CONDITION_LIST = [
parameters=[
ConditionParameter(
name="abs_mean_gt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_abs_param_range
),
ConditionParameter(
name="max_gt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
),
ConditionParameter(
name="min_gt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
),
ConditionParameter(
name="mean_gt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
)
],
supported_target_type=TargetTypeEnum.WEIGHT,
@@ -96,19 +107,23 @@ CONDITION_LIST = [
parameters=[
ConditionParameter(
name="abs_mean_lt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_abs_param_range
),
ConditionParameter(
name="max_lt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
),
ConditionParameter(
name="min_lt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
),
ConditionParameter(
name="mean_lt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
)
],
supported_target_type=TargetTypeEnum.WEIGHT,
@@ -123,19 +138,23 @@ CONDITION_LIST = [
parameters=[
ConditionParameter(
name="abs_mean_lt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_abs_param_range
),
ConditionParameter(
name="max_lt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
),
ConditionParameter(
name="min_lt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
),
ConditionParameter(
name="mean_lt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
)
],
supported_target_type=TargetTypeEnum.GRADIENT,
@@ -150,19 +169,23 @@ CONDITION_LIST = [
parameters=[
ConditionParameter(
name="abs_mean_gt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_abs_param_range
),
ConditionParameter(
name="max_gt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
),
ConditionParameter(
name="min_gt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
),
ConditionParameter(
name="mean_gt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
)
],
supported_target_type=TargetTypeEnum.GRADIENT,
@@ -237,7 +260,8 @@ CONDITION_LIST = [
parameters=[
ConditionParameter(
name="param",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
)
],
supported_target_type=TargetTypeEnum.TENSOR,
@@ -252,7 +276,8 @@ CONDITION_LIST = [
parameters=[
ConditionParameter(
name="param",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
)
],
supported_target_type=TargetTypeEnum.TENSOR,
@@ -267,7 +292,8 @@ CONDITION_LIST = [
parameters=[
ConditionParameter(
name="param",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
)
],
supported_target_type=TargetTypeEnum.TENSOR,
@@ -282,7 +308,8 @@ CONDITION_LIST = [
parameters=[
ConditionParameter(
name="param",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
)
],
supported_target_type=TargetTypeEnum.TENSOR,
@@ -297,7 +324,8 @@ CONDITION_LIST = [
parameters=[
ConditionParameter(
name="param",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
)
],
supported_target_type=TargetTypeEnum.TENSOR,
@@ -312,7 +340,8 @@ CONDITION_LIST = [
parameters=[
ConditionParameter(
name="param",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
)
],
supported_target_type=TargetTypeEnum.TENSOR,
@@ -327,7 +356,8 @@ CONDITION_LIST = [
parameters=[
ConditionParameter(
name="param",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
)
],
supported_target_type=TargetTypeEnum.TENSOR,
@@ -342,7 +372,8 @@ CONDITION_LIST = [
parameters=[
ConditionParameter(
name="param",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
)
],
supported_target_type=TargetTypeEnum.TENSOR,
@@ -358,21 +389,24 @@ CONDITION_LIST = [
ConditionParameter(
name="zero_percentage_ge",
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_percentage_param_range,
default_value=100
),
ConditionParameter(
name="max_gt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
),
ConditionParameter(
name="min_lt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
)
],
supported_target_type=TargetTypeEnum.TENSOR,
supported_platforms=(PlatformEnum.ASCEND, PlatformEnum.GPU),
minimum_debugger_capability=(1, 1),
available_test_func=check_initialization_available
availability_test_func=check_initialization_available
),
Condition(
condition_id="tensor_too_large",
@@ -382,19 +416,23 @@ CONDITION_LIST = [
parameters=[
ConditionParameter(
name="abs_mean_gt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_abs_param_range
),
ConditionParameter(
name="max_gt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
),
ConditionParameter(
name="min_gt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
),
ConditionParameter(
name="mean_gt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
)
],
supported_target_type=TargetTypeEnum.TENSOR,
@@ -409,19 +447,23 @@ CONDITION_LIST = [
parameters=[
ConditionParameter(
name="abs_mean_lt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_abs_param_range
),
ConditionParameter(
name="max_lt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
),
ConditionParameter(
name="min_lt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
),
ConditionParameter(
name="mean_lt",
value_type=ValueTypeEnum.FLOAT64
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_normal_param_range
)
],
supported_target_type=TargetTypeEnum.TENSOR,
@@ -437,6 +479,7 @@ CONDITION_LIST = [
ConditionParameter(
name="zero_percentage_ge",
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_percentage_param_range,
default_value=100
)
],
@@ -453,6 +496,7 @@ CONDITION_LIST = [
ConditionParameter(
name="rtol",
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_abs_param_range,
default_value=1e-5
),
ConditionParameter(
@@ -483,6 +527,7 @@ CONDITION_LIST = [
ConditionParameter(
name="abs_update_ratio_mean_gt",
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_abs_param_range,
default_value=1e-1
),
ConditionParameter(
@@ -506,6 +551,7 @@ CONDITION_LIST = [
ConditionParameter(
name="abs_update_ratio_mean_lt",
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_abs_param_range,
default_value=1e-4
),
ConditionParameter(
@@ -529,6 +575,7 @@ CONDITION_LIST = [
ConditionParameter(
name="abs_update_ratio_mean_gt",
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_abs_param_range,
default_value=1e-1
),
ConditionParameter(
@@ -552,6 +599,7 @@ CONDITION_LIST = [
ConditionParameter(
name="abs_update_ratio_mean_lt",
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_abs_param_range,
default_value=1e-4
),
ConditionParameter(
@@ -575,6 +623,7 @@ CONDITION_LIST = [
ConditionParameter(
name="rtol",
value_type=ValueTypeEnum.FLOAT64,
valid_test_func=check_abs_param_range,
default_value=1e-5
),
ConditionParameter(


+ 7
- 2
mindinsight/debugger/stream_handler/watchpoint_handler.py View File

@@ -575,12 +575,13 @@ def validate_watch_condition_params(condition_mgr, watch_condition):
return

for param in params:
if param.get("name") not in condition.names:
condition_param_name = param.get("name")
if condition_param_name not in condition.names:
log.error("Invalid name of parameter for condition: %s, available values: %s",
condition_id, condition.names)
raise DebuggerParamValueError("Invalid name of parameter.")

condition_param = condition.get_parameter_definition(param.get("name"))
condition_param = condition.get_parameter_definition(condition_param_name)
if condition_param.type.name in (ValueTypeEnum.FLOAT64.name, ValueTypeEnum.INT64.name) \
and not isinstance(param.get("value"), (float, int)):
log.error("Number param should be given for condition: %s", condition_id)
@@ -591,6 +592,10 @@ def validate_watch_condition_params(condition_mgr, watch_condition):
log.error("Bool param should be given for condition: %s", condition_id)
raise DebuggerParamValueError("Bool param should be given.")

if not condition_param.is_valid(param.get("value")):
log.error("Param %s out of range for condition: %s", condition_param_name, condition_id)
raise DebuggerParamValueError("Parameter out of range.")


def set_default_param(condition_mgr, watch_condition):
"""


Loading…
Cancel
Save