You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_checkparam.py 28 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Check parameters."""
  16. import re
  17. import inspect
  18. import math
  19. from enum import Enum
  20. from functools import reduce, wraps
  21. from itertools import repeat
  22. from collections.abc import Iterable
  23. import numpy as np
  24. from mindspore import log as logger
  25. from mindspore.common import dtype as mstype
  26. class Rel(Enum):
  27. """Numerical relationship between variables, logical relationship enumeration definition of range."""
  28. # scalar compare
  29. EQ = 1 # ==
  30. NE = 2 # !=
  31. LT = 3 # <
  32. LE = 4 # <=
  33. GT = 5 # >
  34. GE = 6 # >=
  35. # scalar range check
  36. INC_NEITHER = 7 # (), include neither
  37. INC_LEFT = 8 # [), include left
  38. INC_RIGHT = 9 # (], include right
  39. INC_BOTH = 10 # [], include both
  40. # collection in, not in
  41. IN = 11
  42. NOT_IN = 12
  43. @staticmethod
  44. def get_strs(rel):
  45. """Get value from rel_strs."""
  46. return rel_strs.get(rel, "")
  47. @staticmethod
  48. def get_fns(rel):
  49. """Get value from rel_fns."""
  50. return rel_fns.get(rel, lambda *args: False)
  51. rel_fns = {
  52. # scalar compare
  53. Rel.EQ: lambda x, y: x == y,
  54. Rel.NE: lambda x, y: x != y,
  55. Rel.LT: lambda x, y: x < y,
  56. Rel.LE: lambda x, y: x <= y,
  57. Rel.GT: lambda x, y: x > y,
  58. Rel.GE: lambda x, y: x >= y,
  59. # scalar range check
  60. Rel.INC_NEITHER: lambda x, lower, upper: (lower < x < upper),
  61. Rel.INC_LEFT: lambda x, lower, upper: (lower <= x < upper),
  62. Rel.INC_RIGHT: lambda x, lower, upper: (lower < x <= upper),
  63. Rel.INC_BOTH: lambda x, lower, upper: (lower <= x <= upper),
  64. # collection in, not in
  65. Rel.IN: lambda x, y: x in y,
  66. Rel.NOT_IN: lambda x, y: x not in y,
  67. }
  68. rel_strs = {
  69. # scalar compare
  70. Rel.EQ: "equal to {}",
  71. Rel.NE: "not equal to {}",
  72. Rel.LT: "less than {}",
  73. Rel.LE: "less or equal to {}",
  74. Rel.GT: "greater than {}",
  75. Rel.GE: "greater or equal to {}",
  76. # scalar range check
  77. Rel.INC_NEITHER: "({}, {})",
  78. Rel.INC_LEFT: "[{}, {})",
  79. Rel.INC_RIGHT: "({}, {}]",
  80. Rel.INC_BOTH: "[{}, {}]",
  81. # collection in, not in
  82. Rel.IN: "in {}",
  83. Rel.NOT_IN: "not in {}",
  84. }
  85. def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=None):
  86. """
  87. Check argument integer.
  88. Usage:
  89. - number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0
  90. """
  91. rel_fn = Rel.get_fns(rel)
  92. type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool)
  93. type_except = TypeError if type_mismatch else ValueError
  94. if type_mismatch or not rel_fn(arg_value, value):
  95. rel_str = Rel.get_strs(rel).format(value)
  96. arg_name = arg_name if arg_name else "parameter"
  97. msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
  98. raise type_except(f'{msg_prefix} `{arg_name}` should be an {arg_type} and must {rel_str}, but got `{arg_value}`'
  99. f' with type `{type(arg_value).__name__}`.')
  100. return arg_value
  101. def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None):
  102. """
  103. Checks input value is float type or not.
  104. Usage:
  105. - number = check_is_number(number, int)
  106. - number = check_is_number(number, int, "bias")
  107. - number = check_is_number(number, int, "bias", "bias_class")
  108. """
  109. prim_name = f'in \'{prim_name}\'' if prim_name else ''
  110. arg_name = f'\'{prim_name}\'' if arg_name else 'Input value'
  111. if isinstance(arg_value, arg_type):
  112. if math.isinf(arg_value) or math.isnan(arg_value):
  113. raise ValueError(f'{arg_name} {prim_name} must be legal float, but got `{arg_value}`.')
  114. return arg_value
  115. raise TypeError(f'{arg_name} {prim_name} must be float, but got `{type(arg_value).__name__}`')
  116. class Validator:
  117. """validator for checking input parameters"""
  118. @staticmethod
  119. def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None, excp_cls=ValueError):
  120. """
  121. Method for judging relation between two int values or list/tuple made up of ints.
  122. This method is not suitable for judging relation between floats, since it does not consider float error.
  123. """
  124. rel_fn = Rel.get_fns(rel)
  125. if not rel_fn(arg_value, value):
  126. rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
  127. msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
  128. raise excp_cls(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.')
  129. return arg_value
  130. @staticmethod
  131. def check_integer(arg_name, arg_value, value, rel, prim_name=None):
  132. """Check argument is integer"""
  133. rel_fn = Rel.get_fns(rel)
  134. type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
  135. excp_cls = TypeError if type_mismatch else ValueError
  136. if type_mismatch or not rel_fn(arg_value, value):
  137. rel_str = Rel.get_strs(rel).format(value)
  138. msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
  139. raise excp_cls(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got `{arg_value}`'
  140. f' with type `{type(arg_value).__name__}`.')
  141. return arg_value
  142. @staticmethod
  143. def check_is_int(arg_value, arg_name=None, prim_name=None):
  144. """
  145. Checks input value is float type or not.
  146. Usage:
  147. - number = check_is_int(number, int)
  148. - number = check_is_int(number, int, "bias")
  149. - number = check_is_int(number, int, "bias", "bias_class")
  150. """
  151. check_is_number(arg_value, int, arg_name, prim_name)
  152. @staticmethod
  153. def check_positive_int(arg_value, arg_name=None, prim_name=None):
  154. """
  155. Check argument is positive integer, which mean arg_value > 0.
  156. Usage:
  157. - number = check_positive_int(number)
  158. - number = check_positive_int(number, "bias")
  159. """
  160. return check_number(arg_value, 0, Rel.GT, int, arg_name, prim_name)
  161. @staticmethod
  162. def check_negative_int(arg_value, arg_name=None, prim_name=None):
  163. """
  164. Check argument is negative integer, which mean arg_value < 0.
  165. Usage:
  166. - number = check_negative_int(number)
  167. - number = check_negative_int(number, "bias")
  168. """
  169. return check_number(arg_value, 0, Rel.LT, int, arg_name, prim_name)
  170. @staticmethod
  171. def check_non_positive_int(arg_value, arg_name=None, prim_name=None):
  172. """
  173. Check argument is non-negative integer, which mean arg_value <= 0.
  174. Usage:
  175. - number = check_non_positive_int(number)
  176. - number = check_non_positive_int(number, "bias")
  177. """
  178. return check_number(arg_value, 0, Rel.LE, int, arg_name, prim_name)
  179. @staticmethod
  180. def check_non_negative_int(arg_value, arg_name=None, prim_name=None):
  181. """
  182. Check argument is non-negative integer, which mean arg_value >= 0.
  183. Usage:
  184. - number = check_non_negative_int(number)
  185. - number = check_non_negative_int(number, "bias")
  186. """
  187. return check_number(arg_value, 0, Rel.GE, int, arg_name, prim_name)
  188. @staticmethod
  189. def check_is_float(arg_value, arg_name=None, prim_name=None):
  190. """
  191. Checks input value is float type or not.
  192. Usage:
  193. - number = check_is_float(number, int)
  194. - number = check_is_float(number, int, "bias")
  195. - number = check_is_float(number, int, "bias", "bias_class")
  196. """
  197. check_is_number(arg_value, float, arg_name, prim_name)
  198. @staticmethod
  199. def check_positive_float(arg_value, arg_name=None, prim_name=None):
  200. """
  201. Check argument is positive float, which mean arg_value > 0.
  202. Usage:
  203. - number = check_positive_float(number)
  204. - number = check_positive_float(number, "bias")
  205. - number = check_positive_float(number, "bias", "bias_class")
  206. """
  207. return check_number(arg_value, 0, Rel.GT, float, arg_name, prim_name)
  208. @staticmethod
  209. def check_negative_float(arg_value, arg_name=None, prim_name=None):
  210. """
  211. Check argument is negative float, which mean arg_value < 0.
  212. Usage:
  213. - number = check_negative_float(number)
  214. - number = check_negative_float(number, "bias")
  215. """
  216. return check_number(arg_value, 0, Rel.LT, float, arg_name, prim_name)
  217. @staticmethod
  218. def check_non_positive_float(arg_value, arg_name=None, prim_name=None):
  219. """
  220. Check argument is non-negative float, which mean arg_value <= 0.
  221. Usage:
  222. - number = check_non_positive_float(number)
  223. - number = check_non_positive_float(number, "bias")
  224. """
  225. return check_number(arg_value, 0, Rel.LE, float, arg_name, prim_name)
  226. @staticmethod
  227. def check_non_negative_float(arg_value, arg_name=None, prim_name=None):
  228. """
  229. Check argument is non-negative float, which mean arg_value >= 0.
  230. Usage:
  231. - number = check_non_negative_float(number)
  232. - number = check_non_negative_float(number, "bias")
  233. """
  234. return check_number(arg_value, 0, Rel.GE, float, arg_name, prim_name)
  235. @staticmethod
  236. def check_number(arg_name, arg_value, value, rel, prim_name):
  237. """Number value judgment."""
  238. rel_fn = Rel.get_fns(rel)
  239. if not rel_fn(arg_value, value):
  240. rel_str = Rel.get_strs(rel).format(value)
  241. raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must {rel_str}, but got {arg_value}.')
  242. return arg_value
  243. @staticmethod
  244. def check_isinstance(arg_name, arg_value, classes):
  245. """Check arg isinstance of classes"""
  246. if not isinstance(arg_value, classes):
  247. raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
  248. return arg_value
  249. @staticmethod
  250. def check_bool(arg_value, arg_name=None):
  251. """
  252. Check argument is instance of bool.
  253. Usage:
  254. - has_bias = check_bool(has_bias)
  255. - has_bias = check_bool(has_bias, "has_bias")
  256. """
  257. if not isinstance(arg_value, bool):
  258. arg_name = arg_name if arg_name else "Parameter"
  259. raise TypeError(f'`{arg_name}` should be isinstance of bool, but got `{arg_value}`.')
  260. return arg_value
  261. @staticmethod
  262. def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name):
  263. """Method for checking whether an int value is in some range."""
  264. rel_fn = Rel.get_fns(rel)
  265. type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
  266. excp_cls = TypeError if type_mismatch else ValueError
  267. if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
  268. rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
  269. raise excp_cls(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},'
  270. f' but got `{arg_value}` with type `{type(arg_value).__name__}`.')
  271. return arg_value
  272. @staticmethod
  273. def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name):
  274. """Method for checking whether a numeric value is in some range."""
  275. rel_fn = Rel.get_fns(rel)
  276. if not rel_fn(arg_value, lower_limit, upper_limit):
  277. rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
  278. raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be in range {rel_str}, but got {arg_value}.')
  279. return arg_value
  280. @staticmethod
  281. def check_string(arg_value, valid_values, arg_name=None, prim_name=None):
  282. """
  283. Check whether string is in some value list.
  284. Usage:
  285. - method = check_string(method, ["string1", "string2", "string3"], "method")
  286. """
  287. if isinstance(arg_value, str) and arg_value in valid_values:
  288. return arg_value
  289. arg_name = arg_name if arg_name else "Parameter"
  290. msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
  291. raise ValueError(f'{msg_prefix} `{arg_name}` should be str and must be in `{valid_values}`,'
  292. f' but got `{arg_value}`.')
  293. @staticmethod
  294. def check_pad_value_by_mode(pad_mode, padding, prim_name):
  295. """Validates value of padding according to pad_mode"""
  296. if pad_mode != 'pad' and padding != 0:
  297. raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}'.")
  298. return padding
  299. @staticmethod
  300. def check_subclass(arg_name, type_, template_types, prim_name):
  301. """Checks whether some type is subclass of another type"""
  302. if not isinstance(template_types, Iterable):
  303. template_types = (template_types,)
  304. hit = False
  305. for template_type in template_types:
  306. if isinstance(template_type, mstype.Type):
  307. if mstype.issubclass_(type_, template_type):
  308. hit = True
  309. break
  310. elif type_ is template_type:
  311. hit = True
  312. break
  313. if not hit:
  314. type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
  315. raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass'
  316. f' of {",".join((str(x) for x in template_types))}, but got {type_str}.')
  317. @staticmethod
  318. def check_const_input(arg_name, arg_value, prim_name):
  319. """Checks valid value."""
  320. if arg_value is None:
  321. raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.')
  322. return arg_value
  323. @staticmethod
  324. def check_type(arg_name, arg_value, valid_types):
  325. """Type checking."""
  326. def raise_error_msg():
  327. """func for raising error message when check failed"""
  328. raise TypeError(f'The type of `{arg_name}` should be in {valid_types}, but got {type(arg_value).__name__}.')
  329. if isinstance(arg_value, type(mstype.tensor)):
  330. arg_value = arg_value.element_type()
  331. if isinstance(arg_value, bool) and bool not in tuple(valid_types):
  332. raise_error_msg()
  333. if arg_value in valid_types:
  334. return arg_value
  335. if isinstance(arg_value, tuple(valid_types)):
  336. return arg_value
  337. raise_error_msg()
  338. @staticmethod
  339. def check_type_same(args, valid_values, prim_name):
  340. """Checks whether the types of inputs are the same."""
  341. def _check_tensor_type(arg):
  342. arg_key, arg_val = arg
  343. elem_type = arg_val
  344. Validator.check_subclass(arg_key, elem_type, valid_values, prim_name)
  345. return (arg_key, elem_type)
  346. def _check_types_same(arg1, arg2):
  347. arg1_name, arg1_type = arg1
  348. arg2_name, arg2_type = arg2
  349. if arg1_type != arg2_type:
  350. raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
  351. f' but `{arg1_name}` with type {arg1_type} and `{arg2_name}` with type {arg2_type}.')
  352. return arg1
  353. elem_types = map(_check_tensor_type, args.items())
  354. reduce(_check_types_same, elem_types)
  355. @staticmethod
  356. def check_tensor_type_same(args, valid_values, prim_name):
  357. """Checks whether the element types of input tensors are the same."""
  358. tensor_types = [mstype.tensor_type(t) for t in valid_values]
  359. Validator.check_type_same(args, tensor_types, prim_name)
  360. @staticmethod
  361. def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False):
  362. """
  363. Checks whether the types of inputs are the same. If the input args are tensors, checks their element types.
  364. If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised.
  365. """
  366. def _check_argument_type(arg):
  367. arg_key, arg_val = arg
  368. if isinstance(arg_val, type(mstype.tensor)):
  369. arg_val = arg_val.element_type()
  370. if not arg_val in valid_values:
  371. raise TypeError(f'For \'{prim_name}\' the `{arg_key}` should be in {valid_values},'
  372. f' but `{arg_key}` is {arg_val}.')
  373. return arg
  374. def _check_types_same(arg1, arg2):
  375. arg1_name, arg1_type = arg1
  376. arg2_name, arg2_type = arg2
  377. except_flag = False
  378. if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)):
  379. arg1_type = arg1_type.element_type()
  380. arg2_type = arg2_type.element_type()
  381. elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))):
  382. pass
  383. elif allow_mix:
  384. arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type
  385. arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type
  386. else:
  387. except_flag = True
  388. if except_flag or arg1_type != arg2_type:
  389. raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
  390. f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
  391. return arg1
  392. reduce(_check_types_same, map(_check_argument_type, args.items()))
  393. @staticmethod
  394. def check_value_type(arg_name, arg_value, valid_types, prim_name):
  395. """Checks whether a value is instance of some types."""
  396. valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
  397. def raise_error_msg():
  398. """func for raising error message when check failed"""
  399. type_names = [t.__name__ for t in valid_types]
  400. num_types = len(valid_types)
  401. msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
  402. raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
  403. f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
  404. # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
  405. # `check_value_type('x', True, [bool, int])` will check pass
  406. if isinstance(arg_value, bool) and bool not in tuple(valid_types):
  407. raise_error_msg()
  408. if isinstance(arg_value, tuple(valid_types)):
  409. return arg_value
  410. raise_error_msg()
  411. @staticmethod
  412. def check_type_name(arg_name, arg_type, valid_types, prim_name):
  413. """Checks whether a type in some specified types"""
  414. valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
  415. def get_typename(t):
  416. return t.__name__ if hasattr(t, '__name__') else str(t)
  417. if isinstance(arg_type, type(mstype.tensor)):
  418. arg_type = arg_type.element_type()
  419. if arg_type in valid_types:
  420. return arg_type
  421. type_names = [get_typename(t) for t in valid_types]
  422. msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
  423. if len(valid_types) == 1:
  424. raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},'
  425. f' but got {get_typename(arg_type)}.')
  426. raise TypeError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},'
  427. f' but got {get_typename(arg_type)}.')
  428. @staticmethod
  429. def check_reduce_shape(ori_shape, shape, axis, prim_name):
  430. """Checks whether shape is ori_shape reduced on axis"""
  431. axis = axis if isinstance(axis, Iterable) else (axis,)
  432. exp_shape = [ori_shape[i] for i in range(len(ori_shape)) if i not in axis]
  433. if list(shape) != exp_shape:
  434. raise ValueError(f'For {prim_name}, {ori_shape} reduce on {axis} should be '
  435. f'{tuple(exp_shape)}, but got {shape}.')
  436. def check_int(input_param):
  437. """Int type judgment."""
  438. if isinstance(input_param, int) and not isinstance(input_param, bool):
  439. return input_param
  440. raise TypeError("Input type must be int!")
  441. def check_int_zero_one(input_param):
  442. """Judge whether it is 0 or 1."""
  443. if input_param in (0, 1):
  444. return input_param
  445. raise ValueError("The data must be 0 or 1.")
  446. def check_input_format(input_param):
  447. """Judge input format."""
  448. if input_param == "NCHW":
  449. return input_param
  450. raise ValueError("The data format must be NCHW.")
  451. def check_padding(padding):
  452. """Check padding."""
  453. if padding >= 0:
  454. return padding
  455. raise ValueError("The padding must be at least 0,"" but got padding {}.".format(padding))
  456. def check_padmode(mode):
  457. """Check padmode."""
  458. if mode in ("same", "valid", "pad"):
  459. return mode
  460. raise ValueError("The pad mode must be same or valid or pad,"" but got mode {}.".format(mode))
  461. def check_tensor_supported_type(dtype):
  462. """Check tensor dtype."""
  463. if dtype in (mstype.int32, mstype.float32):
  464. return dtype
  465. raise ValueError("The dtype must be mstype.int32 or mstype.float32, but got mstype {}.".format(dtype))
  466. def _expand_tuple(n_dimensions):
  467. """To expand a number to tuple."""
  468. def convert(m):
  469. if not isinstance(m, tuple):
  470. if isinstance(m, int):
  471. return tuple(repeat(m, n_dimensions))
  472. raise TypeError("Input type must be int or tuple.")
  473. if not len(m) is n_dimensions:
  474. raise TypeError("Input dimension is incorrect.")
  475. for i in m:
  476. if not isinstance(i, int):
  477. raise TypeError("Incorrect type inside of a tuple!")
  478. return m
  479. return convert
  480. def check_input_data(*data, data_class):
  481. """Input data check."""
  482. for item in data:
  483. if isinstance(item, (list, tuple)):
  484. for v in item:
  485. check_input_data(v, data_class=data_class)
  486. else:
  487. if not isinstance(item, data_class):
  488. raise ValueError(f'Please provide as model inputs'
  489. f' either a single'
  490. f' or a list of {data_class.__name__},'
  491. f' but got part data type is {str(type(item))}.')
  492. if item.size() == 0:
  493. msg = "Please provide non-empty data."
  494. logger.error(msg)
  495. raise ValueError(msg)
  496. def check_output_data(data):
  497. """Output data check."""
  498. if data is None:
  499. raise RuntimeError('Executor return data ' + str(data) + ', please check your net or input data.')
  500. once = _expand_tuple(1)
  501. twice = _expand_tuple(2)
  502. triple = _expand_tuple(3)
  503. valid_data_types = (int, float, np.int8, np.int16, np.int32, np.int64,
  504. np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
  505. np.float32, np.float64, bool, np.bool_)
  506. def check_type(arg_name, arg_value, valid_types):
  507. """Check value type."""
  508. # if input type is Tensor ,get element type
  509. if isinstance(arg_value, type(mstype.tensor)):
  510. arg_value = arg_value.element_type()
  511. # First, check if arg_value has argvalid_types
  512. if isinstance(arg_value, tuple(valid_types)):
  513. return type(arg_value).__name__
  514. # Second, wrap arg_value with numpy array so that it can be checked through numpy api
  515. if isinstance(arg_value, (list, tuple)):
  516. arg_value = np.array(arg_value)
  517. # Thirdly, check the data type by numpy's dtype api
  518. valid = False
  519. if isinstance(arg_value, np.ndarray):
  520. valid = arg_value.dtype in valid_data_types
  521. # Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
  522. # `check_type('x', True, [bool, int])` will check pass
  523. if isinstance(arg_value, bool) and bool not in tuple(valid_types):
  524. valid = False
  525. if not valid:
  526. type_names = [t.__name__ for t in valid_types]
  527. if len(valid_types) == 1:
  528. raise TypeError(f'The type of `{arg_name}` should be {type_names[0]},'
  529. f' but got {type(arg_value).__name__}.')
  530. raise TypeError(f'The type of `{arg_name}` should be one of {type_names},'
  531. f' but got {type(arg_value).__name__}.')
  532. return type(arg_value).__name__
  533. def check_typename(arg_name, arg_type, valid_types):
  534. """Check type name."""
  535. def get_typename(t):
  536. return t.__name__ if hasattr(t, '__name__') else str(t)
  537. if isinstance(arg_type, type(mstype.tensor)):
  538. arg_type = arg_type.element_type()
  539. if arg_type in valid_types:
  540. return arg_type
  541. if isinstance(arg_type, tuple(valid_types)):
  542. return arg_type
  543. type_names = [get_typename(t) for t in valid_types]
  544. if len(valid_types) == 1:
  545. raise TypeError(f'The type of `{arg_name}` should be {type_names[0]},'
  546. f' but got {get_typename(arg_type)}.')
  547. raise TypeError(f'The type of `{arg_name}` should be one of {type_names},'
  548. f' but got {get_typename(arg_type)}.')
  549. def check_shape(arg_name, arg_value):
  550. """Check shape."""
  551. # First, check if shape is a tuple
  552. if not isinstance(arg_value, tuple):
  553. raise TypeError(f'The type of `{arg_name}` should be one of {tuple.__name__},'
  554. f' but got {type(arg_value).__name__}.')
  555. # Second, wrap arg_value with numpy array so that it can be checked through numpy api
  556. arg_value = np.array(arg_value)
  557. # shape can not be ()
  558. if arg_value.size == 0:
  559. raise ValueError('Shape can not be empty.')
  560. # shape's dimension should be 1
  561. if arg_value.ndim != 1:
  562. raise ValueError('Shape of tensor should be 1-dim vector, but got {}-dim.'.format(arg_value.ndim))
  563. # Thirdly, check each element's type of the shape
  564. valid_types = (int, np.int8, np.int16, np.int32, np.int64,
  565. np.uint8, np.uint16, np.uint32, np.uint64)
  566. for dim_size in arg_value:
  567. if not isinstance(dim_size, valid_types) or dim_size <= 0:
  568. raise ValueError('Every dimension size of the tensor shape should be a positive integer,'
  569. ' but got {}.'.format(dim_size))
  570. def _check_str_by_regular(target, reg=None, flag=re.ASCII):
  571. if reg is None:
  572. # Named string regular expression
  573. reg = r"^\w+[0-9a-zA-Z\_\.]*$"
  574. if re.match(reg, target, flag) is None:
  575. raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag))
  576. return True
  577. def args_type_check(*type_args, **type_kwargs):
  578. """Check whether input data type is correct."""
  579. def type_check(func):
  580. sig = inspect.signature(func)
  581. bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments
  582. @wraps(func)
  583. def wrapper(*args, **kwargs):
  584. nonlocal bound_types
  585. bound_values = sig.bind(*args, **kwargs)
  586. argument_dict = bound_values.arguments
  587. if "kwargs" in bound_types:
  588. bound_types = bound_types["kwargs"]
  589. if "kwargs" in argument_dict:
  590. argument_dict = argument_dict["kwargs"]
  591. for name, value in argument_dict.items():
  592. if name in bound_types:
  593. if value is not None and not isinstance(value, bound_types[name]):
  594. raise TypeError('Argument {} must be {}'.format(name, bound_types[name]))
  595. return func(*args, **kwargs)
  596. return wrapper
  597. return type_check