You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validation_check.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """validation check functions"""
  15. from functools import wraps, reduce
  16. from _akg.utils.format_transform import get_shape
  17. MAX_DATA_SIZE = 2 ** 31
  18. def check_input_type_dict(input_dict, input_key, input_name):
  19. """
  20. check input parameter type for new type: dict.
  21. Note:
  22. rule1: key of input_dict should be in the input_key
  23. rule2: type of input_dict[shape] should be in (list, tuple), if have shape
  24. rule3: type of input_dict[dtype] should be in (str), if have dtype
  25. Args:
  26. input_dict (dict): input_dict
  27. input_key (list or tuple): all input key list, the key of input must in input_key
  28. input_name (str): input param name, only used for error print
  29. Returns:
  30. None
  31. """
  32. def _check_input_type(input_key, input_type):
  33. if not isinstance(input_dict[input_key], input_type):
  34. raise RuntimeError(
  35. "the input parameter %s[%s] must be %s, while type of input is %s" %
  36. (input_name, input_key, input_type, type(input_dict[input_key])))
  37. for key in input_dict.keys():
  38. if key not in input_key:
  39. raise RuntimeError(
  40. "the input parameter %s must have arrt <%s>" %
  41. (input_name, key))
  42. # check shape's type of input_dict, if have shape
  43. if key == "shape":
  44. _check_input_type(key, (list, tuple))
  45. # check dtype's type of input_dict, if have dtype
  46. if key == "dtype":
  47. _check_input_type(key, (str,))
  48. def check_input_type_list_tuple(inputs, expect):
  49. """check inputs by a list or tuple of expected types."""
  50. if not isinstance(inputs, expect[1][0]):
  51. raise RuntimeError("the input parameter %s must be (list, tuple), while"
  52. " type of input is %s" % (expect[0], type(inputs)))
  53. for inp in inputs:
  54. if not isinstance(inp, expect[1][1]):
  55. raise RuntimeError("The element in parameter %s must be %s, while "
  56. "type of input is %s" % (
  57. expect[0], expect[1][1], type(inp)))
  58. def check_input_type(*type_args, **_type_kwargs):
  59. """check input parameter type."""
  60. def out_wrapper(func):
  61. """outer wrapper function."""
  62. formal_parameter = func.__code__.co_varnames
  63. formal_parameter_list = list(zip(formal_parameter, type_args))
  64. @wraps(func)
  65. def in_wrapper(*args, **kwargs):
  66. """inner wrapper function."""
  67. for i, arg_v in enumerate(args):
  68. # add for new input dict, if dict, will check shape and dtype
  69. if isinstance(arg_v, dict):
  70. check_input_type_dict(arg_v, arg_v.keys(),
  71. formal_parameter_list[i][0])
  72. if isinstance(formal_parameter_list[i][1], tuple):
  73. if isinstance(formal_parameter_list[i][1][0], tuple) \
  74. and len(formal_parameter_list[i][1]) == 2:
  75. check_input_type_list_tuple(arg_v, formal_parameter_list[i])
  76. continue
  77. if not isinstance(arg_v, formal_parameter_list[i][1]):
  78. raise RuntimeError("the %sth input parameter %s must be %s, "
  79. "while type of input is %s" % (str(i), formal_parameter_list[i][0],
  80. formal_parameter_list[i][1],
  81. type(arg_v)))
  82. for i in kwargs:
  83. for j in formal_parameter_list:
  84. if i in j:
  85. if not isinstance(kwargs[i], j[1]):
  86. raise RuntimeError("the input parameter %s must be "
  87. "%s, while type of input is %s"
  88. "" % (i, j[1], type(kwargs[i])))
  89. break
  90. return func(*args, **kwargs)
  91. return in_wrapper
  92. return out_wrapper
  93. def shape_dtype_max_size_check(shape):
  94. """check validation of tensor's shape."""
  95. if shape:
  96. mul = int(reduce(lambda x, y: int(x) * int(y), shape))
  97. if mul > MAX_DATA_SIZE:
  98. error_msg = "*".join([str(sh) for sh in shape])
  99. raise RuntimeError("Invalid shape, data is {} bytes ({}), which "
  100. "exceed max data size {} bytes"
  101. .format(mul, error_msg, MAX_DATA_SIZE))
  102. def check_shape(tensor, length=None, tensor_name=""):
  103. """The common check rule for placeholder data."""
  104. shape = get_shape(tensor)
  105. if not shape:
  106. raise RuntimeError("The ndim of input tensor {} must more than 0, "
  107. "actual input is {}".format(tensor_name, len(shape)))
  108. for shape_v in shape:
  109. if not isinstance(shape_v, int) or shape_v <= 0:
  110. raise RuntimeError("The type of tensor {} axis value must be "
  111. "positive int and value more than 0,"
  112. "actual input is ({}) {}".
  113. format(tensor_name, type(shape_v), shape_v))
  114. if length and len(shape) != length:
  115. raise ValueError('The length of {} should be {}, while actual length is {}'.
  116. format(tensor_name, length, len(shape)))
  117. def ops_dtype_check(dtype, args):
  118. """check validation of op's dtype."""
  119. expected_dtype = list()
  120. def _get_expect_dtype(expected_dtype, arg):
  121. if isinstance(arg, str):
  122. expected_dtype.append(arg)
  123. elif isinstance(arg, (list, tuple)):
  124. for t in arg:
  125. _get_expect_dtype(expected_dtype, t)
  126. else:
  127. raise TypeError("arg should be either a string, "
  128. "or a list/tuple of string, "
  129. "while current is {}".format(type(arg)))
  130. _get_expect_dtype(expected_dtype, args)
  131. if isinstance(dtype, (list, tuple)):
  132. checking_dtype = [d.lower() for d in dtype]
  133. elif isinstance(dtype, str):
  134. checking_dtype = [dtype.lower()]
  135. else:
  136. raise TypeError("dtype should be either a string or a tuple/list of string")
  137. error_msg = "Supported dtype: {}, while received dtype: {}"
  138. if not set(checking_dtype).issubset(set(expected_dtype)):
  139. raise RuntimeError(error_msg.format(expected_dtype, checking_dtype))
  140. def reduce_axis_check(reduce_shape, reduce_axis):
  141. """check validation of reduce axis for certain reduce shape."""
  142. dim = len(reduce_shape)
  143. if dim == 1 and int(reduce_shape[0]) == 1:
  144. raise RuntimeError("Error, reduce shape is 1. Scalar is not supported "
  145. "for reduction, please input a vector.")
  146. if isinstance(reduce_axis, int):
  147. if reduce_axis not in range(-dim, dim):
  148. raise RuntimeError("Reduce axis should be in range [%d. %d)"
  149. "" % (-dim, dim))
  150. elif isinstance(reduce_axis, (tuple, list)):
  151. if len(reduce_axis) > len(reduce_shape):
  152. raise RuntimeError("Reduce axis list exceed reduce shape length: "
  153. "%d vs %d, error" % (len(reduce_axis), len(reduce_shape)))
  154. processed_axis = []
  155. for axis in reduce_axis:
  156. processed_axis.append(int(axis + dim) if axis < 0 else int(axis))
  157. if len(set(processed_axis)) < len(processed_axis):
  158. raise RuntimeError("Reduce axis list contains %d duplicated element, please check"
  159. % (len(processed_axis) - len(set(processed_axis))))
  160. for axis in processed_axis:
  161. if axis >= dim:
  162. raise RuntimeError("Invalid reduce axis, axis should less than %d" % dim)
  163. elif reduce_axis is not None:
  164. raise RuntimeError("axis should be a list, tuple or int.")
  165. def elemwise_dtype_check(dtype_a, dtype_b, supported_type=None):
  166. """check validation of tensor's dtype for element-wise op."""
  167. if supported_type:
  168. ops_dtype_check(dtype_a, supported_type)
  169. ops_dtype_check(dtype_b, supported_type)
  170. if dtype_a.lower() != dtype_b.lower():
  171. raise RuntimeError("Element-wise operation needs same data type, while "
  172. "current is %s vs %s" % (dtype_a.lower(), dtype_b.lower()))
  173. def auto_broadcast_check(shape_a, shape_b):
  174. """automatic broadcast check."""
  175. shape_l = get_shape(shape_a)
  176. shape_r = get_shape(shape_b)
  177. if len(shape_l) <= len(shape_r):
  178. shape_short = shape_l
  179. shape_long = shape_r
  180. else:
  181. shape_short = shape_r
  182. shape_long = shape_l
  183. dim_diff = len(shape_long) - len(shape_short)
  184. for i in range(dim_diff):
  185. shape_short.insert(0, 1)
  186. for i, shp in enumerate(shape_short):
  187. if int(shp) != int(shape_long[i]) and 1 not in [int(shp), int(shape_long[i])]:
  188. raise RuntimeError("Invalid auto broadcast, dim %d should be 1 or equal, "
  189. "while now is %d vs %d" % (i, shp, shape_long[i]))
  190. def check_int_list(array, array_name):
  191. """check whether all the elements are integers."""
  192. for num in array:
  193. if not isinstance(num, int):
  194. raise RuntimeError("Type of value in %s should be int, but got type %s" % (array_name, type(num)))