You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validators.py 10 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. # Copyright 2019-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Validators for TensorOps.
  16. """
  17. from functools import wraps
  18. import inspect
  19. import numpy as np
  20. from mindspore._c_expression import typing
  21. from ..core.validator_helpers import parse_user_args, type_check, check_pos_int64, check_value, check_positive, \
  22. check_tensor_op, type_check_list
  23. # POS_INT_MIN is used to limit values from starting from 0
  24. POS_INT_MIN = 1
  25. UINT8_MAX = 255
  26. UINT8_MIN = 0
  27. UINT32_MAX = 4294967295
  28. UINT32_MIN = 0
  29. UINT64_MAX = 18446744073709551615
  30. UINT64_MIN = 0
  31. INT32_MAX = 2147483647
  32. INT32_MIN = -2147483648
  33. INT64_MAX = 9223372036854775807
  34. INT64_MIN = -9223372036854775808
  35. FLOAT_MAX_INTEGER = 16777216
  36. FLOAT_MIN_INTEGER = -16777216
  37. DOUBLE_MAX_INTEGER = 9007199254740992
  38. DOUBLE_MIN_INTEGER = -9007199254740992
  39. def check_fill_value(method):
  40. """Wrapper method to check the parameters of fill_value."""
  41. @wraps(method)
  42. def new_method(self, *args, **kwargs):
  43. [fill_value], _ = parse_user_args(method, *args, **kwargs)
  44. type_check(fill_value, (str, float, bool, int, bytes), "fill_value")
  45. return method(self, *args, **kwargs)
  46. return new_method
  47. def check_one_hot_op(method):
  48. """Wrapper method to check the parameters of one_hot_op."""
  49. @wraps(method)
  50. def new_method(self, *args, **kwargs):
  51. [num_classes, smoothing_rate], _ = parse_user_args(method, *args, **kwargs)
  52. type_check(smoothing_rate, (int, float), "smoothing_rate")
  53. type_check(num_classes, (int,), "num_classes")
  54. check_positive(num_classes)
  55. if smoothing_rate is not None:
  56. check_value(smoothing_rate, [0., 1.], "smoothing_rate")
  57. return method(self, *args, **kwargs)
  58. return new_method
  59. def check_num_classes(method):
  60. """Wrapper method to check the parameters of number of classes."""
  61. @wraps(method)
  62. def new_method(self, *args, **kwargs):
  63. [num_classes], _ = parse_user_args(method, *args, **kwargs)
  64. type_check(num_classes, (int,), "num_classes")
  65. check_positive(num_classes)
  66. return method(self, *args, **kwargs)
  67. return new_method
  68. def check_ms_type(method):
  69. """Wrapper method to check the parameters of data type."""
  70. @wraps(method)
  71. def new_method(self, *args, **kwargs):
  72. [data_type], _ = parse_user_args(method, *args, **kwargs)
  73. type_check(data_type, (typing.Type,), "data_type")
  74. return method(self, *args, **kwargs)
  75. return new_method
  76. def check_slice_option(method):
  77. """Wrapper method to check the parameters of SliceOption."""
  78. @wraps(method)
  79. def new_method(self, *args, **kwargs):
  80. [slice_option], _ = parse_user_args(method, *args, **kwargs)
  81. from .c_transforms import _SliceOption
  82. if slice_option is not None:
  83. type_check(slice_option, (int, list, slice, bool, type(Ellipsis), _SliceOption), "slice_option")
  84. if isinstance(slice_option, list):
  85. type_check_list(slice_option, (int,), "slice_option")
  86. return method(self, *args, **kwargs)
  87. return new_method
  88. def check_slice_op(method):
  89. """Wrapper method to check the parameters of slice."""
  90. @wraps(method)
  91. def new_method(self, *args, **kwargs):
  92. [slice_op], _ = parse_user_args(method, *args, **kwargs)
  93. for s in slice_op:
  94. from .c_transforms import _SliceOption
  95. if s is not None:
  96. type_check(s, (int, list, slice, bool, type(Ellipsis), _SliceOption), "slice")
  97. if isinstance(s, list) and s:
  98. if isinstance(s[0], int):
  99. type_check_list(s, (int,), "slice")
  100. return method(self, *args, **kwargs)
  101. return new_method
  102. def check_mask_op(method):
  103. """Wrapper method to check the parameters of mask."""
  104. @wraps(method)
  105. def new_method(self, *args, **kwargs):
  106. [operator, constant, dtype], _ = parse_user_args(method, *args, **kwargs)
  107. from .c_transforms import Relational
  108. type_check(operator, (Relational,), "operator")
  109. type_check(constant, (str, float, bool, int, bytes), "constant")
  110. type_check(dtype, (typing.Type,), "dtype")
  111. return method(self, *args, **kwargs)
  112. return new_method
  113. def check_pad_end(method):
  114. """Wrapper method to check the parameters of PadEnd."""
  115. @wraps(method)
  116. def new_method(self, *args, **kwargs):
  117. [pad_shape, pad_value], _ = parse_user_args(method, *args, **kwargs)
  118. if pad_value is not None:
  119. type_check(pad_value, (str, float, bool, int, bytes), "pad_value")
  120. type_check(pad_shape, (list,), "pad_shape")
  121. for dim in pad_shape:
  122. if dim is not None:
  123. if isinstance(dim, int):
  124. check_pos_int64(dim)
  125. else:
  126. raise TypeError("a value in the list is not an integer.")
  127. return method(self, *args, **kwargs)
  128. return new_method
  129. def check_concat_type(method):
  130. """Wrapper method to check the parameters of concatenation op."""
  131. @wraps(method)
  132. def new_method(self, *args, **kwargs):
  133. [axis, prepend, append], _ = parse_user_args(method, *args, **kwargs)
  134. if axis is not None:
  135. type_check(axis, (int,), "axis")
  136. if axis not in (0, -1):
  137. raise ValueError("only 1D concatenation supported.")
  138. if prepend is not None:
  139. type_check(prepend, (np.ndarray,), "prepend")
  140. if len(prepend.shape) != 1:
  141. raise ValueError("can only prepend 1D arrays.")
  142. if append is not None:
  143. type_check(append, (np.ndarray,), "append")
  144. if len(append.shape) != 1:
  145. raise ValueError("can only append 1D arrays.")
  146. return method(self, *args, **kwargs)
  147. return new_method
  148. def check_random_transform_ops(method):
  149. """Wrapper method to check the parameters of RandomChoice, RandomApply and Compose."""
  150. @wraps(method)
  151. def new_method(self, *args, **kwargs):
  152. arg_list, _ = parse_user_args(method, *args, **kwargs)
  153. type_check(arg_list[0], (list,), "op_list")
  154. if not arg_list[0]:
  155. raise ValueError("op_list can not be empty.")
  156. for ind, op in enumerate(arg_list[0]):
  157. check_tensor_op(op, "op_list[{0}]".format(ind))
  158. if len(arg_list) == 2: # random apply takes an additional arg
  159. type_check(arg_list[1], (float, int), "prob")
  160. check_value(arg_list[1], (0, 1), "prob")
  161. return method(self, *args, **kwargs)
  162. return new_method
  163. def check_compose_list(method):
  164. """Wrapper method to check the transform list of Python Compose."""
  165. @wraps(method)
  166. def new_method(self, *args, **kwargs):
  167. [transforms], _ = parse_user_args(method, *args, **kwargs)
  168. type_check(transforms, (list,), transforms)
  169. if not transforms:
  170. raise ValueError("transforms list is empty.")
  171. for i, transform in enumerate(transforms):
  172. if not callable(transform):
  173. raise ValueError("transforms[{}] is not callable.".format(i))
  174. return method(self, *args, **kwargs)
  175. return new_method
  176. def check_compose_call(method):
  177. """Wrapper method to check the transform list of Compose."""
  178. @wraps(method)
  179. def new_method(self, *args, **kwargs):
  180. sig = inspect.signature(method)
  181. ba = sig.bind_partial(method, *args, **kwargs)
  182. img = ba.arguments.get("args")
  183. if img is None:
  184. raise TypeError(
  185. "Compose was called without an image. Fix invocation (avoid it being invoked as Compose([...])()).")
  186. return method(self, *args, **kwargs)
  187. return new_method
  188. def check_random_apply(method):
  189. """Wrapper method to check the parameters of random apply."""
  190. @wraps(method)
  191. def new_method(self, *args, **kwargs):
  192. [transforms, prob], _ = parse_user_args(method, *args, **kwargs)
  193. type_check(transforms, (list,), "transforms")
  194. for i, transform in enumerate(transforms):
  195. if str(transform).find("c_transform") >= 0:
  196. raise ValueError(
  197. "transforms[{}] is not a py transforms. Should not use a c transform in py transform" \
  198. .format(i))
  199. if prob is not None:
  200. type_check(prob, (float, int,), "prob")
  201. check_value(prob, [0., 1.], "prob")
  202. return method(self, *args, **kwargs)
  203. return new_method
  204. def check_transforms_list(method):
  205. """Wrapper method to check the parameters of transform list."""
  206. @wraps(method)
  207. def new_method(self, *args, **kwargs):
  208. [transforms], _ = parse_user_args(method, *args, **kwargs)
  209. type_check(transforms, (list,), "transforms")
  210. for i, transform in enumerate(transforms):
  211. if str(transform).find("c_transform") >= 0:
  212. raise ValueError(
  213. "transforms[{}] is not a py transforms. Should not use a c transform in py transform" \
  214. .format(i))
  215. return method(self, *args, **kwargs)
  216. return new_method
  217. def check_plugin(method):
  218. """Wrapper method to check the parameters of plugin."""
  219. @wraps(method)
  220. def new_method(self, *args, **kwargs):
  221. [lib_path, func_name, user_args], _ = parse_user_args(method, *args, **kwargs)
  222. type_check(lib_path, (str,), "lib_path")
  223. type_check(func_name, (str,), "func_name")
  224. if user_args is not None:
  225. type_check(user_args, (str,), "user_args")
  226. return method(self, *args, **kwargs)
  227. return new_method