|
|
|
@@ -24,8 +24,8 @@ from mindinsight.optimizer.utils.utils import is_param_name_valid |
|
|
|
|
|
|
|
_BOUND_LEN = 2 |
|
|
|
_NUMBER_ERR_MSG = "Value(s) should be integer or float." |
|
|
|
_TYPE_ERR_MSG = "Value type should be %r." |
|
|
|
_VALUE_ERR_MSG = "Value should be in %s. Current value is %s." |
|
|
|
_TYPE_ERR_MSG = "Value should be a %s." |
|
|
|
_VALUE_ERR_MSG = "Value should be in %s. Current value is %r." |
|
|
|
|
|
|
|
|
|
|
|
def _generate_schema_err_msg(err_msg, *args): |
|
|
|
@@ -87,7 +87,7 @@ class ParameterSchema(Schema): |
|
|
|
"""Schema for parameter.""" |
|
|
|
number_err_msg = _generate_schema_err_msg(_NUMBER_ERR_MSG) |
|
|
|
list_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "list") |
|
|
|
str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "str") |
|
|
|
str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "string") |
|
|
|
|
|
|
|
bounds = fields.List(fields.Number(error_messages=number_err_msg), error_messages=list_err_msg) |
|
|
|
choice = fields.List(fields.Number(error_messages=number_err_msg), error_messages=list_err_msg) |
|
|
|
@@ -146,7 +146,7 @@ class ParameterSchema(Schema): |
|
|
|
|
|
|
|
class TargetSchema(Schema): |
|
|
|
"""Schema for target.""" |
|
|
|
str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "str") |
|
|
|
str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "string") |
|
|
|
|
|
|
|
group = fields.Str(error_messages=str_err_msg) |
|
|
|
name = fields.Str(required=True, error_messages=str_err_msg) |
|
|
|
@@ -174,13 +174,13 @@ class TargetSchema(Schema): |
|
|
|
group = data.get(TargetKey.GROUP.value) |
|
|
|
if group == TargetGroup.SYSTEM_DEFINED.value and name not in SystemDefinedTargets.list_members(): |
|
|
|
raise ValidationError({ |
|
|
|
TargetKey.GROUP.value: "This target is not system defined. Current group is: %s." % group}) |
|
|
|
TargetKey.GROUP.value: "This target is not system defined. Current group is %r." % group}) |
|
|
|
|
|
|
|
|
|
|
|
class OptimizerConfig(Schema): |
|
|
|
"""Define the search model condition parameter schema.""" |
|
|
|
dict_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "dict") |
|
|
|
str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "str") |
|
|
|
str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "string") |
|
|
|
|
|
|
|
summary_base_dir = fields.Str(required=True, error_messages=str_err_msg) |
|
|
|
command = fields.Str(required=True, error_messages=str_err_msg) |
|
|
|
@@ -192,7 +192,6 @@ class OptimizerConfig(Schema): |
|
|
|
"""Check tunable system parameters.""" |
|
|
|
bound = value.get(HyperParamKey.BOUND.value) |
|
|
|
choice = value.get(HyperParamKey.CHOICE.value) |
|
|
|
param_type = value.get(HyperParamKey.TYPE.value) |
|
|
|
|
|
|
|
err_msg = "The value(s) should be positive number." |
|
|
|
if bound is not None and bound[0] <= 0: |
|
|
|
@@ -207,18 +206,25 @@ class OptimizerConfig(Schema): |
|
|
|
if choice is not None and max(choice) >= 1: |
|
|
|
err_msg = "The values should be float number less than to 1." |
|
|
|
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.CHOICE.value)) |
|
|
|
if param_type == HyperParamType.INT.value: |
|
|
|
err_msg = "The value(s) should be float number, please config it as %s." % HyperParamType.FLOAT.value |
|
|
|
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.TYPE.value)) |
|
|
|
else: |
|
|
|
if choice is not None and list(filter(lambda x: not isinstance(x, int), choice)): |
|
|
|
# if the choice contains value(s) which is not integer |
|
|
|
err_msg = "The value(s) should be integer." |
|
|
|
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.CHOICE.value)) |
|
|
|
if bound is not None and param_type != HyperParamType.INT.value: |
|
|
|
# if bound is configured, need to config its type as int. |
|
|
|
err_msg = "The value(s) should be integer, please config its type as %r." % HyperParamType.INT.value |
|
|
|
elif choice is not None and list(filter(lambda x: not isinstance(x, int), choice)): |
|
|
|
# if the choice contains value(s) which is not integer |
|
|
|
err_msg = "The value(s) should be integer." |
|
|
|
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.CHOICE.value)) |
|
|
|
|
|
|
|
def _pre_check_tunable_system_parameters(self, name, value): |
|
|
|
"""Check tunable system parameters.""" |
|
|
|
param_type = value.get(HyperParamKey.TYPE.value) |
|
|
|
if param_type is None: |
|
|
|
return |
|
|
|
|
|
|
|
if name == TunableSystemDefinedParams.LEARNING_RATE.value: |
|
|
|
if param_type != HyperParamType.FLOAT.value: |
|
|
|
err_msg = "The value(s) should be float number, " \ |
|
|
|
"please config its type as %r." % HyperParamType.FLOAT.value |
|
|
|
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.TYPE.value)) |
|
|
|
elif param_type != HyperParamType.INT.value: |
|
|
|
err_msg = "The value(s) should be integer, please config its type as %r." % HyperParamType.INT.value |
|
|
|
raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.TYPE.value)) |
|
|
|
|
|
|
|
@validates("tuner") |
|
|
|
def check_tuner(self, data): |
|
|
|
@@ -234,15 +240,21 @@ class OptimizerConfig(Schema): |
|
|
|
if not is_param_name_valid(name): |
|
|
|
raise ValidationError("Parameter name %r is not a valid name, only number(0-9), alphabet(a-z, A-Z) " |
|
|
|
"and underscore(_) characters are allowed in name." % name) |
|
|
|
is_system_param = False |
|
|
|
|
|
|
|
source = value.get(HyperParamKey.SOURCE.value) |
|
|
|
if source in [None, HyperParamSource.SYSTEM_DEFINED.value] and \ |
|
|
|
name in TunableSystemDefinedParams.list_members(): |
|
|
|
is_system_param = True |
|
|
|
|
|
|
|
if is_system_param: |
|
|
|
self._pre_check_tunable_system_parameters(name, value) |
|
|
|
|
|
|
|
err = ParameterSchema().validate(value) |
|
|
|
if err: |
|
|
|
raise ValidationError({name: err}) |
|
|
|
|
|
|
|
source = value.get(HyperParamKey.SOURCE.value) |
|
|
|
|
|
|
|
if source in [None, HyperParamSource.SYSTEM_DEFINED.value] and \ |
|
|
|
name in TunableSystemDefinedParams.list_members(): |
|
|
|
if is_system_param: |
|
|
|
self._check_tunable_system_parameters(name, value) |
|
|
|
|
|
|
|
if source is None: |
|
|
|
@@ -252,7 +264,7 @@ class OptimizerConfig(Schema): |
|
|
|
if source == HyperParamSource.SYSTEM_DEFINED.value and \ |
|
|
|
name not in TunableSystemDefinedParams.list_members(): |
|
|
|
raise ValidationError({ |
|
|
|
name: {"source": "This param is not system defined. Current source is: %s." % source}}) |
|
|
|
name: {"source": "This param is not system defined. Current source is %r." % source}}) |
|
|
|
|
|
|
|
@validates("target") |
|
|
|
def check_target(self, target): |
|
|
|
|