Browse Source

!869 fix hint about optimizer config

From: @luopengting
Reviewed-by: @ouwenchang,@wenkai_dist
Signed-off-by: @wenkai_dist
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
2d2f9863bc
4 changed files with 60 additions and 52 deletions
  1. +35
    -23
      mindinsight/optimizer/common/validator/optimizer_config.py
  2. +3
    -3
      mindinsight/optimizer/utils/utils.py
  3. +19
    -23
      tests/ut/optimizer/common/validator/test_optimizer_config.py
  4. +3
    -3
      tests/utils/tools.py

+ 35
- 23
mindinsight/optimizer/common/validator/optimizer_config.py View File

@@ -24,8 +24,8 @@ from mindinsight.optimizer.utils.utils import is_param_name_valid

_BOUND_LEN = 2
_NUMBER_ERR_MSG = "Value(s) should be integer or float."
_TYPE_ERR_MSG = "Value type should be %r."
_VALUE_ERR_MSG = "Value should be in %s. Current value is %s."
_TYPE_ERR_MSG = "Value should be a %s."
_VALUE_ERR_MSG = "Value should be in %s. Current value is %r."


def _generate_schema_err_msg(err_msg, *args):
@@ -87,7 +87,7 @@ class ParameterSchema(Schema):
"""Schema for parameter."""
number_err_msg = _generate_schema_err_msg(_NUMBER_ERR_MSG)
list_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "list")
str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "str")
str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "string")

bounds = fields.List(fields.Number(error_messages=number_err_msg), error_messages=list_err_msg)
choice = fields.List(fields.Number(error_messages=number_err_msg), error_messages=list_err_msg)
@@ -146,7 +146,7 @@ class ParameterSchema(Schema):

class TargetSchema(Schema):
"""Schema for target."""
str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "str")
str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "string")

group = fields.Str(error_messages=str_err_msg)
name = fields.Str(required=True, error_messages=str_err_msg)
@@ -174,13 +174,13 @@ class TargetSchema(Schema):
group = data.get(TargetKey.GROUP.value)
if group == TargetGroup.SYSTEM_DEFINED.value and name not in SystemDefinedTargets.list_members():
raise ValidationError({
TargetKey.GROUP.value: "This target is not system defined. Current group is: %s." % group})
TargetKey.GROUP.value: "This target is not system defined. Current group is %r." % group})


class OptimizerConfig(Schema):
"""Define the search model condition parameter schema."""
dict_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "dict")
str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "str")
str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "string")

summary_base_dir = fields.Str(required=True, error_messages=str_err_msg)
command = fields.Str(required=True, error_messages=str_err_msg)
@@ -192,7 +192,6 @@ class OptimizerConfig(Schema):
"""Check tunable system parameters."""
bound = value.get(HyperParamKey.BOUND.value)
choice = value.get(HyperParamKey.CHOICE.value)
param_type = value.get(HyperParamKey.TYPE.value)

err_msg = "The value(s) should be positive number."
if bound is not None and bound[0] <= 0:
@@ -207,18 +206,25 @@ class OptimizerConfig(Schema):
if choice is not None and max(choice) >= 1:
err_msg = "The values should be float number less than to 1."
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.CHOICE.value))
if param_type == HyperParamType.INT.value:
err_msg = "The value(s) should be float number, please config it as %s." % HyperParamType.FLOAT.value
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.TYPE.value))
else:
if choice is not None and list(filter(lambda x: not isinstance(x, int), choice)):
# if the choice contains value(s) which is not integer
err_msg = "The value(s) should be integer."
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.CHOICE.value))
if bound is not None and param_type != HyperParamType.INT.value:
# if bound is configured, need to config its type as int.
err_msg = "The value(s) should be integer, please config its type as %r." % HyperParamType.INT.value
elif choice is not None and list(filter(lambda x: not isinstance(x, int), choice)):
# if the choice contains value(s) which is not integer
err_msg = "The value(s) should be integer."
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.CHOICE.value))

def _pre_check_tunable_system_parameters(self, name, value):
"""Check tunable system parameters."""
param_type = value.get(HyperParamKey.TYPE.value)
if param_type is None:
return

if name == TunableSystemDefinedParams.LEARNING_RATE.value:
if param_type != HyperParamType.FLOAT.value:
err_msg = "The value(s) should be float number, " \
"please config its type as %r." % HyperParamType.FLOAT.value
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.TYPE.value))
elif param_type != HyperParamType.INT.value:
err_msg = "The value(s) should be integer, please config its type as %r." % HyperParamType.INT.value
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.TYPE.value))

@validates("tuner")
def check_tuner(self, data):
@@ -234,15 +240,21 @@ class OptimizerConfig(Schema):
if not is_param_name_valid(name):
raise ValidationError("Parameter name %r is not a valid name, only number(0-9), alphabet(a-z, A-Z) "
"and underscore(_) characters are allowed in name." % name)
is_system_param = False

source = value.get(HyperParamKey.SOURCE.value)
if source in [None, HyperParamSource.SYSTEM_DEFINED.value] and \
name in TunableSystemDefinedParams.list_members():
is_system_param = True

if is_system_param:
self._pre_check_tunable_system_parameters(name, value)

err = ParameterSchema().validate(value)
if err:
raise ValidationError({name: err})

source = value.get(HyperParamKey.SOURCE.value)

if source in [None, HyperParamSource.SYSTEM_DEFINED.value] and \
name in TunableSystemDefinedParams.list_members():
if is_system_param:
self._check_tunable_system_parameters(name, value)

if source is None:
@@ -252,7 +264,7 @@ class OptimizerConfig(Schema):
if source == HyperParamSource.SYSTEM_DEFINED.value and \
name not in TunableSystemDefinedParams.list_members():
raise ValidationError({
name: {"source": "This param is not system defined. Current source is: %s." % source}})
name: {"source": "This param is not system defined. Current source is %r." % source}})

@validates("target")
def check_target(self, target):


+ 3
- 3
mindinsight/optimizer/utils/utils.py View File

@@ -95,21 +95,21 @@ def is_number(uchar):


def is_alphabet(uchar):
"""If it is a alphabet, return True."""
"""If it is an alphabet, return True."""
if uchar in string.ascii_letters:
return True
return False


def is_allowed_symbols(uchar):
"""If it is a allowed symbol, return True."""
"""If it is an allowed symbol, return True."""
if uchar in ['_']:
return True
return False


def is_param_name_valid(param_name: str):
"""If parameter name only contains number or alphabet."""
"""If parameter name only contains underscore, number or alphabet, return True."""
for uchar in param_name:
if not is_number(uchar) and not is_alphabet(uchar) and not is_allowed_symbols(uchar):
return False


+ 19
- 23
tests/ut/optimizer/common/validator/test_optimizer_config.py View File

@@ -69,18 +69,16 @@ class TestOptimizerConfig:
config_dict['parameters']['learning_rate']['choice'] = init_str
config_dict['parameters']['learning_rate']['type'] = init_list
expected_err = {
'command': ["Value type should be 'str'."],
'command': ['Value should be a string.'],
'parameters': {
'learning_rate': {
'bounds': ["Value type should be 'list'."],
'choice': ["Value type should be 'list'."],
'type': ["Value type should be 'str'."]
'type': "The value(s) should be float number, please config its type as 'float'."
}
},
'summary_base_dir': ["Value type should be 'str'."],
'summary_base_dir': ['Value should be a string.'],
'target': {
'name': ["Value type should be 'str'."],
'goal': ["Value type should be 'str'."]
'goal': ['Value should be a string.'],
'name': ['Value should be a string.']
}
}
err = OptimizerConfig().validate(config_dict)
@@ -101,18 +99,13 @@ class TestOptimizerConfig:
expected_err = {
'parameters': {
'learning_rate': {
'bounds': {
0: ['Value(s) should be integer or float.']
},
'choice': {
0: ['Value(s) should be integer or float.']
},
'type': ["It should be in ['int', 'float']."]
'type': "The value(s) should be float number, please config its type as 'float'."
}
},
'target': {
'goal': ["Value should be in ['maximize', 'minimize']. Current value is a."],
'group': ["Value should be in ['system_defined', 'metric']. Current value is a."]},
'goal': ["Value should be in ['maximize', 'minimize']. Current value is 'a'."],
'group': ["Value should be in ['system_defined', 'metric']. Current value is 'a'."]
},
'tuner': {
'name': ['Must be one of: gp.']
}
@@ -128,7 +121,7 @@ class TestOptimizerConfig:
config_dict['target']['name'] = 'a'
expected_err = {
'target': {
'group': 'This target is not system defined. Current group is: system_defined.'
'group': "This target is not system defined. Current group is 'system_defined'."
}
}
err = OptimizerConfig().validate(config_dict)
@@ -142,7 +135,7 @@ class TestOptimizerConfig:
expected_err = {
'parameters': {
'decay_step': {
'source': 'This param is not system defined. Current source is: system_defined.'
'source': "This param is not system defined. Current source is 'system_defined'."
}
}
}
@@ -165,7 +158,7 @@ class TestOptimizerConfig:
assert expected_err == err

def test_learning_rate(self):
"""Test parameters combination."""
"""Test learning rate with wrong value."""
config_dict = deepcopy(self._config_dict)

config_dict['parameters']['learning_rate']['bounds'] = [-0.1, 1]
@@ -190,11 +183,14 @@ class TestOptimizerConfig:
err = OptimizerConfig().validate(config_dict)
assert expected_err == err

def test_learning_rate_type(self):
"""Test learning rate with wrong type."""
config_dict = deepcopy(self._config_dict)
config_dict['parameters']['learning_rate']['type'] = 'int'
expected_err = {
'parameters': {
'learning_rate': {
'bounds': 'The upper bound should be less than and equal to 1.'
'type': "The value(s) should be float number, please config its type as 'float'."
}
}
}
@@ -205,9 +201,9 @@ class TestOptimizerConfig:
def test_batch_size_and_epoch(self, param_name):
"""Test parameters combination."""
config_dict = deepcopy(self._config_dict)
config_dict['parameters'] = {}
config_dict['parameters'][param_name] = {'choice': [-0.1, 1]}
config_dict['parameters'] = {
param_name: {'choice': [-0.1, 1]}
}
expected_err = {
'parameters': {
param_name: {


+ 3
- 3
tests/utils/tools.py View File

@@ -67,7 +67,7 @@ def compare_result_with_file(result, expected_file_path):
assert result == expected_results


def deal_float_for_dict(res: dict, expected_res: dict, decimal_num=5):
def deal_float_for_dict(res: dict, expected_res: dict, decimal_num):
"""
Deal float rounded to specified decimals in dict.

@@ -116,7 +116,7 @@ def deal_float_for_dict(res: dict, expected_res: dict, decimal_num=5):
value = res[key]
expected_value = expected_res[key]
if isinstance(value, dict):
deal_float_for_dict(value, expected_value)
deal_float_for_dict(value, expected_value, decimal_num)
elif isinstance(value, float):
res[key] = round(value, decimal_num)
expected_res[key] = round(expected_value, decimal_num)
@@ -130,7 +130,7 @@ def _deal_float_for_list(list1, list2, decimal_num):
index += 1


def assert_equal_lineages(lineages1, lineages2, assert_func, decimal_num=2):
def assert_equal_lineages(lineages1, lineages2, assert_func, decimal_num=5):
"""Assert lineages."""
if isinstance(lineages1, list) and isinstance(lineages2, list):
_deal_float_for_list(lineages1, lineages2, decimal_num)


Loading…
Cancel
Save