You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validation_check.py 26 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622
  1. #!/usr/bin/env python3
  2. # coding: utf-8
  3. # Copyright 2019 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """validation check functions"""
  17. from functools import wraps, reduce
  18. from enum import Enum
  19. import akg.tvm
  20. import akg.topi
  21. from akg.utils.format_transform import get_bytes, get_shape
  22. MAX_DATA_SIZE = 2 ** 31
  23. class DtypeForDavinci(Enum):
  24. """Davinci supported dtype."""
  25. ALL_TYPES = ["float16", "float32", "int32", "int8", "uint8"]
  26. ALL_FLOAT = ["float16", "float32"]
  27. ALL_INT = ["int8", "int32"]
  28. ALL_UINT = ["uint8"]
  29. FLOAT16 = ["float16"]
  30. FLOAT32 = ["float32"]
  31. INT8 = ["int8"]
  32. INT16 = ["int16"]
  33. INT32 = ["int32"]
  34. INT64 = ["int64"]
  35. UINT8 = ["uint8"]
  36. UINT16 = ["uint16"]
  37. UINT32 = ["uint32"]
  38. UINT64 = ["uint64"]
  39. BOOL = ["bool"]
  40. supported_bits = {
  41. "8": 1, "16": 2, "32": 4, "64": 8, "bool": 1
  42. }
  43. def check_input_type_dict(input_dict, input_key, input_name):
  44. """
  45. check input parameter type for new type: dict.
  46. Note:
  47. rule1: key of input_dict should be in the input_key
  48. rule2: type of input_dict[shape] should be in (list, tuple), if have shape
  49. rule3: type of input_dict[dtype] should be in (str), if have dtype
  50. Args:
  51. input_dict (dict): input_dict
  52. input_key (list or tuple): all input key list, the key of input must in input_key
  53. input_name (str): input param name, only used for error print
  54. Returns:
  55. None
  56. """
  57. def _check_input_type(input_key, input_type):
  58. if not isinstance(input_dict[input_key], input_type):
  59. raise RuntimeError(
  60. "the input parameter %s[%s] must be %s, while type of input is %s" %
  61. (input_name, input_key, input_type, type(input_dict[input_key])))
  62. for key in input_dict.keys():
  63. if key not in input_key:
  64. raise RuntimeError(
  65. "the input parameter %s must have arrt <%s>" %
  66. (input_name, key))
  67. # check shape's type of input_dict, if have shape
  68. if key == "shape":
  69. _check_input_type(key, (list, tuple))
  70. # check dtype's type of input_dict, if have dtype
  71. if key == "dtype":
  72. _check_input_type(key, (str,))
  73. def check_input_type_list_tuple(inputs, expect):
  74. """check inputs by a list or tuple of expected types."""
  75. if not isinstance(inputs, expect[1][0]):
  76. raise RuntimeError("the input parameter %s must be (list, tuple), while"
  77. " type of input is %s" % (expect[0], type(inputs)))
  78. for inp in inputs:
  79. if not isinstance(inp, expect[1][1]):
  80. raise RuntimeError("The element in parameter %s must be %s, while "
  81. "type of input is %s" % (
  82. expect[0], expect[1][1], type(inp)))
  83. def check_input_type(*type_args, **_type_kwargs):
  84. """check input parameter type."""
  85. def out_wrapper(func):
  86. """outer wrapper function."""
  87. formal_parameter = func.__code__.co_varnames
  88. formal_parameter_list = list(zip(formal_parameter, type_args))
  89. @wraps(func)
  90. def in_wrapper(*args, **kwargs):
  91. """inner wrapper function."""
  92. for i, arg_v in enumerate(args):
  93. # add for new input dict, if dict, will check shape and dtype
  94. if isinstance(arg_v, dict):
  95. check_input_type_dict(arg_v, arg_v.keys(),
  96. formal_parameter_list[i][0])
  97. if isinstance(formal_parameter_list[i][1], tuple):
  98. if isinstance(formal_parameter_list[i][1][0], tuple) \
  99. and len(formal_parameter_list[i][1]) == 2:
  100. check_input_type_list_tuple(arg_v, formal_parameter_list[i])
  101. continue
  102. if not isinstance(arg_v, formal_parameter_list[i][1]):
  103. raise RuntimeError("the %sth input parameter %s must be %s, "
  104. "while type of input is %s" % (str(i), formal_parameter_list[i][0],
  105. formal_parameter_list[i][1],
  106. type(arg_v)))
  107. for i in kwargs:
  108. for j in formal_parameter_list:
  109. if i in j:
  110. if not isinstance(kwargs[i], j[1]):
  111. raise RuntimeError("the input parameter %s must be "
  112. "%s, while type of input is %s"
  113. "" % (i, j[1], type(kwargs[i])))
  114. break
  115. return func(*args, **kwargs)
  116. return in_wrapper
  117. return out_wrapper
  118. def shape_dtype_max_size_check(shape, dtype):
  119. """check validation of tensor's shape."""
  120. if shape:
  121. for x in shape:
  122. if not isinstance(x, int):
  123. return
  124. mul = get_bytes(dtype) * int(reduce(lambda x, y: int(x) * int(y), shape))
  125. if mul > MAX_DATA_SIZE:
  126. error_msg = "*".join([str(sh) for sh in shape])
  127. raise RuntimeError("Invalid shape, data is {} bytes ({}), which "
  128. "exceed max data size {} bytes"
  129. .format(mul, error_msg, MAX_DATA_SIZE))
  130. def tensor_max_size_check(tensor):
  131. """check validation of tensor's shape."""
  132. if not isinstance(tensor, akg.tvm.tensor.Tensor):
  133. raise RuntimeError("tensor should be an akg.tvm.tensor.Tensor, but got "
  134. "type {}".format(type(tensor)))
  135. shape = get_shape(tensor)
  136. dtype = tensor.dtype
  137. shape_dtype_max_size_check(shape, dtype)
  138. def check_shape(tensor, length=None, tensor_name=""):
  139. """The common check rule for placeholder data."""
  140. shape = get_shape(tensor)
  141. if not shape:
  142. raise RuntimeError("The ndim of input tensor {} must more than 0, "
  143. "actual input is {}".format(tensor_name, len(shape)))
  144. for shape_v in shape:
  145. if isinstance(shape_v, (akg.tvm.expr.Var, akg.tvm.expr.Mul, akg.tvm.expr.FloorDiv, akg.tvm.expr.IntImm)):
  146. continue
  147. if not isinstance(shape_v, int) or shape_v <= 0:
  148. raise RuntimeError("The type of tensor {} axis value must be "
  149. "positive int and value more than 0,"
  150. "actual input is ({}) {}".
  151. format(tensor_name, type(shape_v), shape_v))
  152. if length and len(shape) != length:
  153. raise ValueError('The length of {} should be {}, while actual length is {}'.
  154. format(tensor_name, length, len(shape)))
  155. def ops_dtype_check(dtype, args):
  156. """check validation of op's dtype."""
  157. expected_dtype = list()
  158. def _get_expect_dtype(expected_dtype, arg):
  159. if isinstance(arg, str):
  160. expected_dtype.append(arg)
  161. elif isinstance(arg, DtypeForDavinci):
  162. expected_dtype += arg.value
  163. elif isinstance(arg, (list, tuple)):
  164. for t in arg:
  165. _get_expect_dtype(expected_dtype, t)
  166. else:
  167. raise TypeError("arg should be either a string, a DtypeForDavinci "
  168. "or a list/tuple of string or DtypeForDavinci, "
  169. "while current is {}".format(type(arg)))
  170. _get_expect_dtype(expected_dtype, args)
  171. if isinstance(dtype, (list, tuple)):
  172. checking_dtype = [d.lower() for d in dtype]
  173. elif isinstance(dtype, str):
  174. checking_dtype = [dtype.lower()]
  175. else:
  176. raise TypeError("dtype should be either a string or a tuple/list of string")
  177. error_msg = "Supported dtype: {}, while received dtype: {}"
  178. if not set(checking_dtype).issubset(set(expected_dtype)):
  179. raise RuntimeError(error_msg.format(expected_dtype, checking_dtype))
  180. def reduce_axis_check(reduce_shape, reduce_axis):
  181. """check validation of reduce axis for certain reduce shape."""
  182. dim = len(reduce_shape)
  183. if dim == 1 and isinstance(reduce_shape[0], int) and int(reduce_shape[0]) == 1:
  184. raise RuntimeError("Error, reduce shape is 1. Scalar is not supported "
  185. "for reduction, please input a vector.")
  186. if isinstance(reduce_axis, int):
  187. if reduce_axis not in range(-dim, dim):
  188. raise RuntimeError("Reduce axis should be in range [%d. %d)"
  189. "" % (-dim, dim))
  190. elif isinstance(reduce_axis, (tuple, list)):
  191. if len(reduce_axis) > len(reduce_shape):
  192. raise RuntimeError("Reduce axis list exceed reduce shape length: "
  193. "%d vs %d, error" % (len(reduce_axis), len(reduce_shape)))
  194. processed_axis = []
  195. for axis in reduce_axis:
  196. processed_axis.append(int(axis + dim) if axis < 0 else int(axis))
  197. if len(set(processed_axis)) < len(processed_axis):
  198. raise RuntimeError("Reduce axis list contains %d duplicated element, please check"
  199. % (len(processed_axis) - len(set(processed_axis))))
  200. for axis in processed_axis:
  201. if axis >= dim:
  202. raise RuntimeError("Invalid reduce axis, axis should less than %d" % dim)
  203. elif reduce_axis is not None:
  204. raise RuntimeError("axis should be a list, tuple or int.")
  205. def elemwise_shape_check(shape_a, shape_b):
  206. """check validation of tensor's shape for element-wise op."""
  207. check_shape(shape_a)
  208. check_shape(shape_b)
  209. if len(shape_a) != len(shape_b):
  210. raise RuntimeError("Element-wise operation needs same data length, "
  211. "while current is %s vs %s" % (len(shape_a), len(shape_b)))
  212. for i, shp in enumerate(shape_a):
  213. if int(shp) != int(shape_b[i]):
  214. raise RuntimeError("Element-wise operation needs same data shape, "
  215. "while current is %s vs %s" % (shp, shape_b[i]))
  216. def elemwise_dtype_check(dtype_a, dtype_b, supported_type=None):
  217. """check validation of tensor's dtype for element-wise op."""
  218. if supported_type:
  219. ops_dtype_check(dtype_a, supported_type)
  220. ops_dtype_check(dtype_b, supported_type)
  221. if dtype_a.lower() != dtype_b.lower():
  222. raise RuntimeError("Element-wise operation needs same data type, while "
  223. "current is %s vs %s" % (dtype_a.lower(), dtype_b.lower()))
  224. def auto_broadcast_check(shape_a, shape_b):
  225. """automatic broadcast check."""
  226. shape_l = get_shape(shape_a)
  227. shape_r = get_shape(shape_b)
  228. if len(shape_l) <= len(shape_r):
  229. shape_short = shape_l
  230. shape_long = shape_r
  231. else:
  232. shape_short = shape_r
  233. shape_long = shape_l
  234. dim_diff = len(shape_long) - len(shape_short)
  235. for i in range(dim_diff):
  236. shape_short.insert(0, 1)
  237. for i, shp in enumerate(shape_short):
  238. if int(shp) != int(shape_long[i]) and 1 not in [int(shp), int(shape_long[i])]:
  239. raise RuntimeError("Invalid auto broadcast, dim %d should be 1 or equal, "
  240. "while now is %d vs %d" % (i, shp, shape_long[i]))
  241. def broadcast_check(ori_shape, dst_shape):
  242. """check valid broadcast from ori_shape to dst_shape."""
  243. shape_l = get_shape(ori_shape)
  244. shape_r = get_shape(dst_shape)
  245. if len(shape_l) <= len(shape_r):
  246. dim_diff = len(shape_r) - len(shape_l)
  247. shape_l = ([1] * dim_diff) + shape_l
  248. else:
  249. raise RuntimeError("Cannot broadcast from shape %s to %s" % (str(ori_shape), str(dst_shape)))
  250. for i, shp in enumerate(shape_l):
  251. if int(shp) != int(shape_r[i]) and int(shp) != 1:
  252. raise RuntimeError("Cannot broadcast from shape %s to %s" % (str(ori_shape), str(dst_shape)))
  253. def gemm_format_check(lhs_input, rhs_input, lhs_trans=False, rhs_trans=False):
  254. """check gemm format (shape length and value)."""
  255. dim = len(lhs_input)
  256. if len(rhs_input) != dim:
  257. raise RuntimeError("Dimensions are different, lhs input is of %d dimension "
  258. "while rhs input is of %d dimension, " % (dim, len(rhs_input)))
  259. b_pos = [0] if dim == 3 else [0, 1]
  260. lhs_k_pos = -2 if lhs_trans else -1
  261. rhs_k_pos = -1 if rhs_trans else -2
  262. def length_check(tensor):
  263. if len(tensor) < 2 or len(tensor) > 4:
  264. raise RuntimeError("Gemm only support 2d shape (height, weight) "
  265. "or 3d shape (batch, height, weight) "
  266. "or 4d shape (batch_o, batch_i, height, weight) "
  267. " while shape length is %d!" % (len(tensor)))
  268. def value_check(loc):
  269. if loc == "B":
  270. if len(lhs_input) > 2:
  271. for pos in b_pos:
  272. value = int(lhs_input[pos])
  273. cmp_value = int(rhs_input[pos])
  274. if value != cmp_value:
  275. raise RuntimeError("%s size is not compatible, lhs "
  276. "input: %d and rhs input: %d" %
  277. (loc, value, cmp_value))
  278. if loc == "K":
  279. if isinstance(lhs_input[lhs_k_pos], akg.tvm.expr.Var) or isinstance(rhs_input[rhs_k_pos], akg.tvm.expr.Var):
  280. return
  281. value = int(lhs_input[lhs_k_pos])
  282. cmp_value = int(rhs_input[rhs_k_pos])
  283. if cmp_value != value:
  284. raise RuntimeError("%s size is not compatible, lhs :%d, "
  285. "rhs input: %d " % (loc, value, cmp_value))
  286. for data in [lhs_input, rhs_input]:
  287. length_check(data)
  288. for location in ["B", "K"]:
  289. value_check(location)
  290. def convolution_format_check(x_shape, w_shape, pad, stride, dilation):
  291. """check convolution format."""
  292. def conv_shape_check(shape):
  293. if (not isinstance(shape, (tuple, list))) or (len(shape) != 4):
  294. raise RuntimeError("conv tensor shape should be 4d list or tuple")
  295. conv_dtype = "float16"
  296. size = get_bytes(conv_dtype)
  297. for i in shape:
  298. if (not isinstance(i, int)) or (i <= 0):
  299. raise RuntimeError("conv tensor shape should be 4d list or "
  300. "tuple of positive integer")
  301. size *= i
  302. if size > MAX_DATA_SIZE:
  303. raise RuntimeError("runtime can not support tensor more than 2G size")
  304. def conv_pad_check(pad):
  305. if (not isinstance(pad, (tuple, list))) or (len(pad) != 4):
  306. raise RuntimeError("conv pad should be 4d list or tuple")
  307. for i in pad:
  308. if (not isinstance(i, int)) or (i < 0):
  309. raise RuntimeError("conv pad should be 4d list or tuple of "
  310. "positive integer or zero")
  311. def conv_stride_check(stride):
  312. if (not isinstance(stride, (tuple, list))) or (len(stride) != 2):
  313. raise RuntimeError("conv stride should be 2d list or tuple")
  314. for i in stride:
  315. if (not isinstance(i, int)) or (i <= 0):
  316. raise RuntimeError("conv stride should be 2d list or tuple of positive integer")
  317. def conv_dilation_check(dilation):
  318. if (not isinstance(dilation, (tuple, list))) or (len(dilation) != 2):
  319. raise RuntimeError("conv dilation should be 2d list or tuple")
  320. for i in dilation:
  321. if (not isinstance(i, int)) or (i <= 0):
  322. raise RuntimeError("conv dilation should be 2d list or tuple of positive integer")
  323. conv_shape_check(x_shape)
  324. conv_shape_check(w_shape)
  325. conv_pad_check(pad)
  326. conv_stride_check(stride)
  327. conv_dilation_check(dilation)
  328. if x_shape[1] != w_shape[1]:
  329. raise RuntimeError("conv's feature_map and filter tensor should be the same channel")
  330. if x_shape[2] + pad[0] + pad[1] < w_shape[2]:
  331. raise RuntimeError("kernel_h should be <= h + pad_left + pad_right: %d"
  332. "" % (x_shape[2] + pad[0] + pad[1]))
  333. if x_shape[3] + pad[2] + pad[3] < w_shape[3]:
  334. raise RuntimeError("kernel_w should be <= w + pad_top + pad_bottom: %d"
  335. "" % (x_shape[3] + pad[2] + pad[3]))
  336. if (pad[0] >= w_shape[2]) or (pad[1] >= w_shape[2]) \
  337. or (pad[2] >= w_shape[3]) or (pad[3] >= w_shape[3]):
  338. raise RuntimeError("pad value cannot be more than the filter value")
  339. def davinci_format_check(shape, tensor_format, dim=-1):
  340. """check validation of tensor's shape for certain format used in davinci chip."""
  341. all_format_shape = {"NCHW": 4,
  342. "NHWC": 4,
  343. "NC1HWC0": 5,
  344. "DefaultFormat": [2, 4]}
  345. if dim not in [-1, 2, 4, 5]:
  346. raise RuntimeError("Only support 2d, 4d, 5d format check, please set "
  347. "dim to the dim want to check "
  348. "or use default value -1 to check both all the dim")
  349. if dim == -1:
  350. support_format_shape = all_format_shape
  351. else:
  352. support_format_shape = {}
  353. for k, v in all_format_shape.items():
  354. if isinstance(v, int) and v == dim:
  355. support_format_shape[k] = v
  356. if isinstance(v, list) and dim in v:
  357. support_format_shape[k] = v
  358. support_shape = {"NC1HWC0": (4, 16)}
  359. if not isinstance(tensor_format, str):
  360. raise RuntimeError("Invalid davinci format, should be a string, "
  361. "but get %s" % (type(tensor_format)))
  362. if tensor_format not in support_format_shape.keys():
  363. raise RuntimeError("Invalid davinci format {}, davinci support {}"
  364. .format(tensor_format, support_format_shape.keys()))
  365. if isinstance(support_format_shape[tensor_format], int):
  366. if len(shape) != support_format_shape[tensor_format]:
  367. raise RuntimeError("Invalid shape {} for davinci format {}, needs "
  368. "{} dim shape, current length{}"
  369. .format(shape, tensor_format,
  370. support_format_shape[tensor_format], len(shape)))
  371. if isinstance(support_format_shape[tensor_format], list):
  372. if len(shape) not in support_format_shape[tensor_format]:
  373. raise RuntimeError("Invalid shape {} for davinci format {}, needs {} dim shape"
  374. .format(shape, tensor_format,
  375. support_format_shape[tensor_format]))
  376. if tensor_format in support_shape.keys():
  377. check_dim = support_shape[tensor_format][0]
  378. expect_shape = support_shape[tensor_format][1]
  379. if int(shape[check_dim]) != expect_shape:
  380. raise RuntimeError("Invalid shape {} for davinci format {}, dim {} "
  381. "should be {}, while current is {}"
  382. .format(shape, tensor_format, check_dim,
  383. expect_shape, shape[check_dim]))
  384. def is_valid_reduce_axis(tensor, reduce_axis):
  385. """
  386. if the reduce axis correspond to shape[axis] is 1, we can not refine the shape,or the reduce axis will be wrong.
  387. Args:
  388. tensor (tvm.tensor.Tensor): input tensor.
  389. reduce_axis (Union[list, tuple, int]): axis want to reduce.
  390. Returns:
  391. True or False.
  392. """
  393. # if the reduce axis correspond to shape[axis] is 1, we can not refine the
  394. # shape,or the reduce axis will be wrong
  395. # need_shape_refine = True
  396. shape = get_shape(tensor)
  397. if hasattr(reduce_axis, 'index'):
  398. for id_ite in reduce_axis:
  399. if shape[id_ite] == 1:
  400. return False
  401. else:
  402. if shape[reduce_axis] == 1:
  403. return False
  404. return True
  405. def axis_check(shape_len, axis):
  406. """Check the value of axis and return the sorted axis."""
  407. def _axis_value_type_check(value):
  408. if not isinstance(value, int):
  409. raise RuntimeError("type of axis value should be int")
  410. if value >= shape_len or value < -shape_len:
  411. raise RuntimeError(
  412. "input axis is out of range, axis value can be from %d to %d" %
  413. (-shape_len, shape_len - 1))
  414. if value < 0:
  415. value = shape_len + value
  416. return value
  417. if not hasattr(axis, 'index'):
  418. axis = _axis_value_type_check(axis)
  419. return axis
  420. for i, axs in enumerate(axis):
  421. axis[i] = _axis_value_type_check(axs)
  422. axis = sorted(set(axis))
  423. return axis
  424. def check_value_on_integer(arg_name, arg_value, low=None, high=None):
  425. """Judging integer type."""
  426. type_match = isinstance(arg_value, int) and not isinstance(arg_value, bool)
  427. if not type_match:
  428. raise ValueError("%s should be an int , but got type %s"
  429. "" % (arg_name, type(arg_value)))
  430. if low and arg_value < low:
  431. raise ValueError("%s should be greater than or equal to %f, but got %f"
  432. "" % (arg_name, low, arg_value))
  433. if high and arg_value >= high:
  434. raise ValueError("%s should be less than %f, but got %f"
  435. "" % (arg_name, high, arg_value))
  436. def check_typename(arg_name, arg_type, valid_types):
  437. """Does it contain the _name_ attribute."""
  438. def get_typename(t):
  439. return t.__name__ if hasattr(t, '__name__') else str(t)
  440. if arg_type in valid_types:
  441. return arg_type
  442. type_names = [get_typename(t) for t in valid_types]
  443. if len(valid_types) == 1:
  444. raise ValueError('type of {} should be {}, but got {}'.format(
  445. arg_name, type_names[0], get_typename(arg_type)))
  446. raise ValueError('type of {} should be one of {}, but got {}'.format(
  447. arg_name, type_names, get_typename(arg_type)))
  448. def check_equal(arg_name1, arg_name2, arg1, arg2,):
  449. """Check equal."""
  450. if arg1 != arg2:
  451. raise ValueError('{} should be equal to {}'.format(arg_name1, arg_name2))
  452. def check_greater(arg_name1, arg_name2, arg1, arg2,):
  453. """Check greater."""
  454. if arg1 <= arg2:
  455. raise ValueError('{} should be greater than {}'.format(arg_name1, arg_name2))
  456. def check_5d(arg_name, shape5d, shape4d):
  457. """Check 5D shape."""
  458. blocksize = 16
  459. if len(shape4d) != 4:
  460. raise ValueError('invalid 4D shape of {}'.format(arg_name))
  461. if len(shape5d) != 5:
  462. raise ValueError('invalid 5D shape of {}'.format(arg_name))
  463. d1, d2, d3, d4 = shape4d
  464. if [x.value for x in shape5d] != [d1, (d2 + blocksize - 1) // blocksize, d3, d4, blocksize]:
  465. raise ValueError('the 4D shape and 5D shape of {} do not match'.format(arg_name))
  466. def check_shape_length_equal(tensor_name, tensor_shape, shape_length):
  467. """Shape length equal judgment."""
  468. if isinstance(shape_length, (tuple, list)):
  469. if not len(tensor_shape) in shape_length:
  470. raise ValueError("The shape length of {tensor_name} should be one of "
  471. "{shape_length}, but get {tensor_shape_len}"
  472. "".format(
  473. tensor_name=tensor_name,
  474. shape_length=shape_length,
  475. tensor_shape_len=len(tensor_shape)))
  476. elif len(tensor_shape) != shape_length:
  477. raise ValueError("The shape length of {tensor_name} should be "
  478. "{shape_length}, but get {tensor_shape_len}"
  479. "".format(
  480. tensor_name=tensor_name,
  481. shape_length=shape_length,
  482. tensor_shape_len=len(tensor_shape)))
  483. def check_shape_length_greater(tensor_name, tensor_shape, shape_length):
  484. """Shape length greater judgment."""
  485. if len(tensor_shape) <= shape_length:
  486. raise ValueError("The shape length of {tensor_name} should be greater "
  487. "than {shape_length}, but get {tensor_shape_len}".format(
  488. tensor_name=tensor_name, shape_length=shape_length,
  489. tensor_shape_len=len(tensor_shape)))
  490. def judge_var(num):
  491. """judge var if a tvm.var, tvm.const or python data type."""
  492. var_dict = {
  493. "python_const": [int, float],
  494. "tvm_const": [
  495. akg.tvm.expr.IntImm, akg.tvm.expr.UIntImm, akg.tvm.expr.FloatImm],
  496. "tvm_var": [akg.tvm.expr.Var]}
  497. num_type = type(num)
  498. for i in var_dict:
  499. if num_type in var_dict[i]:
  500. return i
  501. raise RuntimeError("Input var dtype {} error".format(type(num)))
  502. def check_pad(arg_name, pad, length=None):
  503. """Check pad."""
  504. if not pad:
  505. raise ValueError("{} should not be None".format(arg_name))
  506. if not isinstance(pad, (tuple, list)):
  507. raise ValueError("{} should be tuple or list".format(arg_name))
  508. for i in pad:
  509. if not isinstance(i, int):
  510. raise ValueError("Elements in {} should be int".format(arg_name))
  511. if i < 0:
  512. raise ValueError("Elements in {} should not be less than zero"
  513. "".format(arg_name))
  514. if length:
  515. if length != len(pad):
  516. raise ValueError("The length of {} should be {}".format(
  517. arg_name, length))
  518. def check_int_list(array, array_name):
  519. """check whether all the elements are integers."""
  520. for num in array:
  521. if not isinstance(num, int):
  522. raise RuntimeError("Type of value in %s should be int, but got type %s" % (array_name, type(num)))