Browse Source

enhance the checking for system defined parameters

Add pre-checking about the param type in choice. If user use batch_size:
[2, 4, ff] in optimizer config file, it will raise an Exception to tell
that it must be integer.
tags/v1.1.0
luopengting 5 years ago
parent
commit
1a023e2ef0
2 changed files with 56 additions and 33 deletions
  1. +38
    -25
      mindinsight/optimizer/common/validator/optimizer_config.py
  2. +18
    -8
      tests/ut/optimizer/common/validator/test_optimizer_config.py

+ 38
- 25
mindinsight/optimizer/common/validator/optimizer_config.py View File

@@ -188,8 +188,43 @@ class OptimizerConfig(Schema):
target = fields.Dict(required=True, error_messages=dict_err_msg)
parameters = fields.Dict(required=True, error_messages=dict_err_msg)

def _check_tunable_system_parameters(self, name, value):
"""Check tunable system parameters."""
def _pre_check_tunable_system_parameters(self, name, value):
self._check_param_type_tunable_system_parameters(name, value)
# need to check param type in choice before checking the value
self._check_param_type_choice_tunable_system_parameters(name, value)
self._check_param_value_tunable_system_parameters(name, value)

def _check_param_type_tunable_system_parameters(self, name, value):
"""Check param type for tunable system parameters."""
param_type = value.get(HyperParamKey.TYPE.value)
if param_type is None:
return

if name == TunableSystemDefinedParams.LEARNING_RATE.value:
if param_type != HyperParamType.FLOAT.value:
err_msg = "The value(s) should be float number, " \
"please config its type as %r." % HyperParamType.FLOAT.value
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.TYPE.value))
elif param_type != HyperParamType.INT.value:
err_msg = "The value(s) should be integer, please config its type as %r." % HyperParamType.INT.value
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.TYPE.value))

def _check_param_type_choice_tunable_system_parameters(self, name, value):
"""Check param type in choice for tunable system parameters."""
choice = value.get(HyperParamKey.CHOICE.value)
if choice is None:
return

if name == TunableSystemDefinedParams.LEARNING_RATE.value:
if list(filter(lambda x: not isinstance(x, float), choice)):
err_msg = "The value(s) should be float number."
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.CHOICE.value))
elif list(filter(lambda x: not isinstance(x, int), choice)):
err_msg = "The value(s) should be integer."
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.CHOICE.value))

def _check_param_value_tunable_system_parameters(self, name, value):
"""Check param value for tunable system parameters."""
bound = value.get(HyperParamKey.BOUND.value)
choice = value.get(HyperParamKey.CHOICE.value)

@@ -204,27 +239,8 @@ class OptimizerConfig(Schema):
err_msg = "The upper bound should be less than and equal to 1."
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.BOUND.value))
if choice is not None and max(choice) >= 1:
err_msg = "The values should be float number less than to 1."
err_msg = "The value(s) should be float number less than 1."
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.CHOICE.value))
elif choice is not None and list(filter(lambda x: not isinstance(x, int), choice)):
# if the choice contains value(s) which is not integer
err_msg = "The value(s) should be integer."
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.CHOICE.value))

def _pre_check_tunable_system_parameters(self, name, value):
"""Check tunable system parameters."""
param_type = value.get(HyperParamKey.TYPE.value)
if param_type is None:
return

if name == TunableSystemDefinedParams.LEARNING_RATE.value:
if param_type != HyperParamType.FLOAT.value:
err_msg = "The value(s) should be float number, " \
"please config its type as %r." % HyperParamType.FLOAT.value
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.TYPE.value))
elif param_type != HyperParamType.INT.value:
err_msg = "The value(s) should be integer, please config its type as %r." % HyperParamType.INT.value
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.TYPE.value))

@validates("tuner")
def check_tuner(self, data):
@@ -254,9 +270,6 @@ class OptimizerConfig(Schema):
if err:
raise ValidationError({name: err})

if is_system_param:
self._check_tunable_system_parameters(name, value)

if source is None:
# if params is in system_defined keys, group will be 'system_defined', else will be 'user_defined'.
continue


+ 18
- 8
tests/ut/optimizer/common/validator/test_optimizer_config.py View File

@@ -183,6 +183,18 @@ class TestOptimizerConfig:
err = OptimizerConfig().validate(config_dict)
assert expected_err == err

config_dict['parameters']['learning_rate'].pop('bounds')
config_dict['parameters']['learning_rate']['choice'] = [0.1, 1.1]
expected_err = {
'parameters': {
'learning_rate': {
'choice': 'The value(s) should be float number less than 1.'
}
}
}
err = OptimizerConfig().validate(config_dict)
assert expected_err == err

def test_learning_rate_type(self):
"""Test learning rate with wrong type."""
config_dict = deepcopy(self._config_dict)
@@ -202,7 +214,7 @@ class TestOptimizerConfig:
"""Test parameters combination."""
config_dict = deepcopy(self._config_dict)
config_dict['parameters'] = {
param_name: {'choice': [-0.1, 1]}
param_name: {'choice': [-1, 1]}
}
expected_err = {
'parameters': {
@@ -214,14 +226,12 @@ class TestOptimizerConfig:
err = OptimizerConfig().validate(config_dict)
assert expected_err == err

expected_err['parameters'][param_name]['choice'] = 'The value(s) should be integer.'
config_dict['parameters'][param_name]['choice'] = [1, 'hello']
err = OptimizerConfig().validate(config_dict)
assert expected_err == err

config_dict['parameters'][param_name] = {'choice': [0.1, 0.2]}
expected_err = {
'parameters': {
param_name: {
'choice': 'The value(s) should be integer.'
}
}
}
err = OptimizerConfig().validate(config_dict)
assert expected_err == err



Loading…
Cancel
Save