You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_checkparam.py 27 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Check parameters."""
  16. import re
  17. from enum import Enum
  18. from functools import reduce
  19. from itertools import repeat
  20. from collections import Iterable
  21. import numpy as np
  22. from mindspore import log as logger
  23. from .common import dtype as mstype
  24. # Named string regular expression
  25. _name_re = r"^\w+[0-9a-zA-Z\_\.]*$"
  26. class Rel(Enum):
  27. """Numerical relationship between variables, logical relationship enumeration definition of range."""
  28. # scalar compare
  29. EQ = 1 # ==
  30. NE = 2 # !=
  31. LT = 3 # <
  32. LE = 4 # <=
  33. GT = 5 # >
  34. GE = 6 # >=
  35. # scalar range check
  36. INC_NEITHER = 7 # (), include neither
  37. INC_LEFT = 8 # [), include left
  38. INC_RIGHT = 9 # (], include right
  39. INC_BOTH = 10 # [], include both
  40. # collection in, not in
  41. IN = 11
  42. NOT_IN = 12
  43. @staticmethod
  44. def get_strs(rel):
  45. """Get value from rel_strs."""
  46. return rel_strs.get(rel, "")
  47. @staticmethod
  48. def get_fns(rel):
  49. """Get value from rel_fns."""
  50. return rel_fns.get(rel, lambda *args: False)
  51. rel_fns = {
  52. # scalar compare
  53. Rel.EQ: lambda x, y: x == y,
  54. Rel.NE: lambda x, y: x != y,
  55. Rel.LT: lambda x, y: x < y,
  56. Rel.LE: lambda x, y: x <= y,
  57. Rel.GT: lambda x, y: x > y,
  58. Rel.GE: lambda x, y: x >= y,
  59. # scalar range check
  60. Rel.INC_NEITHER: lambda x, lower, upper: (lower < x < upper),
  61. Rel.INC_LEFT: lambda x, lower, upper: (lower <= x < upper),
  62. Rel.INC_RIGHT: lambda x, lower, upper: (lower < x <= upper),
  63. Rel.INC_BOTH: lambda x, lower, upper: (lower <= x <= upper),
  64. # collection in, not in
  65. Rel.IN: lambda x, y: x in y,
  66. Rel.NOT_IN: lambda x, y: x not in y,
  67. }
  68. rel_strs = {
  69. # scalar compare
  70. Rel.EQ: "equal to {}",
  71. Rel.NE: "not equal to {}",
  72. Rel.LT: "less than {}",
  73. Rel.LE: "less or equal to {}",
  74. Rel.GT: "greater than {}",
  75. Rel.GE: "greater or equal to {}",
  76. # scalar range check
  77. Rel.INC_NEITHER: "({}, {})",
  78. Rel.INC_LEFT: "[{}, {})",
  79. Rel.INC_RIGHT: "({}, {}]",
  80. Rel.INC_BOTH: "[{}, {}]",
  81. # collection in, not in
  82. Rel.IN: "in {}",
  83. Rel.NOT_IN: "not in {}",
  84. }
  85. class Validator:
  86. """validator for checking input parameters"""
  87. @staticmethod
  88. def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None):
  89. """
  90. Method for judging relation between two int values or list/tuple made up of ints.
  91. This method is not suitable for judging relation between floats, since it does not consider float error.
  92. """
  93. rel_fn = Rel.get_fns(rel)
  94. if not rel_fn(arg_value, value):
  95. rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
  96. msg_prefix = f'For {prim_name} the' if prim_name else "The"
  97. raise ValueError(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.')
  98. @staticmethod
  99. def check_integer(arg_name, arg_value, value, rel, prim_name):
  100. """Integer value judgment."""
  101. rel_fn = Rel.get_fns(rel)
  102. type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
  103. if type_mismatch or not rel_fn(arg_value, value):
  104. rel_str = Rel.get_strs(rel).format(value)
  105. raise ValueError(f'For {prim_name} the `{arg_name}` should be an int and must {rel_str},'
  106. f' but got {arg_value}.')
  107. return arg_value
  108. @staticmethod
  109. def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name):
  110. """Method for checking whether an int value is in some range."""
  111. rel_fn = Rel.get_fns(rel)
  112. type_mismatch = not isinstance(arg_value, int)
  113. if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
  114. rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
  115. raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},'
  116. f' but got {arg_value}.')
  117. return arg_value
  118. @staticmethod
  119. def check_subclass(arg_name, type_, template_type, prim_name):
  120. """Check whether some type is sublcass of another type"""
  121. if not isinstance(template_type, Iterable):
  122. template_type = (template_type,)
  123. if not any([mstype.issubclass_(type_, x) for x in template_type]):
  124. type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
  125. raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass'
  126. f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')
  127. @staticmethod
  128. def check_tensor_type_same(args, valid_values, prim_name):
  129. """check whether the element types of input tensors are the same."""
  130. def _check_tensor_type(arg):
  131. arg_key, arg_val = arg
  132. Validator.check_subclass(arg_key, arg_val, mstype.tensor, prim_name)
  133. elem_type = arg_val.element_type()
  134. if not elem_type in valid_values:
  135. raise TypeError(f'For \'{prim_name}\' element type of `{arg_key}` should be in {valid_values},'
  136. f' but `{arg_key}` is {elem_type}.')
  137. return (arg_key, elem_type)
  138. def _check_types_same(arg1, arg2):
  139. arg1_name, arg1_type = arg1
  140. arg2_name, arg2_type = arg2
  141. if arg1_type != arg2_type:
  142. raise TypeError(f'For \'{prim_name}\' element type of `{arg2_name}` should be same as `{arg1_name}`,'
  143. f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
  144. return arg1
  145. elem_types = map(_check_tensor_type, args.items())
  146. reduce(_check_types_same, elem_types)
  147. @staticmethod
  148. def check_scalar_or_tensor_type_same(args, valid_values, prim_name):
  149. """check whether the types of inputs are the same. if the input args are tensors, check their element types"""
  150. def _check_argument_type(arg):
  151. arg_key, arg_val = arg
  152. if isinstance(arg_val, type(mstype.tensor)):
  153. arg_val = arg_val.element_type()
  154. if not arg_val in valid_values:
  155. raise TypeError(f'For \'{prim_name}\' the `{arg_key}` should be in {valid_values},'
  156. f' but `{arg_key}` is {arg_val}.')
  157. return arg
  158. def _check_types_same(arg1, arg2):
  159. arg1_name, arg1_type = arg1
  160. arg2_name, arg2_type = arg2
  161. excp_flag = False
  162. if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)):
  163. arg1_type = arg1_type.element_type()
  164. arg2_type = arg2_type.element_type()
  165. elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))):
  166. pass
  167. else:
  168. excp_flag = True
  169. if excp_flag or arg1_type != arg2_type:
  170. raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
  171. f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
  172. return arg1
  173. reduce(_check_types_same, map(_check_argument_type, args.items()))
  174. @staticmethod
  175. def check_value_type(arg_name, arg_value, valid_types, prim_name):
  176. """Check whether a values is instance of some types."""
  177. def raise_error_msg():
  178. """func for raising error message when check failed"""
  179. type_names = [t.__name__ for t in valid_types]
  180. num_types = len(valid_types)
  181. raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be '
  182. f'{"one of " if num_types > 1 else ""}'
  183. f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
  184. # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
  185. # `check_value_type('x', True, [bool, int])` will check pass
  186. if isinstance(arg_value, bool) and bool not in tuple(valid_types):
  187. raise_error_msg()
  188. if isinstance(arg_value, tuple(valid_types)):
  189. return arg_value
  190. raise_error_msg()
  191. class ParamValidator:
  192. """Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
  193. @staticmethod
  194. def equal(arg_name, arg_value, cond_str, cond):
  195. """Judging valid value."""
  196. if not cond:
  197. raise ValueError(f'The `{arg_name}` must be {cond_str}, but got {arg_value}.')
  198. @staticmethod
  199. def check(arg_name, arg_value, value_name, value, rel=Rel.EQ):
  200. """This method is only used for check int values, since when compare float values,
  201. we need consider float error."""
  202. rel_fn = Rel.get_fns(rel)
  203. if not rel_fn(arg_value, value):
  204. rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
  205. raise ValueError(f'The `{arg_name}` should be {rel_str}, but got {arg_value}.')
  206. @staticmethod
  207. def check_integer(arg_name, arg_value, value, rel):
  208. """Integer value judgment."""
  209. rel_fn = Rel.get_fns(rel)
  210. type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
  211. if type_mismatch or not rel_fn(arg_value, value):
  212. rel_str = Rel.get_strs(rel).format(value)
  213. raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
  214. return arg_value
  215. @staticmethod
  216. def check_shape_length(arg_name, arg_value, value, rel):
  217. """Shape length judgment."""
  218. rel_fn = Rel.get_fns(rel)
  219. type_mismatch = not isinstance(arg_value, int)
  220. if type_mismatch or not rel_fn(arg_value, value):
  221. rel_str = Rel.get_strs(rel).format(value)
  222. raise ValueError(f'The length of `{arg_name}` should be an int and must {rel_str}, but got {arg_value}')
  223. return arg_value
  224. @staticmethod
  225. def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel):
  226. """This method is only used for check int values,
  227. since when compare float values, we need consider float error."""
  228. rel_fn = Rel.get_fns(rel)
  229. type_mismatch = not isinstance(arg_value, int)
  230. if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
  231. rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
  232. raise ValueError(f'The `{arg_name}` should be an int in range {rel_str}, but got {arg_value}.')
  233. return arg_value
  234. @staticmethod
  235. def check_isinstance(arg_name, arg_value, classes):
  236. """Check arg isintance of classes"""
  237. if not isinstance(arg_value, classes):
  238. raise ValueError(f'The `{arg_name}` should be isintance of {classes}, but got {arg_value}.')
  239. return arg_value
  240. @staticmethod
  241. def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel):
  242. """Is it necessary to consider error when comparing float values."""
  243. rel_fn = Rel.get_fns(rel)
  244. if not rel_fn(arg_value, lower_limit, upper_limit):
  245. rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
  246. raise ValueError(f'The `{arg_name}` should be in range {rel_str}, but got {arg_value}.')
  247. return arg_value
  248. @staticmethod
  249. def check_subclass(arg_name, type_, template_type, with_type_of=True):
  250. """Check whether some type is sublcass of another type"""
  251. if not isinstance(template_type, Iterable):
  252. template_type = (template_type,)
  253. if not any([mstype.issubclass_(type_, x) for x in template_type]):
  254. type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
  255. raise TypeError(f'The {"type of" if with_type_of else ""} `{arg_name}` should be subclass'
  256. f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')
  257. @staticmethod
  258. def check_args_tensor(args):
  259. """Check whether args are all tensor."""
  260. if not isinstance(args, dict):
  261. raise TypeError("The args should be a dict.")
  262. for arg, value in args.items():
  263. ParamValidator.check_subclass(arg, value, mstype.tensor)
  264. @staticmethod
  265. def check_bool(arg_name, arg_value):
  266. """Check arg isintance of bool"""
  267. if not isinstance(arg_value, bool):
  268. raise ValueError(f'The `{arg_name}` should be isintance of bool, but got {arg_value}.')
  269. return arg_value
  270. @staticmethod
  271. def check_type(arg_name, arg_value, valid_types):
  272. """Type checking."""
  273. def raise_error_msg():
  274. """func for raising error message when check failed"""
  275. type_names = [t.__name__ for t in valid_types]
  276. num_types = len(valid_types)
  277. raise ValueError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
  278. f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
  279. if isinstance(arg_value, type(mstype.tensor)):
  280. arg_value = arg_value.element_type()
  281. # Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
  282. # `check_type('x', True, [bool, int])` will check pass
  283. if isinstance(arg_value, bool) and bool not in tuple(valid_types):
  284. raise_error_msg()
  285. if isinstance(arg_value, tuple(valid_types)):
  286. return arg_value
  287. raise_error_msg()
  288. @staticmethod
  289. def check_typename(arg_name, arg_type, valid_types):
  290. """Does it contain the _name_ attribute."""
  291. def get_typename(t):
  292. return t.__name__ if hasattr(t, '__name__') else str(t)
  293. if isinstance(arg_type, type(mstype.tensor)):
  294. arg_type = arg_type.element_type()
  295. if arg_type in valid_types:
  296. return arg_type
  297. type_names = [get_typename(t) for t in valid_types]
  298. if len(valid_types) == 1:
  299. raise ValueError(f'The type of `{arg_name}` should be {type_names[0]},'
  300. f' but got {get_typename(arg_type)}.')
  301. raise ValueError(f'The type of `{arg_name}` should be one of {type_names},'
  302. f' but got {get_typename(arg_type)}.')
  303. @staticmethod
  304. def check_string(arg_name, arg_value, valid_values):
  305. """String type judgment."""
  306. if isinstance(arg_value, str) and arg_value in valid_values:
  307. return arg_value
  308. if len(valid_values) == 1:
  309. raise ValueError(f'The `{arg_name}` should be str and must be {valid_values[0]},'
  310. f' but got {arg_value}.')
  311. raise ValueError(f'The `{arg_name}` should be str and must be one of {valid_values},'
  312. f' but got {arg_value}.')
  313. @staticmethod
  314. def check_type_same(args, valid_values):
  315. """Determine whether the types are the same."""
  316. name = list(args.keys())[0]
  317. value = list(args.values())[0]
  318. if isinstance(value, type(mstype.tensor)):
  319. value = value.element_type()
  320. for arg_name, arg_value in args.items():
  321. if isinstance(arg_value, type(mstype.tensor)):
  322. arg_value = arg_value.element_type()
  323. if arg_value not in valid_values:
  324. raise TypeError(f'The `{arg_name}` should be in {valid_values},'
  325. f' but `{arg_name}` is {arg_value}.')
  326. if arg_value != value:
  327. raise TypeError(f'`{arg_name}` should be same as `{name}`,'
  328. f' but `{arg_name}` is {arg_value}, `{name}` is {value}.')
  329. @staticmethod
  330. def check_two_types_same(arg1_name, arg1_type, arg2_name, arg2_type):
  331. """Determine whether the types of two variables are the same."""
  332. if arg1_type != arg2_type:
  333. raise TypeError(f'The type of `{arg1_name}` and `{arg2_name}` should be same.')
  334. @staticmethod
  335. def check_value_on_integer(arg_name, arg_value, value, rel):
  336. """Judging integer type."""
  337. rel_fn = Rel.get_fns(rel)
  338. type_match = isinstance(arg_value, int)
  339. if type_match and (not rel_fn(arg_value, value)):
  340. rel_str = Rel.get_strs(rel).format(value)
  341. raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
  342. return arg_value
  343. @staticmethod
  344. def check_param_equal(param1_name, param1_value, param2_name, param2_value):
  345. """Judging the equality of parameters."""
  346. if param1_value != param2_value:
  347. raise ValueError(f"`{param1_name}` must equal `{param2_name}`,"
  348. f" but got `{param1_name}` = {param1_value},"
  349. f" `{param2_name}` = {param2_value}.")
  350. @staticmethod
  351. def check_const_input(arg_name, arg_value):
  352. """Check valid value."""
  353. if arg_value is None:
  354. raise ValueError(f'The `{arg_name}` must be a const input, but got {arg_value}.')
  355. @staticmethod
  356. def check_float_positive(arg_name, arg_value):
  357. """Float type judgment."""
  358. if isinstance(arg_value, float):
  359. if arg_value > 0:
  360. return arg_value
  361. raise ValueError(f"The `{arg_name}` must be positive, but got {arg_value}.")
  362. raise TypeError(f"`{arg_name}` must be float!")
  363. @staticmethod
  364. def check_pad_value_by_mode(op_name, pad_mode, padding):
  365. """Validate value of padding according to pad_mode"""
  366. if pad_mode != 'pad' and padding != 0:
  367. raise ValueError(f"For op '{op_name}', padding must be zero when pad_mode is '{pad_mode}'.")
  368. return padding
  369. @staticmethod
  370. def check_empty_shape_input(arg_name, arg_value):
  371. """Check zeros value."""
  372. if 0 in arg_value:
  373. raise ValueError(f"Input `{arg_name}` cannot be empty.")
  374. @staticmethod
  375. def check_scalar_shape_input(arg_name, arg_value):
  376. """Check scalar shape input."""
  377. if arg_value != []:
  378. raise ValueError(f"Input `{arg_name}` shape should be (). got {arg_value}")
  379. def check_int(input_param):
  380. """Int type judgment."""
  381. if isinstance(input_param, int) and not isinstance(input_param, bool):
  382. return input_param
  383. raise TypeError("Input type must be int!")
  384. def check_int_positive(input_param):
  385. """Int type judgment."""
  386. if isinstance(input_param, bool):
  387. raise TypeError("Input type must be int cannot be bool!")
  388. if isinstance(input_param, int):
  389. if input_param > 0:
  390. return input_param
  391. raise ValueError("The input_param must be positive, but got input_param {}.".format(input_param))
  392. raise TypeError("Input type must be int cannot be {}!".format(type(input_param)))
  393. def check_int_non_negative(input_param):
  394. """Non_negative type judgment."""
  395. if isinstance(input_param, bool):
  396. raise TypeError("Input type must be int cannot be bool!")
  397. if isinstance(input_param, int):
  398. if input_param >= 0:
  399. return input_param
  400. raise ValueError("The input_param must be non_negative, but got input_param {}.".format(input_param))
  401. raise TypeError("Input type must be int cannot be {}!".format(type(input_param)))
  402. def check_int_zero_one(input_param):
  403. """Judge whether it is 0 or 1."""
  404. if input_param in (0, 1):
  405. return input_param
  406. raise ValueError("The data must be 0 or 1.")
  407. def check_bool(input_param):
  408. """Bool type judgment."""
  409. if isinstance(input_param, bool):
  410. return input_param
  411. raise TypeError("Input type must be bool!")
  412. def check_input_format(input_param):
  413. """Judge input format."""
  414. if input_param == "NCHW":
  415. return input_param
  416. raise ValueError("The data format must be NCHW.")
  417. def check_padding(padding):
  418. """Check padding."""
  419. if padding >= 0:
  420. return padding
  421. raise ValueError("The padding must be at least 0,"" but got padding {}.".format(padding))
  422. def check_padmode(mode):
  423. """Check padmode."""
  424. if mode in ("same", "valid", "pad"):
  425. return mode
  426. raise ValueError("The pad mode must be same or valid or pad,"" but got mode {}.".format(mode))
  427. def check_tensor_supported_type(dtype):
  428. """Check tensor dtype."""
  429. if dtype in (mstype.int32, mstype.float32):
  430. return dtype
  431. raise ValueError("The dtype must be mstype.int32 or mstype.float32, but got mstype {}.".format(dtype))
  432. def _expand_tuple(n_dimensions):
  433. """To expand a number to tuple."""
  434. def convert(m):
  435. if not isinstance(m, tuple):
  436. if isinstance(m, int):
  437. return tuple(repeat(m, n_dimensions))
  438. raise TypeError("Input type must be int or tuple.")
  439. if not len(m) is n_dimensions:
  440. raise TypeError("Input dimension is incorrect.")
  441. for i in m:
  442. if not isinstance(i, int):
  443. raise TypeError("Incorrect type inside of a tuple!")
  444. return m
  445. return convert
  446. def check_input_data(*data, data_class):
  447. """Input data check."""
  448. for item in data:
  449. if isinstance(item, (list, tuple)):
  450. for v in item:
  451. check_input_data(v, data_class=data_class)
  452. else:
  453. if not isinstance(item, data_class):
  454. raise ValueError(f'Please provide as model inputs'
  455. f' either a single'
  456. f' or a list of {data_class.__name__},'
  457. f' but got part data type is {str(type(item))}.')
  458. if item.size() == 0:
  459. msg = "Please provide non-empty data."
  460. logger.error(msg)
  461. raise ValueError(msg)
  462. def check_output_data(data):
  463. """Output data check."""
  464. if not data:
  465. raise RuntimeError('Executor return data ' + str(data) + ', please check your net or input data.')
  466. def check_axis_type_int(axis):
  467. """Check axis type."""
  468. if not isinstance(axis, int):
  469. raise TypeError('Wrong type for axis, should be int.')
  470. def check_axis_range(axis, rank):
  471. """Check axis range."""
  472. if not -rank <= axis < rank:
  473. raise ValueError('The axis should be in range [{}, {}),'' but got {}.'.format(-rank, rank, axis))
  474. def check_attr_int(attr_name, attr):
  475. """Check int type."""
  476. if not isinstance(attr, int):
  477. raise TypeError("The attr {} should be int, but got {}.".format(attr_name, type(attr)))
  478. def check_t_in_range(t):
  479. """Check input range."""
  480. if t not in (mstype.float16, mstype.float32, mstype.float64, mstype.int32, mstype.int64):
  481. raise ValueError("The param T should be (float16, float32, float64, int32, int64).")
  482. once = _expand_tuple(1)
  483. twice = _expand_tuple(2)
  484. triple = _expand_tuple(3)
  485. valid_data_types = (int, float, np.int8, np.int16, np.int32, np.int64,
  486. np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
  487. np.float32, np.float64, bool, np.bool_)
  488. def check_type(arg_name, arg_value, valid_types):
  489. """Check value type."""
  490. # if input type is Tensor ,get element type
  491. if isinstance(arg_value, type(mstype.tensor)):
  492. arg_value = arg_value.element_type()
  493. # First, check if arg_value has argvalid_types
  494. if isinstance(arg_value, tuple(valid_types)):
  495. return type(arg_value).__name__
  496. # Second, wrap arg_value with numpy array so that it can be checked through numpy api
  497. if isinstance(arg_value, (list, tuple)):
  498. arg_value = np.array(arg_value)
  499. # Thirdly, check the data type by numpy's dtype api
  500. valid = False
  501. if isinstance(arg_value, np.ndarray):
  502. valid = arg_value.dtype in valid_data_types
  503. # Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
  504. # `check_type('x', True, [bool, int])` will check pass
  505. if isinstance(arg_value, bool) and bool not in tuple(valid_types):
  506. valid = False
  507. if not valid:
  508. type_names = [t.__name__ for t in valid_types]
  509. if len(valid_types) == 1:
  510. raise TypeError(f'The type of `{arg_name}` should be {type_names[0]},'
  511. f' but got {type(arg_value).__name__}.')
  512. raise TypeError(f'The type of `{arg_name}` should be one of {type_names},'
  513. f' but got {type(arg_value).__name__}.')
  514. return type(arg_value).__name__
  515. def check_typename(arg_name, arg_type, valid_types):
  516. """Check type name."""
  517. def get_typename(t):
  518. return t.__name__ if hasattr(t, '__name__') else str(t)
  519. if isinstance(arg_type, type(mstype.tensor)):
  520. arg_type = arg_type.element_type()
  521. if arg_type in valid_types:
  522. return arg_type
  523. if isinstance(arg_type, tuple(valid_types)):
  524. return arg_type
  525. type_names = [get_typename(t) for t in valid_types]
  526. if len(valid_types) == 1:
  527. raise TypeError(f'The type of `{arg_name}` should be {type_names[0]},'
  528. f' but got {get_typename(arg_type)}.')
  529. raise TypeError(f'The type of `{arg_name}` should be one of {type_names},'
  530. f' but got {get_typename(arg_type)}.')
  531. def check_shape(arg_name, arg_value):
  532. """Check shape."""
  533. # First, check if shape is a tuple
  534. if not isinstance(arg_value, tuple):
  535. raise TypeError(f'The type of `{arg_name}` should be one of {tuple.__name__},'
  536. f' but got {type(arg_value).__name__}.')
  537. # Second, wrap arg_value with numpy array so that it can be checked through numpy api
  538. arg_value = np.array(arg_value)
  539. # shape can not be ()
  540. if arg_value.size == 0:
  541. raise ValueError('Shape can not be empty.')
  542. # shape's dimension should be 1
  543. if arg_value.ndim != 1:
  544. raise ValueError('Shape of tensor should be 1-dim vector, but got {}-dim.'.format(arg_value.ndim))
  545. # Thirdly, check each element's type of the shape
  546. valid_types = (int, np.int8, np.int16, np.int32, np.int64,
  547. np.uint8, np.uint16, np.uint32, np.uint64)
  548. for dim_size in arg_value:
  549. if not isinstance(dim_size, valid_types) or dim_size <= 0:
  550. raise ValueError('Every dimension size of the tensor shape should be a positive integer,'
  551. ' but got {}.'.format(dim_size))
  552. def _check_str_by_regular(target, reg=None, flag=re.ASCII):
  553. if reg is None:
  554. reg = _name_re
  555. if re.match(reg, target, flag) is None:
  556. raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag))
  557. return True