@@ -188,8 +188,43 @@ class OptimizerConfig(Schema):
target = fields.Dict(required=True, error_messages=dict_err_msg)
parameters = fields.Dict(required=True, error_messages=dict_err_msg)
def _check_tunable_system_parameters(self, name, value):
"""Check tunable system parameters."""
def _pre_check_tunable_system_parameters(self, name, value):
self._check_param_type_tunable_system_parameters(name, value)
# need to check param type in choice before checking the value
self._check_param_type_choice_tunable_system_parameters(name, value)
self._check_param_value_tunable_system_parameters(name, value)
def _check_param_type_tunable_system_parameters(self, name, value):
"""Check param type for tunable system parameters."""
param_type = value.get(HyperParamKey.TYPE.value)
if param_type is None:
return
if name == TunableSystemDefinedParams.LEARNING_RATE.value:
if param_type != HyperParamType.FLOAT.value:
err_msg = "The value(s) should be float number, " \
"please config its type as %r." % HyperParamType.FLOAT.value
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.TYPE.value))
elif param_type != HyperParamType.INT.value:
err_msg = "The value(s) should be integer, please config its type as %r." % HyperParamType.INT.value
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.TYPE.value))
def _check_param_type_choice_tunable_system_parameters(self, name, value):
"""Check param type in choice for tunable system parameters."""
choice = value.get(HyperParamKey.CHOICE.value)
if choice is None:
return
if name == TunableSystemDefinedParams.LEARNING_RATE.value:
if list(filter(lambda x: not isinstance(x, float), choice)):
err_msg = "The value(s) should be float number."
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.CHOICE.value))
elif list(filter(lambda x: not isinstance(x, int), choice)):
err_msg = "The value(s) should be integer."
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.CHOICE.value))
def _check_param_value_tunable_system_parameters(self, name, value):
"""Check param value for tunable system parameters."""
bound = value.get(HyperParamKey.BOUND.value)
choice = value.get(HyperParamKey.CHOICE.value)
@@ -204,27 +239,8 @@ class OptimizerConfig(Schema):
err_msg = "The upper bound should be less than and equal to 1."
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.BOUND.value))
if choice is not None and max(choice) >= 1:
err_msg = "The values should be float number less than to 1."
err_msg = "The value( s) should be float number less than 1."
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.CHOICE.value))
elif choice is not None and list(filter(lambda x: not isinstance(x, int), choice)):
# if the choice contains value(s) which is not integer
err_msg = "The value(s) should be integer."
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.CHOICE.value))
def _pre_check_tunable_system_parameters(self, name, value):
"""Check tunable system parameters."""
param_type = value.get(HyperParamKey.TYPE.value)
if param_type is None:
return
if name == TunableSystemDefinedParams.LEARNING_RATE.value:
if param_type != HyperParamType.FLOAT.value:
err_msg = "The value(s) should be float number, " \
"please config its type as %r." % HyperParamType.FLOAT.value
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.TYPE.value))
elif param_type != HyperParamType.INT.value:
err_msg = "The value(s) should be integer, please config its type as %r." % HyperParamType.INT.value
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.TYPE.value))
@validates("tuner")
def check_tuner(self, data):
@@ -254,9 +270,6 @@ class OptimizerConfig(Schema):
if err:
raise ValidationError({name: err})
if is_system_param:
self._check_tunable_system_parameters(name, value)
if source is None:
# if params is in system_defined keys, group will be 'system_defined', else will be 'user_defined'.
continue