|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779 |
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Check parameters."""
- import re
- import inspect
- import math
- from enum import Enum
- from functools import reduce, wraps
- from itertools import repeat
- from collections.abc import Iterable
-
- import numpy as np
- from mindspore import log as logger
- from .common import dtype as mstype
-
-
- # Named string regular expression
- _name_re = r"^\w+[0-9a-zA-Z\_\.]*$"
-
-
- class Rel(Enum):
- """Numerical relationship between variables, logical relationship enumeration definition of range."""
- # scalar compare
- EQ = 1 # ==
- NE = 2 # !=
- LT = 3 # <
- LE = 4 # <=
- GT = 5 # >
- GE = 6 # >=
- # scalar range check
- INC_NEITHER = 7 # (), include neither
- INC_LEFT = 8 # [), include left
- INC_RIGHT = 9 # (], include right
- INC_BOTH = 10 # [], include both
- # collection in, not in
- IN = 11
- NOT_IN = 12
-
- @staticmethod
- def get_strs(rel):
- """Get value from rel_strs."""
- return rel_strs.get(rel, "")
-
- @staticmethod
- def get_fns(rel):
- """Get value from rel_fns."""
- return rel_fns.get(rel, lambda *args: False)
-
-
- rel_fns = {
- # scalar compare
- Rel.EQ: lambda x, y: x == y,
- Rel.NE: lambda x, y: x != y,
- Rel.LT: lambda x, y: x < y,
- Rel.LE: lambda x, y: x <= y,
- Rel.GT: lambda x, y: x > y,
- Rel.GE: lambda x, y: x >= y,
- # scalar range check
- Rel.INC_NEITHER: lambda x, lower, upper: (lower < x < upper),
- Rel.INC_LEFT: lambda x, lower, upper: (lower <= x < upper),
- Rel.INC_RIGHT: lambda x, lower, upper: (lower < x <= upper),
- Rel.INC_BOTH: lambda x, lower, upper: (lower <= x <= upper),
- # collection in, not in
- Rel.IN: lambda x, y: x in y,
- Rel.NOT_IN: lambda x, y: x not in y,
- }
-
- rel_strs = {
- # scalar compare
- Rel.EQ: "equal to {}",
- Rel.NE: "not equal to {}",
- Rel.LT: "less than {}",
- Rel.LE: "less or equal to {}",
- Rel.GT: "greater than {}",
- Rel.GE: "greater or equal to {}",
- # scalar range check
- Rel.INC_NEITHER: "({}, {})",
- Rel.INC_LEFT: "[{}, {})",
- Rel.INC_RIGHT: "({}, {}]",
- Rel.INC_BOTH: "[{}, {}]",
- # collection in, not in
- Rel.IN: "in {}",
- Rel.NOT_IN: "not in {}",
- }
-
-
- class Validator:
- """validator for checking input parameters"""
-
- @staticmethod
- def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None, excp_cls=ValueError):
- """
- Method for judging relation between two int values or list/tuple made up of ints.
-
- This method is not suitable for judging relation between floats, since it does not consider float error.
- """
-
- rel_fn = Rel.get_fns(rel)
- if not rel_fn(arg_value, value):
- rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
- msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
- raise excp_cls(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.')
-
- @staticmethod
- def check_integer(arg_name, arg_value, value, rel, prim_name):
- """Integer value judgment."""
- rel_fn = Rel.get_fns(rel)
- type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
- excp_cls = TypeError if type_mismatch else ValueError
- if type_mismatch or not rel_fn(arg_value, value):
- rel_str = Rel.get_strs(rel).format(value)
- msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
- raise excp_cls(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got `{arg_value}`'
- f' with type `{type(arg_value).__name__}`.')
- return arg_value
-
- @staticmethod
- def check_number(arg_name, arg_value, value, rel, prim_name):
- """Number value judgment."""
- rel_fn = Rel.get_fns(rel)
- if not rel_fn(arg_value, value):
- rel_str = Rel.get_strs(rel).format(value)
- raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must {rel_str}, but got {arg_value}.')
- return arg_value
-
- @staticmethod
- def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name):
- """Method for checking whether an int value is in some range."""
- rel_fn = Rel.get_fns(rel)
- type_mismatch = not isinstance(arg_value, int)
- excp_cls = TypeError if type_mismatch else ValueError
- if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
- rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
- raise excp_cls(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},'
- f' but got `{arg_value}` with type `{type(arg_value).__name__}`.')
- return arg_value
-
- @staticmethod
- def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name):
- """Method for checking whether a numeric value is in some range."""
- rel_fn = Rel.get_fns(rel)
- if not rel_fn(arg_value, lower_limit, upper_limit):
- rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
- raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be in range {rel_str}, but got {arg_value}.')
- return arg_value
-
- @staticmethod
- def check_string(arg_name, arg_value, valid_values, prim_name):
- """Checks whether a string is in some value list"""
- if isinstance(arg_value, str) and arg_value in valid_values:
- return arg_value
- if len(valid_values) == 1:
- raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be {valid_values[0]},'
- f' but got {arg_value}.')
- raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be one of {valid_values},'
- f' but got {arg_value}.')
-
- @staticmethod
- def check_pad_value_by_mode(pad_mode, padding, prim_name):
- """Validates value of padding according to pad_mode"""
- if pad_mode != 'pad' and padding != 0:
- raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}'.")
- return padding
-
- @staticmethod
- def check_float_positive(arg_name, arg_value, prim_name):
- """Float type judgment."""
- msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
- if isinstance(arg_value, float):
- if arg_value > 0:
- return arg_value
- raise ValueError(f"{msg_prefix} `{arg_name}` must be positive, but got {arg_value}.")
- raise TypeError(f"{msg_prefix} `{arg_name}` must be float.")
-
- @staticmethod
- def check_subclass(arg_name, type_, template_type, prim_name):
- """Checks whether some type is subclass of another type"""
- if not isinstance(template_type, Iterable):
- template_type = (template_type,)
- if not any([mstype.issubclass_(type_, x) for x in template_type]):
- type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
- raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass'
- f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')
-
- @staticmethod
- def check_const_input(arg_name, arg_value, prim_name):
- """Checks valid value."""
- if arg_value is None:
- raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.')
-
- @staticmethod
- def check_type_same(args, valid_values, prim_name):
- """Checks whether the types of inputs are the same."""
- def _check_tensor_type(arg):
- arg_key, arg_val = arg
- elem_type = arg_val
- if not elem_type in valid_values:
- type_names = []
- for t in valid_values:
- type_names.append(str(t))
- types_info = '[' + ', '.join(type_names) + ']'
- raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {types_info},'
- f' but got {elem_type}.')
- return (arg_key, elem_type)
-
- def _check_types_same(arg1, arg2):
- arg1_name, arg1_type = arg1
- arg2_name, arg2_type = arg2
- if arg1_type != arg2_type:
- raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
- f' but `{arg1_name}` with type {arg1_type} and `{arg2_name}` with type {arg2_type}.')
- return arg1
-
- elem_types = map(_check_tensor_type, args.items())
- reduce(_check_types_same, elem_types)
-
- @staticmethod
- def check_tensor_type_same(args, valid_values, prim_name):
- """Checks whether the element types of input tensors are the same."""
- tensor_types = [mstype.tensor_type(t) for t in valid_values]
- Validator.check_type_same(args, tensor_types, prim_name)
-
- @staticmethod
- def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False):
- """
- Checks whether the types of inputs are the same. If the input args are tensors, checks their element types.
-
- If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised.
- """
-
- def _check_argument_type(arg):
- arg_key, arg_val = arg
- if isinstance(arg_val, type(mstype.tensor)):
- arg_val = arg_val.element_type()
- if not arg_val in valid_values:
- raise TypeError(f'For \'{prim_name}\' the `{arg_key}` should be in {valid_values},'
- f' but `{arg_key}` is {arg_val}.')
- return arg
-
- def _check_types_same(arg1, arg2):
- arg1_name, arg1_type = arg1
- arg2_name, arg2_type = arg2
- except_flag = False
- if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)):
- arg1_type = arg1_type.element_type()
- arg2_type = arg2_type.element_type()
- elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))):
- pass
- elif allow_mix:
- arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type
- arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type
- else:
- except_flag = True
-
- if except_flag or arg1_type != arg2_type:
- raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
- f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
- return arg1
- reduce(_check_types_same, map(_check_argument_type, args.items()))
-
- @staticmethod
- def check_value_type(arg_name, arg_value, valid_types, prim_name):
- """Checks whether a value is instance of some types."""
- valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
-
- def raise_error_msg():
- """func for raising error message when check failed"""
- type_names = [t.__name__ for t in valid_types]
- num_types = len(valid_types)
- msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
- raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
- f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
-
- # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
- # `check_value_type('x', True, [bool, int])` will check pass
- if isinstance(arg_value, bool) and bool not in tuple(valid_types):
- raise_error_msg()
- if isinstance(arg_value, tuple(valid_types)):
- return arg_value
- raise_error_msg()
-
- @staticmethod
- def check_type_name(arg_name, arg_type, valid_types, prim_name):
- """Checks whether a type in some specified types"""
- valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
-
- def get_typename(t):
- return t.__name__ if hasattr(t, '__name__') else str(t)
-
- if arg_type in valid_types:
- return arg_type
- type_names = [get_typename(t) for t in valid_types]
- msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
- if len(valid_types) == 1:
- raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},'
- f' but got {get_typename(arg_type)}.')
- raise TypeError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},'
- f' but got {get_typename(arg_type)}.')
-
- @staticmethod
- def check_float_legal_value(arg_name, arg_value, prim_name):
- """Checks whether a legal value of float type"""
- msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
- if isinstance(arg_value, float):
- if math.isinf(arg_value) or math.isnan(arg_value):
- raise ValueError(f"{msg_prefix} `{arg_name}` must be legal value, but got {arg_value}.")
- return arg_value
- raise TypeError(f"{msg_prefix} `{arg_name}` must be float.")
-
-
- class ParamValidator:
- """Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
-
- @staticmethod
- def equal(arg_name, arg_value, cond_str, cond):
- """Judging valid value."""
- if not cond:
- raise ValueError(f'The `{arg_name}` must be {cond_str}, but got {arg_value}.')
-
- @staticmethod
- def check(arg_name, arg_value, value_name, value, rel=Rel.EQ):
- """This method is only used for check int values, since when compare float values,
- we need consider float error."""
- rel_fn = Rel.get_fns(rel)
- if not rel_fn(arg_value, value):
- rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
- raise ValueError(f'The `{arg_name}` should be {rel_str}, but got {arg_value}.')
-
- @staticmethod
- def check_integer(arg_name, arg_value, value, rel):
- """Integer value judgment."""
- rel_fn = Rel.get_fns(rel)
- type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
- if type_mismatch or not rel_fn(arg_value, value):
- rel_str = Rel.get_strs(rel).format(value)
- raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
- return arg_value
-
- @staticmethod
- def check_shape_length(arg_name, arg_value, value, rel):
- """Shape length judgment."""
- rel_fn = Rel.get_fns(rel)
- type_mismatch = not isinstance(arg_value, int)
- if type_mismatch or not rel_fn(arg_value, value):
- rel_str = Rel.get_strs(rel).format(value)
- raise ValueError(f'The length of `{arg_name}` should be an int and must {rel_str}, but got {arg_value}')
- return arg_value
-
- @staticmethod
- def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel):
- """This method is only used for check int values,
- since when compare float values, we need consider float error."""
- rel_fn = Rel.get_fns(rel)
- type_mismatch = not isinstance(arg_value, int)
- if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
- rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
- raise ValueError(f'The `{arg_name}` should be an int in range {rel_str}, but got {arg_value}.')
- return arg_value
-
- @staticmethod
- def check_isinstance(arg_name, arg_value, classes):
- """Check arg isinstance of classes"""
- if not isinstance(arg_value, classes):
- raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
- return arg_value
-
- @staticmethod
- def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel):
- """Is it necessary to consider error when comparing float values."""
- rel_fn = Rel.get_fns(rel)
- if not rel_fn(arg_value, lower_limit, upper_limit):
- rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
- raise ValueError(f'The `{arg_name}` should be in range {rel_str}, but got {arg_value}.')
- return arg_value
-
- @staticmethod
- def check_subclass(arg_name, type_, template_type, with_type_of=True):
- """Check whether some type is subclass of another type"""
- if not isinstance(template_type, Iterable):
- template_type = (template_type,)
- if not any([mstype.issubclass_(type_, x) for x in template_type]):
- type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
- raise TypeError(f'The {"type of" if with_type_of else ""} `{arg_name}` should be subclass'
- f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')
-
- @staticmethod
- def check_args_tensor(args):
- """Check whether args are all tensor."""
- if not isinstance(args, dict):
- raise TypeError("The args should be a dict.")
- for arg, value in args.items():
- ParamValidator.check_subclass(arg, value, mstype.tensor)
-
- @staticmethod
- def check_bool(arg_name, arg_value):
- """Check arg isinstance of bool"""
- if not isinstance(arg_value, bool):
- raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.')
- return arg_value
-
- @staticmethod
- def check_type(arg_name, arg_value, valid_types):
- """Type checking."""
- def raise_error_msg():
- """func for raising error message when check failed"""
- type_names = [t.__name__ for t in valid_types]
- num_types = len(valid_types)
- raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
- f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
-
- if isinstance(arg_value, type(mstype.tensor)):
- arg_value = arg_value.element_type()
- # Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
- # `check_type('x', True, [bool, int])` will check pass
- if isinstance(arg_value, bool) and bool not in tuple(valid_types):
- raise_error_msg()
- if isinstance(arg_value, tuple(valid_types)):
- return arg_value
- raise_error_msg()
-
- @staticmethod
- def check_typename(arg_name, arg_type, valid_types):
- """Does it contain the _name_ attribute."""
-
- def get_typename(t):
- return t.__name__ if hasattr(t, '__name__') else str(t)
-
- if isinstance(arg_type, type(mstype.tensor)):
- arg_type = arg_type.element_type()
-
- if arg_type in valid_types:
- return arg_type
- type_names = [get_typename(t) for t in valid_types]
- if len(valid_types) == 1:
- raise ValueError(f'The type of `{arg_name}` should be {type_names[0]},'
- f' but got {get_typename(arg_type)}.')
- raise ValueError(f'The type of `{arg_name}` should be one of {type_names},'
- f' but got {get_typename(arg_type)}.')
-
- @staticmethod
- def check_string(arg_name, arg_value, valid_values):
- """String type judgment."""
- if isinstance(arg_value, str) and arg_value in valid_values:
- return arg_value
- if len(valid_values) == 1:
- raise ValueError(f'The `{arg_name}` should be str and must be {valid_values[0]},'
- f' but got {arg_value}.')
- raise ValueError(f'The `{arg_name}` should be str and must be one of {valid_values},'
- f' but got {arg_value}.')
-
- @staticmethod
- def check_type_same(args, valid_values):
- """Determine whether the types are the same."""
- name = list(args.keys())[0]
- value = list(args.values())[0]
- if isinstance(value, type(mstype.tensor)):
- value = value.element_type()
- for arg_name, arg_value in args.items():
- if isinstance(arg_value, type(mstype.tensor)):
- arg_value = arg_value.element_type()
-
- if arg_value not in valid_values:
- raise TypeError(f'The `{arg_name}` should be in {valid_values},'
- f' but `{arg_name}` is {arg_value}.')
- if arg_value != value:
- raise TypeError(f'`{arg_name}` should be same as `{name}`,'
- f' but `{arg_name}` is {arg_value}, `{name}` is {value}.')
-
- @staticmethod
- def check_two_types_same(arg1_name, arg1_type, arg2_name, arg2_type):
- """Determine whether the types of two variables are the same."""
- if arg1_type != arg2_type:
- raise TypeError(f'The type of `{arg1_name}` and `{arg2_name}` should be same.')
-
- @staticmethod
- def check_value_on_integer(arg_name, arg_value, value, rel):
- """Judging integer type."""
- rel_fn = Rel.get_fns(rel)
- type_match = isinstance(arg_value, int)
- if type_match and (not rel_fn(arg_value, value)):
- rel_str = Rel.get_strs(rel).format(value)
- raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
- return arg_value
-
- @staticmethod
- def check_param_equal(param1_name, param1_value, param2_name, param2_value):
- """Judging the equality of parameters."""
- if param1_value != param2_value:
- raise ValueError(f"`{param1_name}` must equal `{param2_name}`,"
- f" but got `{param1_name}` = {param1_value},"
- f" `{param2_name}` = {param2_value}.")
-
- @staticmethod
- def check_const_input(arg_name, arg_value):
- """Check valid value."""
- if arg_value is None:
- raise ValueError(f'The `{arg_name}` must be a const input, but got {arg_value}.')
-
- @staticmethod
- def check_float_positive(arg_name, arg_value):
- """Float type judgment."""
- if isinstance(arg_value, float):
- if arg_value > 0:
- return arg_value
- raise ValueError(f"The `{arg_name}` must be positive, but got {arg_value}.")
-
- raise TypeError(f"`{arg_name}` must be float!")
-
- @staticmethod
- def check_pad_value_by_mode(op_name, pad_mode, padding):
- """Validate value of padding according to pad_mode"""
- if pad_mode != 'pad' and padding != 0:
- raise ValueError(f"For op '{op_name}', padding must be zero when pad_mode is '{pad_mode}'.")
- return padding
-
- @staticmethod
- def check_empty_shape_input(arg_name, arg_value):
- """Check zeros value."""
- if 0 in arg_value:
- raise ValueError(f"Input `{arg_name}` cannot be empty.")
-
- @staticmethod
- def check_scalar_shape_input(arg_name, arg_value):
- """Check scalar shape input."""
- if arg_value != []:
- raise ValueError(f"Input `{arg_name}` shape should be (). got {arg_value}")
-
-
- def check_int(input_param):
- """Int type judgment."""
- if isinstance(input_param, int) and not isinstance(input_param, bool):
- return input_param
- raise TypeError("Input type must be int!")
-
-
- def check_int_positive(input_param):
- """Int type judgment."""
- if isinstance(input_param, bool):
- raise TypeError("Input type must be int cannot be bool!")
- if isinstance(input_param, int):
- if input_param > 0:
- return input_param
- raise ValueError("The input_param must be positive, but got input_param {}.".format(input_param))
- raise TypeError("Input type must be int cannot be {}!".format(type(input_param)))
-
-
- def check_int_non_negative(input_param):
- """Non_negative type judgment."""
- if isinstance(input_param, bool):
- raise TypeError("Input type must be int cannot be bool!")
- if isinstance(input_param, int):
- if input_param >= 0:
- return input_param
- raise ValueError("The input_param must be non_negative, but got input_param {}.".format(input_param))
- raise TypeError("Input type must be int cannot be {}!".format(type(input_param)))
-
-
- def check_int_zero_one(input_param):
- """Judge whether it is 0 or 1."""
- if input_param in (0, 1):
- return input_param
- raise ValueError("The data must be 0 or 1.")
-
-
- def check_bool(input_param):
- """Bool type judgment."""
- if isinstance(input_param, bool):
- return input_param
- raise TypeError("Input type must be bool!")
-
-
- def check_input_format(input_param):
- """Judge input format."""
- if input_param == "NCHW":
- return input_param
- raise ValueError("The data format must be NCHW.")
-
-
- def check_padding(padding):
- """Check padding."""
- if padding >= 0:
- return padding
- raise ValueError("The padding must be at least 0,"" but got padding {}.".format(padding))
-
-
- def check_padmode(mode):
- """Check padmode."""
- if mode in ("same", "valid", "pad"):
- return mode
- raise ValueError("The pad mode must be same or valid or pad,"" but got mode {}.".format(mode))
-
-
- def check_tensor_supported_type(dtype):
- """Check tensor dtype."""
- if dtype in (mstype.int32, mstype.float32):
- return dtype
- raise ValueError("The dtype must be mstype.int32 or mstype.float32, but got mstype {}.".format(dtype))
-
-
- def _expand_tuple(n_dimensions):
- """To expand a number to tuple."""
-
- def convert(m):
- if not isinstance(m, tuple):
- if isinstance(m, int):
- return tuple(repeat(m, n_dimensions))
- raise TypeError("Input type must be int or tuple.")
-
- if not len(m) is n_dimensions:
- raise TypeError("Input dimension is incorrect.")
-
- for i in m:
- if not isinstance(i, int):
- raise TypeError("Incorrect type inside of a tuple!")
- return m
-
- return convert
-
-
- def check_input_data(*data, data_class):
- """Input data check."""
- for item in data:
- if isinstance(item, (list, tuple)):
- for v in item:
- check_input_data(v, data_class=data_class)
- else:
- if not isinstance(item, data_class):
- raise ValueError(f'Please provide as model inputs'
- f' either a single'
- f' or a list of {data_class.__name__},'
- f' but got part data type is {str(type(item))}.')
- if item.size() == 0:
- msg = "Please provide non-empty data."
- logger.error(msg)
- raise ValueError(msg)
-
-
- def check_output_data(data):
- """Output data check."""
- if not data:
- raise RuntimeError('Executor return data ' + str(data) + ', please check your net or input data.')
-
-
- once = _expand_tuple(1)
- twice = _expand_tuple(2)
- triple = _expand_tuple(3)
- valid_data_types = (int, float, np.int8, np.int16, np.int32, np.int64,
- np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
- np.float32, np.float64, bool, np.bool_)
-
-
- def check_type(arg_name, arg_value, valid_types):
- """Check value type."""
- # if input type is Tensor ,get element type
- if isinstance(arg_value, type(mstype.tensor)):
- arg_value = arg_value.element_type()
-
- # First, check if arg_value has argvalid_types
- if isinstance(arg_value, tuple(valid_types)):
- return type(arg_value).__name__
-
- # Second, wrap arg_value with numpy array so that it can be checked through numpy api
- if isinstance(arg_value, (list, tuple)):
- arg_value = np.array(arg_value)
-
- # Thirdly, check the data type by numpy's dtype api
- valid = False
- if isinstance(arg_value, np.ndarray):
- valid = arg_value.dtype in valid_data_types
-
- # Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
- # `check_type('x', True, [bool, int])` will check pass
- if isinstance(arg_value, bool) and bool not in tuple(valid_types):
- valid = False
-
- if not valid:
- type_names = [t.__name__ for t in valid_types]
- if len(valid_types) == 1:
- raise TypeError(f'The type of `{arg_name}` should be {type_names[0]},'
- f' but got {type(arg_value).__name__}.')
- raise TypeError(f'The type of `{arg_name}` should be one of {type_names},'
- f' but got {type(arg_value).__name__}.')
-
- return type(arg_value).__name__
-
-
- def check_typename(arg_name, arg_type, valid_types):
- """Check type name."""
-
- def get_typename(t):
- return t.__name__ if hasattr(t, '__name__') else str(t)
-
- if isinstance(arg_type, type(mstype.tensor)):
- arg_type = arg_type.element_type()
-
- if arg_type in valid_types:
- return arg_type
- if isinstance(arg_type, tuple(valid_types)):
- return arg_type
- type_names = [get_typename(t) for t in valid_types]
- if len(valid_types) == 1:
- raise TypeError(f'The type of `{arg_name}` should be {type_names[0]},'
- f' but got {get_typename(arg_type)}.')
- raise TypeError(f'The type of `{arg_name}` should be one of {type_names},'
- f' but got {get_typename(arg_type)}.')
-
-
- def check_shape(arg_name, arg_value):
- """Check shape."""
- # First, check if shape is a tuple
- if not isinstance(arg_value, tuple):
- raise TypeError(f'The type of `{arg_name}` should be one of {tuple.__name__},'
- f' but got {type(arg_value).__name__}.')
-
- # Second, wrap arg_value with numpy array so that it can be checked through numpy api
- arg_value = np.array(arg_value)
-
- # shape can not be ()
- if arg_value.size == 0:
- raise ValueError('Shape can not be empty.')
-
- # shape's dimension should be 1
- if arg_value.ndim != 1:
- raise ValueError('Shape of tensor should be 1-dim vector, but got {}-dim.'.format(arg_value.ndim))
-
- # Thirdly, check each element's type of the shape
- valid_types = (int, np.int8, np.int16, np.int32, np.int64,
- np.uint8, np.uint16, np.uint32, np.uint64)
- for dim_size in arg_value:
- if not isinstance(dim_size, valid_types) or dim_size <= 0:
- raise ValueError('Every dimension size of the tensor shape should be a positive integer,'
- ' but got {}.'.format(dim_size))
-
-
- def _check_str_by_regular(target, reg=None, flag=re.ASCII):
- if reg is None:
- reg = _name_re
- if re.match(reg, target, flag) is None:
- raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag))
- return True
-
-
- def args_type_check(*type_args, **type_kwargs):
- """Check whether input data type is correct."""
-
- def type_check(func):
- sig = inspect.signature(func)
- bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments
-
- @wraps(func)
- def wrapper(*args, **kwargs):
- nonlocal bound_types
- bound_values = sig.bind(*args, **kwargs)
- argument_dict = bound_values.arguments
- if "kwargs" in bound_types:
- bound_types = bound_types["kwargs"]
- if "kwargs" in argument_dict:
- argument_dict = argument_dict["kwargs"]
- for name, value in argument_dict.items():
- if name in bound_types:
- if value is not None and not isinstance(value, bound_types[name]):
- raise TypeError('Argument {} must be {}'.format(name, bound_types[name]))
- return func(*args, **kwargs)
-
- return wrapper
-
- return type_check
|