|
- #!/usr/bin/env python3
- # coding: utf-8
- # Copyright 2019 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- """validation check functions"""
- from functools import wraps, reduce
- from enum import Enum
- import akg.tvm
- import akg.topi
- from akg.utils.format_transform import get_bytes, get_shape
-
- MAX_DATA_SIZE = 2 ** 31
-
-
- class DtypeForDavinci(Enum):
- """Davinci supported dtype."""
- ALL_TYPES = ["float16", "float32", "int32", "int8", "uint8"]
- ALL_FLOAT = ["float16", "float32"]
- ALL_INT = ["int8", "int32"]
- ALL_UINT = ["uint8"]
- FLOAT16 = ["float16"]
- FLOAT32 = ["float32"]
- INT8 = ["int8"]
- INT16 = ["int16"]
- INT32 = ["int32"]
- INT64 = ["int64"]
- UINT8 = ["uint8"]
- UINT16 = ["uint16"]
- UINT32 = ["uint32"]
- UINT64 = ["uint64"]
- BOOL = ["bool"]
-
-
- supported_bits = {
- "8": 1, "16": 2, "32": 4, "64": 8, "bool": 1
- }
-
-
- def check_input_type_dict(input_dict, input_key, input_name):
- """
- check input parameter type for new type: dict.
-
- Note:
- rule1: key of input_dict should be in the input_key
- rule2: type of input_dict[shape] should be in (list, tuple), if have shape
- rule3: type of input_dict[dtype] should be in (str), if have dtype
-
- Args:
- input_dict (dict): input_dict
- input_key (list or tuple): all input key list, the key of input must in input_key
- input_name (str): input param name, only used for error print
-
- Returns:
- None
- """
- def _check_input_type(input_key, input_type):
- if not isinstance(input_dict[input_key], input_type):
- raise RuntimeError(
- "the input parameter %s[%s] must be %s, while type of input is %s" %
- (input_name, input_key, input_type, type(input_dict[input_key])))
-
- for key in input_dict.keys():
- if key not in input_key:
- raise RuntimeError(
- "the input parameter %s must have arrt <%s>" %
- (input_name, key))
-
- # check shape's type of input_dict, if have shape
- if key == "shape":
- _check_input_type(key, (list, tuple))
-
- # check dtype's type of input_dict, if have dtype
- if key == "dtype":
- _check_input_type(key, (str,))
-
-
- def check_input_type_list_tuple(inputs, expect):
- """check inputs by a list or tuple of expected types."""
- if not isinstance(inputs, expect[1][0]):
- raise RuntimeError("the input parameter %s must be (list, tuple), while"
- " type of input is %s" % (expect[0], type(inputs)))
- for inp in inputs:
- if not isinstance(inp, expect[1][1]):
- raise RuntimeError("The element in parameter %s must be %s, while "
- "type of input is %s" % (
- expect[0], expect[1][1], type(inp)))
-
- def check_input_type(*type_args, **_type_kwargs):
- """check input parameter type."""
- def out_wrapper(func):
- """outer wrapper function."""
- formal_parameter = func.__code__.co_varnames
- formal_parameter_list = list(zip(formal_parameter, type_args))
-
- @wraps(func)
- def in_wrapper(*args, **kwargs):
- """inner wrapper function."""
- for i, arg_v in enumerate(args):
- # add for new input dict, if dict, will check shape and dtype
- if isinstance(arg_v, dict):
- check_input_type_dict(arg_v, arg_v.keys(),
- formal_parameter_list[i][0])
-
- if isinstance(formal_parameter_list[i][1], tuple):
- if isinstance(formal_parameter_list[i][1][0], tuple) \
- and len(formal_parameter_list[i][1]) == 2:
- check_input_type_list_tuple(arg_v, formal_parameter_list[i])
- continue
-
- if not isinstance(arg_v, formal_parameter_list[i][1]):
- raise RuntimeError("the %sth input parameter %s must be %s, "
- "while type of input is %s" % (str(i), formal_parameter_list[i][0],
- formal_parameter_list[i][1],
- type(arg_v)))
- for i in kwargs:
- for j in formal_parameter_list:
- if i in j:
- if not isinstance(kwargs[i], j[1]):
- raise RuntimeError("the input parameter %s must be "
- "%s, while type of input is %s"
- "" % (i, j[1], type(kwargs[i])))
- break
- return func(*args, **kwargs)
-
- return in_wrapper
-
- return out_wrapper
-
-
- def shape_dtype_max_size_check(shape, dtype):
- """check validation of tensor's shape."""
- if shape:
- for x in shape:
- if not isinstance(x, int):
- return
- mul = get_bytes(dtype) * int(reduce(lambda x, y: int(x) * int(y), shape))
- if mul > MAX_DATA_SIZE:
- error_msg = "*".join([str(sh) for sh in shape])
- raise RuntimeError("Invalid shape, data is {} bytes ({}), which "
- "exceed max data size {} bytes"
- .format(mul, error_msg, MAX_DATA_SIZE))
-
-
- def tensor_max_size_check(tensor):
- """check validation of tensor's shape."""
- if not isinstance(tensor, akg.tvm.tensor.Tensor):
- raise RuntimeError("tensor should be an akg.tvm.tensor.Tensor, but got "
- "type {}".format(type(tensor)))
- shape = get_shape(tensor)
- dtype = tensor.dtype
- shape_dtype_max_size_check(shape, dtype)
-
-
- def check_shape(tensor, length=None, tensor_name=""):
- """The common check rule for placeholder data."""
- shape = get_shape(tensor)
- if not shape:
- raise RuntimeError("The ndim of input tensor {} must more than 0, "
- "actual input is {}".format(tensor_name, len(shape)))
-
- for shape_v in shape:
- if isinstance(shape_v, (akg.tvm.expr.Var, akg.tvm.expr.Mul, akg.tvm.expr.FloorDiv, akg.tvm.expr.IntImm)):
- continue
- if not isinstance(shape_v, int) or shape_v <= 0:
- raise RuntimeError("The type of tensor {} axis value must be "
- "positive int and value more than 0,"
- "actual input is ({}) {}".
- format(tensor_name, type(shape_v), shape_v))
-
- if length and len(shape) != length:
- raise ValueError('The length of {} should be {}, while actual length is {}'.
- format(tensor_name, length, len(shape)))
-
-
- def ops_dtype_check(dtype, args):
- """check validation of op's dtype."""
- expected_dtype = list()
-
- def _get_expect_dtype(expected_dtype, arg):
- if isinstance(arg, str):
- expected_dtype.append(arg)
- elif isinstance(arg, DtypeForDavinci):
- expected_dtype += arg.value
- elif isinstance(arg, (list, tuple)):
- for t in arg:
- _get_expect_dtype(expected_dtype, t)
- else:
- raise TypeError("arg should be either a string, a DtypeForDavinci "
- "or a list/tuple of string or DtypeForDavinci, "
- "while current is {}".format(type(arg)))
-
- _get_expect_dtype(expected_dtype, args)
-
- if isinstance(dtype, (list, tuple)):
- checking_dtype = [d.lower() for d in dtype]
- elif isinstance(dtype, str):
- checking_dtype = [dtype.lower()]
- else:
- raise TypeError("dtype should be either a string or a tuple/list of string")
- error_msg = "Supported dtype: {}, while received dtype: {}"
- if not set(checking_dtype).issubset(set(expected_dtype)):
- raise RuntimeError(error_msg.format(expected_dtype, checking_dtype))
-
-
- def reduce_axis_check(reduce_shape, reduce_axis):
- """check validation of reduce axis for certain reduce shape."""
- dim = len(reduce_shape)
- if dim == 1 and isinstance(reduce_shape[0], int) and int(reduce_shape[0]) == 1:
- raise RuntimeError("Error, reduce shape is 1. Scalar is not supported "
- "for reduction, please input a vector.")
- if isinstance(reduce_axis, int):
- if reduce_axis not in range(-dim, dim):
- raise RuntimeError("Reduce axis should be in range [%d. %d)"
- "" % (-dim, dim))
- elif isinstance(reduce_axis, (tuple, list)):
- if len(reduce_axis) > len(reduce_shape):
- raise RuntimeError("Reduce axis list exceed reduce shape length: "
- "%d vs %d, error" % (len(reduce_axis), len(reduce_shape)))
- processed_axis = []
- for axis in reduce_axis:
- processed_axis.append(int(axis + dim) if axis < 0 else int(axis))
- if len(set(processed_axis)) < len(processed_axis):
- raise RuntimeError("Reduce axis list contains %d duplicated element, please check"
- % (len(processed_axis) - len(set(processed_axis))))
- for axis in processed_axis:
- if axis >= dim:
- raise RuntimeError("Invalid reduce axis, axis should less than %d" % dim)
- elif reduce_axis is not None:
- raise RuntimeError("axis should be a list, tuple or int.")
-
- def elemwise_shape_check(shape_a, shape_b):
- """check validation of tensor's shape for element-wise op."""
- check_shape(shape_a)
- check_shape(shape_b)
- if len(shape_a) != len(shape_b):
- raise RuntimeError("Element-wise operation needs same data length, "
- "while current is %s vs %s" % (len(shape_a), len(shape_b)))
- for i, shp in enumerate(shape_a):
- if int(shp) != int(shape_b[i]):
- raise RuntimeError("Element-wise operation needs same data shape, "
- "while current is %s vs %s" % (shp, shape_b[i]))
-
-
- def elemwise_dtype_check(dtype_a, dtype_b, supported_type=None):
- """check validation of tensor's dtype for element-wise op."""
- if supported_type:
- ops_dtype_check(dtype_a, supported_type)
- ops_dtype_check(dtype_b, supported_type)
- if dtype_a.lower() != dtype_b.lower():
- raise RuntimeError("Element-wise operation needs same data type, while "
- "current is %s vs %s" % (dtype_a.lower(), dtype_b.lower()))
-
-
- def auto_broadcast_check(shape_a, shape_b):
- """automatic broadcast check."""
- shape_l = get_shape(shape_a)
- shape_r = get_shape(shape_b)
-
- if len(shape_l) <= len(shape_r):
- shape_short = shape_l
- shape_long = shape_r
- else:
- shape_short = shape_r
- shape_long = shape_l
-
- dim_diff = len(shape_long) - len(shape_short)
- for i in range(dim_diff):
- shape_short.insert(0, 1)
- for i, shp in enumerate(shape_short):
- if int(shp) != int(shape_long[i]) and 1 not in [int(shp), int(shape_long[i])]:
- raise RuntimeError("Invalid auto broadcast, dim %d should be 1 or equal, "
- "while now is %d vs %d" % (i, shp, shape_long[i]))
-
-
- def broadcast_check(ori_shape, dst_shape):
- """check valid broadcast from ori_shape to dst_shape."""
- shape_l = get_shape(ori_shape)
- shape_r = get_shape(dst_shape)
-
- if len(shape_l) <= len(shape_r):
- dim_diff = len(shape_r) - len(shape_l)
- shape_l = ([1] * dim_diff) + shape_l
- else:
- raise RuntimeError("Cannot broadcast from shape %s to %s" % (str(ori_shape), str(dst_shape)))
-
- for i, shp in enumerate(shape_l):
- if int(shp) != int(shape_r[i]) and int(shp) != 1:
- raise RuntimeError("Cannot broadcast from shape %s to %s" % (str(ori_shape), str(dst_shape)))
-
-
- def gemm_format_check(lhs_input, rhs_input, lhs_trans=False, rhs_trans=False):
- """check gemm format (shape length and value)."""
- dim = len(lhs_input)
- if len(rhs_input) != dim:
- raise RuntimeError("Dimensions are different, lhs input is of %d dimension "
- "while rhs input is of %d dimension, " % (dim, len(rhs_input)))
-
- b_pos = [0] if dim == 3 else [0, 1]
- lhs_k_pos = -2 if lhs_trans else -1
- rhs_k_pos = -1 if rhs_trans else -2
-
- def length_check(tensor):
- if len(tensor) < 2 or len(tensor) > 4:
- raise RuntimeError("Gemm only support 2d shape (height, weight) "
- "or 3d shape (batch, height, weight) "
- "or 4d shape (batch_o, batch_i, height, weight) "
- " while shape length is %d!" % (len(tensor)))
-
- def value_check(loc):
- if loc == "B":
- if len(lhs_input) > 2:
- for pos in b_pos:
- value = int(lhs_input[pos])
- cmp_value = int(rhs_input[pos])
- if value != cmp_value:
- raise RuntimeError("%s size is not compatible, lhs "
- "input: %d and rhs input: %d" %
- (loc, value, cmp_value))
- if loc == "K":
- if isinstance(lhs_input[lhs_k_pos], akg.tvm.expr.Var) or isinstance(rhs_input[rhs_k_pos], akg.tvm.expr.Var):
- return
- value = int(lhs_input[lhs_k_pos])
- cmp_value = int(rhs_input[rhs_k_pos])
- if cmp_value != value:
- raise RuntimeError("%s size is not compatible, lhs :%d, "
- "rhs input: %d " % (loc, value, cmp_value))
-
- for data in [lhs_input, rhs_input]:
- length_check(data)
- for location in ["B", "K"]:
- value_check(location)
-
-
- def convolution_format_check(x_shape, w_shape, pad, stride, dilation):
- """check convolution format."""
- def conv_shape_check(shape):
- if (not isinstance(shape, (tuple, list))) or (len(shape) != 4):
- raise RuntimeError("conv tensor shape should be 4d list or tuple")
-
- conv_dtype = "float16"
- size = get_bytes(conv_dtype)
- for i in shape:
- if (not isinstance(i, int)) or (i <= 0):
- raise RuntimeError("conv tensor shape should be 4d list or "
- "tuple of positive integer")
- size *= i
-
- if size > MAX_DATA_SIZE:
- raise RuntimeError("runtime can not support tensor more than 2G size")
-
- def conv_pad_check(pad):
- if (not isinstance(pad, (tuple, list))) or (len(pad) != 4):
- raise RuntimeError("conv pad should be 4d list or tuple")
-
- for i in pad:
- if (not isinstance(i, int)) or (i < 0):
- raise RuntimeError("conv pad should be 4d list or tuple of "
- "positive integer or zero")
-
- def conv_stride_check(stride):
- if (not isinstance(stride, (tuple, list))) or (len(stride) != 2):
- raise RuntimeError("conv stride should be 2d list or tuple")
-
- for i in stride:
- if (not isinstance(i, int)) or (i <= 0):
- raise RuntimeError("conv stride should be 2d list or tuple of positive integer")
-
- def conv_dilation_check(dilation):
- if (not isinstance(dilation, (tuple, list))) or (len(dilation) != 2):
- raise RuntimeError("conv dilation should be 2d list or tuple")
-
- for i in dilation:
- if (not isinstance(i, int)) or (i <= 0):
- raise RuntimeError("conv dilation should be 2d list or tuple of positive integer")
-
- conv_shape_check(x_shape)
- conv_shape_check(w_shape)
- conv_pad_check(pad)
- conv_stride_check(stride)
- conv_dilation_check(dilation)
-
- if x_shape[1] != w_shape[1]:
- raise RuntimeError("conv's feature_map and filter tensor should be the same channel")
-
- if x_shape[2] + pad[0] + pad[1] < w_shape[2]:
- raise RuntimeError("kernel_h should be <= h + pad_left + pad_right: %d"
- "" % (x_shape[2] + pad[0] + pad[1]))
-
- if x_shape[3] + pad[2] + pad[3] < w_shape[3]:
- raise RuntimeError("kernel_w should be <= w + pad_top + pad_bottom: %d"
- "" % (x_shape[3] + pad[2] + pad[3]))
-
- if (pad[0] >= w_shape[2]) or (pad[1] >= w_shape[2]) \
- or (pad[2] >= w_shape[3]) or (pad[3] >= w_shape[3]):
- raise RuntimeError("pad value cannot be more than the filter value")
-
-
- def davinci_format_check(shape, tensor_format, dim=-1):
- """check validation of tensor's shape for certain format used in davinci chip."""
- all_format_shape = {"NCHW": 4,
- "NHWC": 4,
- "NC1HWC0": 5,
- "DefaultFormat": [2, 4]}
- if dim not in [-1, 2, 4, 5]:
- raise RuntimeError("Only support 2d, 4d, 5d format check, please set "
- "dim to the dim want to check "
- "or use default value -1 to check both all the dim")
- if dim == -1:
- support_format_shape = all_format_shape
- else:
- support_format_shape = {}
- for k, v in all_format_shape.items():
- if isinstance(v, int) and v == dim:
- support_format_shape[k] = v
- if isinstance(v, list) and dim in v:
- support_format_shape[k] = v
-
- support_shape = {"NC1HWC0": (4, 16)}
- if not isinstance(tensor_format, str):
- raise RuntimeError("Invalid davinci format, should be a string, "
- "but get %s" % (type(tensor_format)))
- if tensor_format not in support_format_shape.keys():
- raise RuntimeError("Invalid davinci format {}, davinci support {}"
- .format(tensor_format, support_format_shape.keys()))
- if isinstance(support_format_shape[tensor_format], int):
- if len(shape) != support_format_shape[tensor_format]:
- raise RuntimeError("Invalid shape {} for davinci format {}, needs "
- "{} dim shape, current length{}"
- .format(shape, tensor_format,
- support_format_shape[tensor_format], len(shape)))
- if isinstance(support_format_shape[tensor_format], list):
- if len(shape) not in support_format_shape[tensor_format]:
- raise RuntimeError("Invalid shape {} for davinci format {}, needs {} dim shape"
- .format(shape, tensor_format,
- support_format_shape[tensor_format]))
- if tensor_format in support_shape.keys():
- check_dim = support_shape[tensor_format][0]
- expect_shape = support_shape[tensor_format][1]
- if int(shape[check_dim]) != expect_shape:
- raise RuntimeError("Invalid shape {} for davinci format {}, dim {} "
- "should be {}, while current is {}"
- .format(shape, tensor_format, check_dim,
- expect_shape, shape[check_dim]))
-
-
- def is_valid_reduce_axis(tensor, reduce_axis):
- """
- if the reduce axis correspond to shape[axis] is 1, we can not refine the shape,or the reduce axis will be wrong.
-
- Args:
- tensor (tvm.tensor.Tensor): input tensor.
- reduce_axis (Union[list, tuple, int]): axis want to reduce.
-
- Returns:
- True or False.
- """
- # if the reduce axis correspond to shape[axis] is 1, we can not refine the
- # shape,or the reduce axis will be wrong
- # need_shape_refine = True
- shape = get_shape(tensor)
- if hasattr(reduce_axis, 'index'):
- for id_ite in reduce_axis:
- if shape[id_ite] == 1:
- return False
- else:
- if shape[reduce_axis] == 1:
- return False
-
- return True
-
-
- def axis_check(shape_len, axis):
- """Check the value of axis and return the sorted axis."""
- def _axis_value_type_check(value):
- if not isinstance(value, int):
- raise RuntimeError("type of axis value should be int")
- if value >= shape_len or value < -shape_len:
- raise RuntimeError(
- "input axis is out of range, axis value can be from %d to %d" %
- (-shape_len, shape_len - 1))
- if value < 0:
- value = shape_len + value
- return value
- if not hasattr(axis, 'index'):
- axis = _axis_value_type_check(axis)
- return axis
- for i, axs in enumerate(axis):
- axis[i] = _axis_value_type_check(axs)
-
- axis = sorted(set(axis))
- return axis
-
-
- def check_value_on_integer(arg_name, arg_value, low=None, high=None):
- """Judging integer type."""
- type_match = isinstance(arg_value, int) and not isinstance(arg_value, bool)
- if not type_match:
- raise ValueError("%s should be an int , but got type %s"
- "" % (arg_name, type(arg_value)))
- if low and arg_value < low:
- raise ValueError("%s should be greater than or equal to %f, but got %f"
- "" % (arg_name, low, arg_value))
- if high and arg_value >= high:
- raise ValueError("%s should be less than %f, but got %f"
- "" % (arg_name, high, arg_value))
-
-
- def check_typename(arg_name, arg_type, valid_types):
- """Does it contain the _name_ attribute."""
- def get_typename(t):
- return t.__name__ if hasattr(t, '__name__') else str(t)
-
- if arg_type in valid_types:
- return arg_type
- type_names = [get_typename(t) for t in valid_types]
- if len(valid_types) == 1:
- raise ValueError('type of {} should be {}, but got {}'.format(
- arg_name, type_names[0], get_typename(arg_type)))
- raise ValueError('type of {} should be one of {}, but got {}'.format(
- arg_name, type_names, get_typename(arg_type)))
-
-
- def check_equal(arg_name1, arg_name2, arg1, arg2,):
- """Check equal."""
- if arg1 != arg2:
- raise ValueError('{} should be equal to {}'.format(arg_name1, arg_name2))
-
-
- def check_greater(arg_name1, arg_name2, arg1, arg2,):
- """Check greater."""
- if arg1 <= arg2:
- raise ValueError('{} should be greater than {}'.format(arg_name1, arg_name2))
-
-
- def check_5d(arg_name, shape5d, shape4d):
- """Check 5D shape."""
- blocksize = 16
- if len(shape4d) != 4:
- raise ValueError('invalid 4D shape of {}'.format(arg_name))
- if len(shape5d) != 5:
- raise ValueError('invalid 5D shape of {}'.format(arg_name))
- d1, d2, d3, d4 = shape4d
- if [x.value for x in shape5d] != [d1, (d2 + blocksize - 1) // blocksize, d3, d4, blocksize]:
- raise ValueError('the 4D shape and 5D shape of {} do not match'.format(arg_name))
-
-
- def check_shape_length_equal(tensor_name, tensor_shape, shape_length):
- """Shape length equal judgment."""
- if isinstance(shape_length, (tuple, list)):
- if not len(tensor_shape) in shape_length:
- raise ValueError("The shape length of {tensor_name} should be one of "
- "{shape_length}, but get {tensor_shape_len}"
- "".format(
- tensor_name=tensor_name,
- shape_length=shape_length,
- tensor_shape_len=len(tensor_shape)))
- elif len(tensor_shape) != shape_length:
- raise ValueError("The shape length of {tensor_name} should be "
- "{shape_length}, but get {tensor_shape_len}"
- "".format(
- tensor_name=tensor_name,
- shape_length=shape_length,
- tensor_shape_len=len(tensor_shape)))
-
-
- def check_shape_length_greater(tensor_name, tensor_shape, shape_length):
- """Shape length greater judgment."""
- if len(tensor_shape) <= shape_length:
- raise ValueError("The shape length of {tensor_name} should be greater "
- "than {shape_length}, but get {tensor_shape_len}".format(
- tensor_name=tensor_name, shape_length=shape_length,
- tensor_shape_len=len(tensor_shape)))
-
-
- def judge_var(num):
- """judge var if a tvm.var, tvm.const or python data type."""
- var_dict = {
- "python_const": [int, float],
- "tvm_const": [
- akg.tvm.expr.IntImm, akg.tvm.expr.UIntImm, akg.tvm.expr.FloatImm],
- "tvm_var": [akg.tvm.expr.Var]}
- num_type = type(num)
- for i in var_dict:
- if num_type in var_dict[i]:
- return i
- raise RuntimeError("Input var dtype {} error".format(type(num)))
-
-
- def check_pad(arg_name, pad, length=None):
- """Check pad."""
- if not pad:
- raise ValueError("{} should not be None".format(arg_name))
- if not isinstance(pad, (tuple, list)):
- raise ValueError("{} should be tuple or list".format(arg_name))
- for i in pad:
- if not isinstance(i, int):
- raise ValueError("Elements in {} should be int".format(arg_name))
- if i < 0:
- raise ValueError("Elements in {} should not be less than zero"
- "".format(arg_name))
- if length:
- if length != len(pad):
- raise ValueError("The length of {} should be {}".format(
- arg_name, length))
-
- def check_int_list(array, array_name):
- """check whether all the elements are integers."""
- for num in array:
- if not isinstance(num, int):
- raise RuntimeError("Type of value in %s should be int, but got type %s" % (array_name, type(num)))
|