You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_checkparam.py 40 kB

4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Check parameters."""
  16. import re
  17. import inspect
  18. import math
  19. from enum import Enum
  20. from functools import reduce, wraps
  21. from itertools import repeat, zip_longest
  22. from collections import deque
  23. from collections.abc import Iterable
  24. import numpy as np
  25. from mindspore import context
  26. from mindspore import log as logger
  27. from mindspore.common import dtype as mstype
  28. from mindspore._c_expression import Tensor as Tensor_
  29. class Rel(Enum):
  30. """Numerical relationship between variables, logical relationship enumeration definition of range."""
  31. # scalar compare
  32. EQ = 1 # ==
  33. NE = 2 # !=
  34. LT = 3 # <
  35. LE = 4 # <=
  36. GT = 5 # >
  37. GE = 6 # >=
  38. # scalar range check
  39. INC_NEITHER = 7 # (), include neither
  40. INC_LEFT = 8 # [), include left
  41. INC_RIGHT = 9 # (], include right
  42. INC_BOTH = 10 # [], include both
  43. # collection in, not in
  44. IN = 11
  45. NOT_IN = 12
  46. @staticmethod
  47. def get_strs(rel):
  48. """Get value from rel_strs."""
  49. return rel_strs.get(rel, "")
  50. @staticmethod
  51. def get_fns(rel):
  52. """Get value from rel_fns."""
  53. return rel_fns.get(rel, lambda *args: False)
  54. rel_fns = {
  55. # scalar compare
  56. Rel.EQ: lambda x, y: x == y,
  57. Rel.NE: lambda x, y: x != y,
  58. Rel.LT: lambda x, y: x < y,
  59. Rel.LE: lambda x, y: x <= y,
  60. Rel.GT: lambda x, y: x > y,
  61. Rel.GE: lambda x, y: x >= y,
  62. # scalar range check
  63. Rel.INC_NEITHER: lambda x, lower, upper: (lower < x < upper),
  64. Rel.INC_LEFT: lambda x, lower, upper: (lower <= x < upper),
  65. Rel.INC_RIGHT: lambda x, lower, upper: (lower < x <= upper),
  66. Rel.INC_BOTH: lambda x, lower, upper: (lower <= x <= upper),
  67. # collection in, not in
  68. Rel.IN: lambda x, y: x in y,
  69. Rel.NOT_IN: lambda x, y: x not in y,
  70. }
  71. rel_strs = {
  72. # scalar compare
  73. Rel.EQ: "= {}",
  74. Rel.NE: "!= {}",
  75. Rel.LT: "< {}",
  76. Rel.LE: "<= {}",
  77. Rel.GT: "> {}",
  78. Rel.GE: ">= {}",
  79. # scalar range check
  80. Rel.INC_NEITHER: "({}, {})",
  81. Rel.INC_LEFT: "[{}, {})",
  82. Rel.INC_RIGHT: "({}, {}]",
  83. Rel.INC_BOTH: "[{}, {}]",
  84. # collection in, not in
  85. Rel.IN: "in {}",
  86. Rel.NOT_IN: "not in {}",
  87. }
  88. def _check_3d_int_or_tuple(arg_name, arg_value, prim_name, allow_five=False, ret_five=False,
  89. greater_zero=True, third_one=False, three_input=False):
  90. """
  91. Checks whether an argument is a positive int or tuple with 3 or 5(when allow_five is True) positive int elements.
  92. """
  93. def _raise_message(third_one_flag=False, three_input_flag=False):
  94. if third_one_flag:
  95. raise ValueError(f"For '{prim_name}' the depth of attr '{arg_name}' should be 1, but got {ret_value[-3]}")
  96. if three_input_flag:
  97. raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of "
  98. f"three positive int numbers, but got {arg_value}")
  99. raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of three "
  100. f"{'or five ' if allow_five else ''}positive int numbers, but got {arg_value}")
  101. def _get_return_value():
  102. if isinstance(arg_value, int):
  103. ret = (1, 1, arg_value, arg_value, arg_value) if ret_five else (arg_value, arg_value, arg_value)
  104. elif len(arg_value) == 3:
  105. ret = (1, 1, arg_value[0], arg_value[1], arg_value[2]) if ret_five else arg_value
  106. elif len(arg_value) == 5:
  107. if not allow_five:
  108. _raise_message()
  109. ret = arg_value if ret_five else (arg_value[1], arg_value[2], arg_value[3])
  110. else:
  111. _raise_message()
  112. return ret
  113. Validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name)
  114. if three_input and isinstance(arg_value, tuple):
  115. if len(arg_value) != 3:
  116. _raise_message(three_input_flag=three_input)
  117. ret_value = _get_return_value()
  118. for item in ret_value:
  119. if isinstance(item, int) and not isinstance(item, bool):
  120. if greater_zero and item > 0:
  121. continue
  122. if not greater_zero and item >= 0:
  123. continue
  124. _raise_message()
  125. if third_one:
  126. if ret_value[-3] != 1:
  127. _raise_message(third_one_flag=third_one)
  128. return tuple(ret_value)
  129. def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=None):
  130. """
  131. Check argument integer.
  132. - number = check_number(number, 0, Rel.GE, "number", None) # number >= 0
  133. """
  134. rel_fn = Rel.get_fns(rel)
  135. prim_name = f' in `{prim_name}`' if prim_name else ''
  136. arg_name = f'`{arg_name}`' if arg_name else ''
  137. prim_info = f'{arg_name}' + f'{prim_name}'
  138. if isinstance(arg_value, arg_type):
  139. if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value):
  140. raise ValueError(f'f{prim_info} must be a legal value, but got `{arg_value}`.')
  141. else:
  142. raise TypeError(f'{prim_info} must be {arg_type.__name__}, but got `{type(arg_value).__name__}`')
  143. type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool)
  144. type_except = TypeError if type_mismatch else ValueError
  145. if type_mismatch or not rel_fn(arg_value, value):
  146. rel_str = Rel.get_strs(rel).format(value)
  147. raise type_except(f'{prim_info} should be {arg_type.__name__} and must {rel_str}, '
  148. f'but got `{arg_value}` with type `{type(arg_value).__name__}`.')
  149. return arg_value
  150. def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None):
  151. """
  152. Checks input value is float type or not.
  153. Usage:
  154. - number = check_is_number(number, int)
  155. - number = check_is_number(number, int, "bias")
  156. - number = check_is_number(number, int, "bias", "bias_class")
  157. """
  158. prim_name = f"For \'{prim_name}\', the" if prim_name else 'The'
  159. arg_name = f"\'{arg_name}\'" if arg_name else 'input value'
  160. if isinstance(arg_value, arg_type) and not isinstance(arg_value, bool):
  161. if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value):
  162. raise ValueError(f'{prim_name} {arg_name} must be a legal float, but got `{arg_value}`.')
  163. return arg_value
  164. raise TypeError(f'{prim_name} type of {arg_name} must be {arg_type.__name__}, but got `{type(arg_value).__name__}`')
  165. def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg_name=None, prim_name=None):
  166. """
  167. Method for checking whether an int value is in some range.
  168. Usage:
  169. - number = check_number_range(number, 0.0, 1.0, Rel.INC_NEITHER, "number", float) # number in [0.0, 1.0]
  170. - number = check_number_range(number, 0, 1, Rel.INC_NEITHER, "number", int) # number in [0, 1]
  171. """
  172. rel_fn = Rel.get_fns(rel)
  173. prim_name = f'in `{prim_name}`' if prim_name else ''
  174. arg_name = f'`{arg_name}`' if arg_name else ''
  175. type_mismatch = not isinstance(arg_value, (np.ndarray, np.generic, value_type)) or isinstance(arg_value, bool)
  176. if type_mismatch:
  177. raise TypeError("{} {} must be `{}`, but got `{}`.".format(
  178. arg_name, prim_name, value_type.__name__, type(arg_value).__name__))
  179. if not rel_fn(arg_value, lower_limit, upper_limit):
  180. rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
  181. raise ValueError("{} {} should be in range of {}, but got {} with type `{}`.".format(
  182. arg_name, prim_name, rel_str, arg_value, type(arg_value).__name__))
  183. return arg_value
  184. class Validator:
  185. """validator for checking input parameters"""
  186. @staticmethod
  187. def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None, excp_cls=ValueError):
  188. """
  189. Method for judging relation between two int values or list/tuple made up of ints.
  190. This method is not suitable for judging relation between floats, since it does not consider float error.
  191. """
  192. rel_fn = Rel.get_fns(rel)
  193. if not rel_fn(arg_value, value):
  194. rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
  195. msg_prefix = f'For \'{prim_name}\', the' if prim_name else "The"
  196. raise excp_cls(f'{msg_prefix} \'{arg_name}\' should be {rel_str}, but got {arg_value}.')
  197. return arg_value
  198. @staticmethod
  199. def check_int(arg_value, value, rel, arg_name=None, prim_name=None):
  200. """
  201. Checks input integer value `arg_value` compare to `value`.
  202. Usage:
  203. - number = check_int(number, 0, Rel.GE, "number", None) # number >= 0
  204. """
  205. return check_number(arg_value, value, rel, int, arg_name, prim_name)
  206. @staticmethod
  207. def check_is_int(arg_value, arg_name=None, prim_name=None):
  208. """
  209. Checks input value is float type or not.
  210. Usage:
  211. - number = check_is_int(number, int)
  212. - number = check_is_int(number, int, "bias")
  213. - number = check_is_int(number, int, "bias", "bias_class")
  214. """
  215. return check_is_number(arg_value, int, arg_name, prim_name)
  216. @staticmethod
  217. def check_equal_int(arg_value, value, arg_name=None, prim_name=None):
  218. """
  219. Checks input integer value `arg_value` compare to `value`.
  220. Usage:
  221. - number = check_int(number, 0, Rel.GE, "number", None) # number >= 0
  222. """
  223. return check_number(arg_value, value, Rel.EQ, int, arg_name, prim_name)
  224. @staticmethod
  225. def check_positive_int(arg_value, arg_name=None, prim_name=None):
  226. """
  227. Check argument is positive integer, which mean arg_value > 0.
  228. Usage:
  229. - number = check_positive_int(number)
  230. - number = check_positive_int(number, "bias")
  231. """
  232. return check_number(arg_value, 0, Rel.GT, int, arg_name, prim_name)
  233. @staticmethod
  234. def check_negative_int(arg_value, arg_name=None, prim_name=None):
  235. """
  236. Check argument is negative integer, which mean arg_value < 0.
  237. Usage:
  238. - number = check_negative_int(number)
  239. - number = check_negative_int(number, "bias")
  240. """
  241. return check_number(arg_value, 0, Rel.LT, int, arg_name, prim_name)
  242. @staticmethod
  243. def check_non_positive_int(arg_value, arg_name=None, prim_name=None):
  244. """
  245. Check argument is non-negative integer, which mean arg_value <= 0.
  246. Usage:
  247. - number = check_non_positive_int(number)
  248. - number = check_non_positive_int(number, "bias")
  249. """
  250. return check_number(arg_value, 0, Rel.LE, int, arg_name, prim_name)
  251. @staticmethod
  252. def check_non_negative_int(arg_value, arg_name=None, prim_name=None):
  253. """
  254. Check argument is non-negative integer, which mean arg_value >= 0.
  255. Usage:
  256. - number = check_non_negative_int(number)
  257. - number = check_non_negative_int(number, "bias")
  258. """
  259. return check_number(arg_value, 0, Rel.GE, int, arg_name, prim_name)
  260. @staticmethod
  261. def check_float(arg_value, value, rel, arg_name=None, prim_name=None):
  262. """
  263. Checks input float value `arg_value` compare to `value`.
  264. Usage:
  265. - number = check_float(number, 0.0, Rel.GE, "number", None) # number >= 0
  266. """
  267. return check_number(arg_value, value, rel, float, arg_name, prim_name)
  268. @staticmethod
  269. def check_is_float(arg_value, arg_name=None, prim_name=None):
  270. """
  271. Checks input value is float type or not.
  272. Usage:
  273. - number = check_is_float(number, int)
  274. - number = check_is_float(number, int, "bias")
  275. - number = check_is_float(number, int, "bias", "bias_class")
  276. """
  277. return check_is_number(arg_value, float, arg_name, prim_name)
  278. @staticmethod
  279. def check_positive_float(arg_value, arg_name=None, prim_name=None):
  280. """
  281. Check argument is positive float, which mean arg_value > 0.
  282. Usage:
  283. - number = check_positive_float(number)
  284. - number = check_positive_float(number, "bias")
  285. - number = check_positive_float(number, "bias", "bias_class")
  286. """
  287. return check_number(arg_value, 0, Rel.GT, float, arg_name, prim_name)
  288. @staticmethod
  289. def check_negative_float(arg_value, arg_name=None, prim_name=None):
  290. """
  291. Check argument is negative float, which mean arg_value < 0.
  292. Usage:
  293. - number = check_negative_float(number)
  294. - number = check_negative_float(number, "bias")
  295. """
  296. return check_number(arg_value, 0, Rel.LT, float, arg_name, prim_name)
  297. @staticmethod
  298. def check_non_positive_float(arg_value, arg_name=None, prim_name=None):
  299. """
  300. Check argument is non-negative float, which mean arg_value <= 0.
  301. Usage:
  302. - number = check_non_positive_float(number)
  303. - number = check_non_positive_float(number, "bias")
  304. """
  305. return check_number(arg_value, 0, Rel.LE, float, arg_name, prim_name)
  306. @staticmethod
  307. def check_non_negative_float(arg_value, arg_name=None, prim_name=None):
  308. """
  309. Check argument is non-negative float, which mean arg_value >= 0.
  310. Usage:
  311. - number = check_non_negative_float(number)
  312. - number = check_non_negative_float(number, "bias")
  313. """
  314. return check_number(arg_value, 0, Rel.GE, float, arg_name, prim_name)
  315. @staticmethod
  316. def check_number(arg_name, arg_value, value, rel, prim_name):
  317. """Number value judgment."""
  318. rel_fn = Rel.get_fns(rel)
  319. if not rel_fn(arg_value, value):
  320. rel_str = Rel.get_strs(rel).format(value)
  321. raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must {rel_str}, but got {arg_value}.')
  322. return arg_value
  323. @staticmethod
  324. def check_isinstance(arg_name, arg_value, classes):
  325. """Check arg isinstance of classes"""
  326. if not isinstance(arg_value, classes):
  327. raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
  328. return arg_value
  329. @staticmethod
  330. def check_bool(arg_value, arg_name=None, prim_name=None):
  331. """
  332. Check argument is instance of bool.
  333. Usage:
  334. - has_bias = check_bool(has_bias)
  335. - has_bias = check_bool(has_bias, "has_bias")
  336. """
  337. if not isinstance(arg_value, bool):
  338. prim_name = f"For '{prim_name}', the" if prim_name else 'The'
  339. arg_name = f"'{arg_name}'" if arg_name else 'input value'
  340. raise TypeError(f"{prim_name} {arg_name} should be a bool, but got {type(arg_value).__name__}.")
  341. return arg_value
  342. @staticmethod
  343. def check_int_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
  344. """
  345. Method for checking whether input value is in int range.
  346. Usage:
  347. - number = check_int_range(number, 0, 1, Rel.INC_NEITHER) # number in [0, 1]
  348. - number = check_int_range(number, 0, 1, Rel.INC_NEITHER, "number") # number in [0, 1]
  349. """
  350. return check_number_range(arg_value, lower_limit, upper_limit, rel, int, arg_name, prim_name)
  351. @staticmethod
  352. def check_float_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
  353. """
  354. Method for checking whether input value is in float range.
  355. Usage:
  356. - number = check_float_range(number, 0.0, 1.0, Rel.INC_NEITHER) # number in [0.0, 1.0]
  357. - number = check_float_range(number, 0.0, 1.0, Rel.INC_NEITHER, "number") # number in [0.0, 1.0]
  358. """
  359. return check_number_range(arg_value, lower_limit, upper_limit, rel, float, arg_name, prim_name)
  360. @staticmethod
  361. def check_string(arg_value, valid_values, arg_name=None, prim_name=None):
  362. """
  363. Check whether string is in some value list.
  364. Usage:
  365. - method = check_string(method, ["string1", "string2", "string3"], "method")
  366. """
  367. if isinstance(arg_value, str) and arg_value in valid_values:
  368. return arg_value
  369. arg_name = arg_name if arg_name else "Parameter"
  370. msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
  371. raise ValueError(f"{msg_prefix} '{arg_name}' should be str and must be in '{valid_values}',"
  372. f" but got '{arg_value}'.")
  373. @staticmethod
  374. def check_str_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
  375. if reg is None:
  376. # Named string regular expression
  377. reg = r"^\w+[0-9a-zA-Z\_\.]*$"
  378. if re.match(reg, target, flag) is None:
  379. prim_name = f'in `{prim_name}`' if prim_name else ""
  380. raise ValueError("'{}' {} is illegal, it should be match regular'{}' by flags'{}.'".format(
  381. target, prim_name, reg, flag))
  382. return True
  383. @staticmethod
  384. def check_file_name_by_regular(target, reg=None, prim_name=None):
  385. """Check whether file name is legitimate."""
  386. if not isinstance(target, str):
  387. raise ValueError("Args file_name {} must be string, please check it".format(target))
  388. if target.endswith("\\") or target.endswith("/"):
  389. raise ValueError("File name cannot be a directory path.")
  390. if reg is None:
  391. reg = r"^[0-9a-zA-Z\_\-\.\:\/\\]+$"
  392. if re.match(reg, target) is None:
  393. prim_name = f'in `{prim_name}`' if prim_name else ""
  394. raise ValueError("'{}' {} is illegal, it should be match regular'{}'.".format(
  395. target, prim_name, reg))
  396. return True
  397. @staticmethod
  398. def check_pad_value_by_mode(pad_mode, padding, prim_name):
  399. """Validates value of padding according to pad_mode"""
  400. if pad_mode != 'pad' and padding != 0:
  401. raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}'.")
  402. return padding
  403. @staticmethod
  404. def check_subclass(arg_name, type_, template_types, prim_name, addition_error_info=None):
  405. """Checks whether some type is subclass of another type"""
  406. if not isinstance(template_types, Iterable):
  407. template_types = (template_types,)
  408. hit = False
  409. for template_type in template_types:
  410. if isinstance(template_type, mstype.Type):
  411. if mstype.issubclass_(type_, template_type):
  412. hit = True
  413. break
  414. elif type_ is template_type:
  415. hit = True
  416. break
  417. if not hit:
  418. if addition_error_info is None:
  419. addition_error_info = ''
  420. type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
  421. raise TypeError(f"For '{prim_name}', the type of '{arg_name}'"
  422. f" should be {'one of ' if len(template_types) > 1 else ''}"
  423. f"{', '.join((str(x) for x in template_types))}, but got {type_str}"
  424. f" {addition_error_info}. The supported data types depend on the hardware that"
  425. f" executes the operator, please refer the official api document to get"
  426. f" more information about the data type.")
  427. @staticmethod
  428. def check_valid_input(arg_name, arg_value, prim_name):
  429. """Checks valid value."""
  430. if arg_value is None:
  431. raise ValueError(f"For \'{prim_name}\', the '{arg_name}' can not be None, but got {arg_value}.")
  432. return arg_value
  433. @staticmethod
  434. def check_types_same_and_valid(args, valid_values, prim_name):
  435. """Checks whether the types of inputs are the same and valid."""
  436. def _check_type_valid(arg):
  437. arg_key, arg_val = arg
  438. elem_type = arg_val
  439. Validator.check_subclass(arg_key, elem_type, valid_values, prim_name)
  440. return (arg_key, elem_type)
  441. def _check_types_same(arg1, arg2):
  442. arg1_name, arg1_type = arg1
  443. arg2_name, arg2_type = arg2
  444. if arg1_type != arg2_type:
  445. raise TypeError(f"For '{prim_name}', type of '{arg2_name}' should be same as '{arg1_name}',"
  446. f" but got '{arg1_name}' with type {arg1_type}"
  447. f" and '{arg2_name}' with type {arg2_type}.")
  448. return arg1
  449. elem_types = map(_check_type_valid, args.items())
  450. reduce(_check_types_same, elem_types)
  451. @staticmethod
  452. def check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim_name):
  453. """Checks whether the element types of input tensors are the same and valid."""
  454. valid_dtypes = valid_dtypes if isinstance(valid_dtypes, Iterable) else [valid_dtypes]
  455. tensor_types = [mstype.tensor_type(t) for t in valid_dtypes]
  456. Validator.check_types_same_and_valid(args, tensor_types, prim_name)
  457. @staticmethod
  458. def check_tensor_dtype_valid(arg_name, arg_type, valid_dtypes, prim_name):
  459. """Checks whether the element types of input tensors are valid."""
  460. valid_dtypes = valid_dtypes if isinstance(valid_dtypes, Iterable) else [valid_dtypes]
  461. tensor_types = [mstype.tensor_type(t) for t in valid_dtypes]
  462. Validator.check_subclass(arg_name, arg_type, tensor_types, prim_name)
  463. @staticmethod
  464. def check_scalar_or_tensor_types_same(args, valid_values, prim_name, allow_mix=False):
  465. """
  466. Checks whether the types of inputs are the same. If the input args are tensors, checks their element types.
  467. If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised.
  468. """
  469. def _check_argument_type(arg):
  470. arg_key, arg_val = arg
  471. if isinstance(arg_val, type(mstype.tensor)):
  472. arg_val = arg_val.element_type()
  473. if not arg_val in valid_values:
  474. raise TypeError(f'For \'{prim_name}\', the type of `{arg_key}` should be in {valid_values},'
  475. f' but got {arg_val}.')
  476. return arg
  477. def _check_types_same(arg1, arg2):
  478. arg1_name, arg1_type = arg1
  479. arg2_name, arg2_type = arg2
  480. except_flag = False
  481. if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)):
  482. arg1_type = arg1_type.element_type()
  483. arg2_type = arg2_type.element_type()
  484. elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))):
  485. pass
  486. elif allow_mix:
  487. arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type
  488. arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type
  489. else:
  490. except_flag = True
  491. if except_flag or arg1_type != arg2_type:
  492. raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
  493. f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
  494. return arg1
  495. reduce(_check_types_same, map(_check_argument_type, args.items()))
  496. @staticmethod
  497. def check_value_type(arg_name, arg_value, valid_types, prim_name=None):
  498. """Checks whether a value is instance of some types."""
  499. valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
  500. def raise_error_msg():
  501. """func for raising error message when check failed"""
  502. type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types]
  503. num_types = len(valid_types)
  504. msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
  505. raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
  506. f'\'{type_names if num_types > 1 else type_names[0]}\', '
  507. f'but got \'{arg_value}\' with type \'{type(arg_value).__name__}\'.')
  508. # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
  509. # `check_value_type('x', True, [bool, int])` will check pass
  510. if isinstance(arg_value, bool) and bool not in tuple(valid_types):
  511. raise_error_msg()
  512. if not isinstance(arg_value, tuple(valid_types)):
  513. raise_error_msg()
  514. return arg_value
  515. @staticmethod
  516. def check_type_name(arg_name, arg_type, valid_types, prim_name):
  517. """Checks whether a type in some specified types"""
  518. valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
  519. def raise_error_msg():
  520. """func for raising error message when check failed"""
  521. type_names = [t.__name__ if hasattr(t, '__name__') else t for t in valid_types]
  522. num_types = len(valid_types)
  523. msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
  524. raise TypeError(f"{msg_prefix} '{arg_name}' should be {'one of ' if num_types > 1 else ''}"
  525. f"{type_names if num_types > 1 else type_names[0]}, "
  526. f"but got {arg_type.__name__ if hasattr(arg_type, '__name__') else repr(arg_type)}.")
  527. if isinstance(arg_type, type(mstype.tensor)):
  528. arg_type = arg_type.element_type()
  529. if arg_type not in valid_types:
  530. raise_error_msg()
  531. return arg_type
  532. @staticmethod
  533. def check_reduce_shape(ori_shape, shape, axis, prim_name, arg_name1, arg_name2):
  534. """Checks whether shape is ori_shape reduced on axis"""
  535. axis_origin = axis
  536. axis = axis if isinstance(axis, Iterable) else (axis,)
  537. exp_shape = [ori_shape[i] for i in range(len(ori_shape)) if i not in axis]
  538. if list(shape) != exp_shape:
  539. raise ValueError(f"For '{prim_name}', the '{arg_name1}'.shape reduce on 'axis': {axis_origin} should "
  540. f"be equal to '{arg_name2}'.shape: {shape}, but got {ori_shape}.")
  541. @staticmethod
  542. def check_astype_dtype(dtype):
  543. """Check whether dtype is a valid input, and convert to mstype"""
  544. all_types = mstype.__dtype__ + ["int", "float", "bool"]
  545. if isinstance(dtype, str):
  546. if dtype.lower() not in all_types:
  547. raise TypeError(f"`{dtype}` not understood.")
  548. dtype = mstype.pytype_to_dtype(np.dtype(dtype.lower()))
  549. elif isinstance(dtype, type):
  550. dtype = mstype.pytype_to_dtype(dtype)
  551. elif not dtype in mstype.number_type + (mstype.bool_,):
  552. raise TypeError(f"`{dtype}` not understood.")
  553. return dtype
  554. @staticmethod
  555. def check_transpose_axis(axes, ndim):
  556. """Check the axis argument for tensor.transpose"""
  557. if not axes or (len(axes) == 1 and axes[0] is None):
  558. return tuple(range(ndim-1, -1, -1))
  559. if len(axes) == 1:
  560. perm = axes[0]
  561. # if only one argument provided, it must be tuple or list
  562. if isinstance(perm, list):
  563. perm = tuple(perm)
  564. else:
  565. if not isinstance(perm, tuple):
  566. raise TypeError(f"The `axes` should be a tuple/list, or series of int, but got {type(axes[0])}")
  567. return perm
  568. # if multiple arguments provided, it must be `ndim` number of ints
  569. if len(axes) != ndim:
  570. raise ValueError("The number of axes must be equal to the dimension of tensor.")
  571. return axes
  572. @staticmethod
  573. def check_reshape_shp(shp):
  574. """Check the shape argument for tensor.reshape"""
  575. if len(shp) == 1:
  576. new_shape = shp[0]
  577. # if only one argument provided, it must be int, tuple or list
  578. if isinstance(new_shape, int):
  579. return shp
  580. if isinstance(new_shape, list):
  581. new_shape = tuple(new_shape)
  582. else:
  583. if not isinstance(new_shape, tuple):
  584. raise TypeError(
  585. f"The `shape` should be an int, or tuple/list, or series of int, but got {type(shp[0])}")
  586. return new_shape
  587. return shp
  588. @staticmethod
  589. def check_flatten_order(order):
  590. """Check flatten function input order"""
  591. if not isinstance(order, str):
  592. raise TypeError(f"The order variable should be a string, but got {type(order)}")
  593. if order not in ('C', 'F'):
  594. raise ValueError(f"only `C` and `F` are supported as order, but got {order}")
  595. return order
  596. @staticmethod
  597. def check_swapaxes_axis(axes, ndim):
  598. """Check all the axes argument for tensor.swapaxes"""
  599. if isinstance(axes, int):
  600. Validator.check_axis_in_range(axes, ndim)
  601. return axes % ndim
  602. if isinstance(axes, (tuple, list)):
  603. for axis in axes:
  604. if not isinstance(axis, int):
  605. raise TypeError(f"The axis argument should be integer, but got {type(axis)}.")
  606. Validator.check_axis_in_range(axis, ndim)
  607. axes = tuple(map(lambda x: x % ndim, axes))
  608. return axes
  609. raise TypeError(f"The axes should be integer, list or tuple for check, but got {type(axes)}.")
  610. @staticmethod
  611. def prepare_shape_for_squeeze(shape, axes):
  612. """
  613. Creates the squeezed new shape based on the tensor and given axes.
  614. Args:
  615. shape (tuple): the shape of the tensor
  616. axes Union[int, tuple(int), list(int)]: the axes with dimensions need to
  617. be squeezed.
  618. Returns:
  619. new_shape(tuple): the shape with dimensions squeezed.
  620. """
  621. new_shape = []
  622. ndim = len(shape)
  623. # Convert to set
  624. if isinstance(axes, int):
  625. if axes >= ndim or axes < -ndim:
  626. raise ValueError(f"The axis {axes} is out of bounds for tensor of dimension {ndim}")
  627. axes = {axes}
  628. elif isinstance(axes, (list, tuple)):
  629. for axis in axes:
  630. if axis >= ndim or axis < -ndim:
  631. raise ValueError(f"The axis {axis} is out of bounds for tensor of dimension {ndim}")
  632. axes = set(axes)
  633. else:
  634. raise TypeError(f"Only int, tuple and list are allowed for axes, but got {type(axes)}")
  635. for idx, s in enumerate(shape):
  636. if s != 1 or (idx not in axes) and (idx - ndim not in axes):
  637. new_shape.append(s)
  638. # if an axis is selected with shape entry greater than one, an error is raised.
  639. if s != 1 and ((idx in axes) or (idx - ndim in axes)):
  640. raise ValueError(f"The axis {axes} has shape entry {s} > 1, cannot be squeezed.")
  641. return tuple(new_shape)
  642. @staticmethod
  643. def check_axis_in_range(axis, ndim):
  644. """Checks axes are with the bounds of ndim"""
  645. if not isinstance(axis, int):
  646. raise TypeError(f'The axes should be integers, not {type(axis)}')
  647. if not -ndim <= axis < ndim:
  648. raise ValueError(f'The axis {axis} is out of bounds for array of dimension {ndim}')
  649. return axis % ndim
  650. @staticmethod
  651. def check_axis_valid(axes, ndim):
  652. """
  653. Checks axes are valid given ndim, and returns axes that can be passed
  654. to the built-in operator (non-negative, int or tuple)
  655. """
  656. if axes is None:
  657. axes = tuple(range(ndim))
  658. return axes
  659. if isinstance(axes, (tuple, list)):
  660. for axis in axes:
  661. Validator.check_axis_in_range(axis, ndim)
  662. axes = tuple(map(lambda x: x % ndim, axes))
  663. if any(axes.count(el) > 1 for el in axes):
  664. raise ValueError('duplicate value in "axis"')
  665. return axes
  666. Validator.check_axis_in_range(axes, ndim)
  667. return (axes % ndim,)
  668. @staticmethod
  669. def max_(*args):
  670. return max(*args)
  671. @staticmethod
  672. def min_(*args):
  673. return min(*args)
  674. @staticmethod
  675. def expanded_shape(ndim, axis_size, axis):
  676. """
  677. Returns a shape with size = 1 for all dimensions
  678. except at axis.
  679. """
  680. return tuple(axis_size if i == axis else 1 for i in range(ndim))
  681. @staticmethod
  682. def tuple_slice(tup, start, end):
  683. """get sliced tuple from start and end."""
  684. return tup[start:end]
  685. @staticmethod
  686. def infer_out_shape(*shapes):
  687. """
  688. Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
  689. """
  690. shape_out = deque()
  691. reversed_shapes = map(reversed, shapes)
  692. for items in zip_longest(*reversed_shapes, fillvalue=1):
  693. max_size = 0 if 0 in items else max(items)
  694. if any(item not in (1, max_size) for item in items):
  695. raise ValueError(f'The operands could not be broadcast together with shapes {*shapes,}')
  696. shape_out.appendleft(max_size)
  697. return tuple(shape_out)
  698. @staticmethod
  699. def get_log2_size(size):
  700. return math.ceil(math.log2(size))
  701. @staticmethod
  702. def check_axis_type(axis, type_int=True, type_tuple=True, type_list=True):
  703. """Check axis argument type."""
  704. if type_int and isinstance(axis, int):
  705. return True
  706. if (type_tuple and isinstance(axis, tuple)) or (type_list and isinstance(axis, list)):
  707. for ax in axis:
  708. if not isinstance(ax, int):
  709. raise TypeError(f"Each axis should be integer, but got {type(ax)} in {axis}.")
  710. return True
  711. type_str = ""
  712. if type_int:
  713. type_str += "int, "
  714. if type_tuple:
  715. type_str += "tuple, "
  716. if type_list:
  717. type_str += "list, "
  718. raise TypeError(f"The axis should be {type_str}but got {type(axis)}.")
  719. @staticmethod
  720. def check_and_canonicalize_axes(axes, ndim):
  721. """Check whether the types and values of input axes are valid."""
  722. axes = axes if isinstance(axes, tuple) else (axes,)
  723. new_axes = ()
  724. for ax in axes:
  725. if not isinstance(ax, int):
  726. raise TypeError((f"Each axis should be integer, but got {type(ax)} in {axes}."))
  727. if not -ndim <= ax < ndim:
  728. raise ValueError(f'The axis {ax} is out of bounds for array of dimension {ndim}')
  729. ax = ax if ax >= 0 else ax + ndim
  730. new_axes += (ax,)
  731. if any(new_axes.count(el) > 1 for el in new_axes):
  732. raise ValueError('duplicate value in "axis"')
  733. return new_axes
  734. @staticmethod
  735. def empty_compile(dtype, shape):
  736. """Returns an empty Tensor."""
  737. return Tensor_(dtype, shape)
  738. @staticmethod
  739. def check_type_support(dtype, device, supported_dtypes):
  740. return dtype in supported_dtypes or not context.get_context('device_target') == device
  741. def check_input_format(input_param):
  742. """Judge input format."""
  743. if input_param == "NCHW":
  744. return input_param
  745. raise ValueError("The data format must be NCHW.")
  746. def _expand_tuple(n_dimensions):
  747. """To expand a int number to tuple."""
  748. def convert(m):
  749. if not isinstance(m, tuple):
  750. if isinstance(m, int) and not isinstance(m, bool):
  751. return tuple(repeat(m, n_dimensions))
  752. raise TypeError("Input type must be int or tuple[int].")
  753. if not len(m) is n_dimensions:
  754. raise TypeError("Input tuple dimension is incorrect.")
  755. for i in m:
  756. if not isinstance(i, int) or isinstance(i, bool):
  757. raise TypeError("Incorrect type inside of a tuple, must be int!")
  758. return m
  759. return convert
  760. def _check_data_type_valid(data, valid_type):
  761. """Check data type valid."""
  762. if valid_type is None:
  763. return data is None
  764. if isinstance(data, valid_type):
  765. if hasattr(data, 'size') and data.size == 0:
  766. msg = "Please provide non-empty data."
  767. logger.critical(msg)
  768. raise ValueError(msg)
  769. return True
  770. return False
  771. def check_input_data(*data, data_class):
  772. """Input data check."""
  773. for item in data:
  774. if isinstance(item, (list, tuple)):
  775. for v in item:
  776. check_input_data(v, data_class=data_class)
  777. elif isinstance(item, dict):
  778. for v in item.values():
  779. check_input_data(v, data_class=data_class)
  780. else:
  781. if isinstance(data_class, (tuple, list)):
  782. ret = True in tuple(_check_data_type_valid(item, data_type) for data_type in data_class)
  783. else:
  784. ret = _check_data_type_valid(item, data_class)
  785. if not ret:
  786. data_class_str = tuple(i.__name__ if hasattr(i, '__name__') else i for i in data_class) \
  787. if isinstance(data_class, (tuple, list)) else \
  788. (data_class if data_class is None else data_class.__name__)
  789. raise ValueError(f'Please provide as model inputs either a single or '
  790. f'a tuple or a list or a dict of {data_class_str}, '
  791. f'but got part data type is {item if item is None else type(item).__name__}.')
  792. def check_output_data(data):
  793. """Output data check."""
  794. if data is None:
  795. raise RuntimeError('Executor return data ' + str(data) + ', please check your net or input data.')
  796. once = _expand_tuple(1)
  797. twice = _expand_tuple(2)
  798. triple = _expand_tuple(3)
  799. def args_type_check(*type_args, **type_kwargs):
  800. """Check whether input data type is correct."""
  801. def type_check(func):
  802. sig = inspect.signature(func)
  803. bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments
  804. @wraps(func)
  805. def wrapper(*args, **kwargs):
  806. nonlocal bound_types
  807. bound_values = sig.bind(*args, **kwargs)
  808. argument_dict = bound_values.arguments
  809. if "kwargs" in bound_types:
  810. bound_types = bound_types["kwargs"]
  811. if "kwargs" in argument_dict:
  812. argument_dict = argument_dict["kwargs"]
  813. for name, value in argument_dict.items():
  814. if name in bound_types:
  815. if value is not None and not isinstance(value, bound_types[name]):
  816. raise TypeError('The argument {} must be {}'.format(name, bound_types[name]))
  817. return func(*args, **kwargs)
  818. return wrapper
  819. return type_check
  820. _set_record = {}
  821. def args_unreset_check(*unreset_args, **unreset_kwargs):
  822. """Check the entered non repeatable setting properties."""
  823. def unreset_check(func):
  824. sig = inspect.signature(func)
  825. bound_unreset = sig.bind_partial(*unreset_args, **unreset_kwargs).arguments
  826. @wraps(func)
  827. def wrapper(*args, **kwargs):
  828. nonlocal bound_unreset
  829. bound_values = sig.bind(*args, **kwargs)
  830. argument_dict = bound_values.arguments
  831. if "kwargs" in bound_unreset:
  832. bound_unreset = bound_unreset["kwargs"]
  833. if "kwargs" in argument_dict:
  834. argument_dict = argument_dict["kwargs"]
  835. for name, value in argument_dict.items():
  836. if name in _set_record.keys():
  837. raise TypeError('Argument {} is non-renewable parameter {}.'.format(name, bound_unreset[name]))
  838. if name in bound_unreset:
  839. _set_record[name] = value
  840. return func(*args, **kwargs)
  841. return wrapper
  842. return unreset_check