|
- # Copyright 2019 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- """validation check functions"""
- from functools import wraps, reduce
- from _akg.utils.format_transform import get_shape
-
- MAX_DATA_SIZE = 2 ** 31
-
- def check_input_type_dict(input_dict, input_key, input_name):
- """
- check input parameter type for new type: dict.
-
- Note:
- rule1: key of input_dict should be in the input_key
- rule2: type of input_dict[shape] should be in (list, tuple), if have shape
- rule3: type of input_dict[dtype] should be in (str), if have dtype
-
- Args:
- input_dict (dict): input_dict
- input_key (list or tuple): all input key list, the key of input must in input_key
- input_name (str): input param name, only used for error print
-
- Returns:
- None
- """
- def _check_input_type(input_key, input_type):
- if not isinstance(input_dict[input_key], input_type):
- raise RuntimeError(
- "the input parameter %s[%s] must be %s, while type of input is %s" %
- (input_name, input_key, input_type, type(input_dict[input_key])))
-
- for key in input_dict.keys():
- if key not in input_key:
- raise RuntimeError(
- "the input parameter %s must have arrt <%s>" %
- (input_name, key))
-
- # check shape's type of input_dict, if have shape
- if key == "shape":
- _check_input_type(key, (list, tuple))
-
- # check dtype's type of input_dict, if have dtype
- if key == "dtype":
- _check_input_type(key, (str,))
-
-
- def check_input_type_list_tuple(inputs, expect):
- """check inputs by a list or tuple of expected types."""
- if not isinstance(inputs, expect[1][0]):
- raise RuntimeError("the input parameter %s must be (list, tuple), while"
- " type of input is %s" % (expect[0], type(inputs)))
- for inp in inputs:
- if not isinstance(inp, expect[1][1]):
- raise RuntimeError("The element in parameter %s must be %s, while "
- "type of input is %s" % (
- expect[0], expect[1][1], type(inp)))
-
-
- def check_input_type(*type_args, **_type_kwargs):
- """check input parameter type."""
- def out_wrapper(func):
- """outer wrapper function."""
- formal_parameter = func.__code__.co_varnames
- formal_parameter_list = list(zip(formal_parameter, type_args))
-
- @wraps(func)
- def in_wrapper(*args, **kwargs):
- """inner wrapper function."""
- for i, arg_v in enumerate(args):
- # add for new input dict, if dict, will check shape and dtype
- if isinstance(arg_v, dict):
- check_input_type_dict(arg_v, arg_v.keys(),
- formal_parameter_list[i][0])
-
- if isinstance(formal_parameter_list[i][1], tuple):
- if isinstance(formal_parameter_list[i][1][0], tuple) \
- and len(formal_parameter_list[i][1]) == 2:
- check_input_type_list_tuple(arg_v, formal_parameter_list[i])
- continue
-
- if not isinstance(arg_v, formal_parameter_list[i][1]):
- raise RuntimeError("the %sth input parameter %s must be %s, "
- "while type of input is %s" % (str(i), formal_parameter_list[i][0],
- formal_parameter_list[i][1],
- type(arg_v)))
- for i in kwargs:
- for j in formal_parameter_list:
- if i in j:
- if not isinstance(kwargs[i], j[1]):
- raise RuntimeError("the input parameter %s must be "
- "%s, while type of input is %s"
- "" % (i, j[1], type(kwargs[i])))
- break
- return func(*args, **kwargs)
-
- return in_wrapper
-
- return out_wrapper
-
-
- def shape_dtype_max_size_check(shape):
- """check validation of tensor's shape."""
- if shape:
- mul = int(reduce(lambda x, y: int(x) * int(y), shape))
- if mul > MAX_DATA_SIZE:
- error_msg = "*".join([str(sh) for sh in shape])
- raise RuntimeError("Invalid shape, data is {} bytes ({}), which "
- "exceed max data size {} bytes"
- .format(mul, error_msg, MAX_DATA_SIZE))
-
-
- def check_shape(tensor, length=None, tensor_name=""):
- """The common check rule for placeholder data."""
- shape = get_shape(tensor)
- if not shape:
- raise RuntimeError("The ndim of input tensor {} must more than 0, "
- "actual input is {}".format(tensor_name, len(shape)))
-
- for shape_v in shape:
- if not isinstance(shape_v, int) or shape_v <= 0:
- raise RuntimeError("The type of tensor {} axis value must be "
- "positive int and value more than 0,"
- "actual input is ({}) {}".
- format(tensor_name, type(shape_v), shape_v))
-
- if length and len(shape) != length:
- raise ValueError('The length of {} should be {}, while actual length is {}'.
- format(tensor_name, length, len(shape)))
-
-
- def ops_dtype_check(dtype, args):
- """check validation of op's dtype."""
- expected_dtype = list()
-
- def _get_expect_dtype(expected_dtype, arg):
- if isinstance(arg, str):
- expected_dtype.append(arg)
- elif isinstance(arg, (list, tuple)):
- for t in arg:
- _get_expect_dtype(expected_dtype, t)
- else:
- raise TypeError("arg should be either a string, "
- "or a list/tuple of string, "
- "while current is {}".format(type(arg)))
-
- _get_expect_dtype(expected_dtype, args)
-
- if isinstance(dtype, (list, tuple)):
- checking_dtype = [d.lower() for d in dtype]
- elif isinstance(dtype, str):
- checking_dtype = [dtype.lower()]
- else:
- raise TypeError("dtype should be either a string or a tuple/list of string")
- error_msg = "Supported dtype: {}, while received dtype: {}"
- if not set(checking_dtype).issubset(set(expected_dtype)):
- raise RuntimeError(error_msg.format(expected_dtype, checking_dtype))
-
-
- def reduce_axis_check(reduce_shape, reduce_axis):
- """check validation of reduce axis for certain reduce shape."""
- dim = len(reduce_shape)
- if dim == 1 and int(reduce_shape[0]) == 1:
- raise RuntimeError("Error, reduce shape is 1. Scalar is not supported "
- "for reduction, please input a vector.")
- if isinstance(reduce_axis, int):
- if reduce_axis not in range(-dim, dim):
- raise RuntimeError("Reduce axis should be in range [%d. %d)"
- "" % (-dim, dim))
- elif isinstance(reduce_axis, (tuple, list)):
- if len(reduce_axis) > len(reduce_shape):
- raise RuntimeError("Reduce axis list exceed reduce shape length: "
- "%d vs %d, error" % (len(reduce_axis), len(reduce_shape)))
- processed_axis = []
- for axis in reduce_axis:
- processed_axis.append(int(axis + dim) if axis < 0 else int(axis))
- if len(set(processed_axis)) < len(processed_axis):
- raise RuntimeError("Reduce axis list contains %d duplicated element, please check"
- % (len(processed_axis) - len(set(processed_axis))))
- for axis in processed_axis:
- if axis >= dim:
- raise RuntimeError("Invalid reduce axis, axis should less than %d" % dim)
- elif reduce_axis is not None:
- raise RuntimeError("axis should be a list, tuple or int.")
-
-
- def elemwise_dtype_check(dtype_a, dtype_b, supported_type=None):
- """check validation of tensor's dtype for element-wise op."""
- if supported_type:
- ops_dtype_check(dtype_a, supported_type)
- ops_dtype_check(dtype_b, supported_type)
- if dtype_a.lower() != dtype_b.lower():
- raise RuntimeError("Element-wise operation needs same data type, while "
- "current is %s vs %s" % (dtype_a.lower(), dtype_b.lower()))
-
-
- def auto_broadcast_check(shape_a, shape_b):
- """automatic broadcast check."""
- shape_l = get_shape(shape_a)
- shape_r = get_shape(shape_b)
-
- if len(shape_l) <= len(shape_r):
- shape_short = shape_l
- shape_long = shape_r
- else:
- shape_short = shape_r
- shape_long = shape_l
-
- dim_diff = len(shape_long) - len(shape_short)
- for i in range(dim_diff):
- shape_short.insert(0, 1)
- for i, shp in enumerate(shape_short):
- if int(shp) != int(shape_long[i]) and 1 not in [int(shp), int(shape_long[i])]:
- raise RuntimeError("Invalid auto broadcast, dim %d should be 1 or equal, "
- "while now is %d vs %d" % (i, shp, shape_long[i]))
-
-
- def check_int_list(array, array_name):
- """check whether all the elements are integers."""
- for num in array:
- if not isinstance(num, int):
- raise RuntimeError("Type of value in %s should be int, but got type %s" % (array_name, type(num)))
|