You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validators.py 9.5 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Validators for TensorOps.
  16. """
  17. from functools import wraps
  18. import inspect
  19. import numpy as np
  20. from mindspore._c_expression import typing
  21. from ..core.validator_helpers import parse_user_args, type_check, check_pos_int64, check_value, check_positive, \
  22. check_tensor_op, type_check_list
  23. # POS_INT_MIN is used to limit values from starting from 0
  24. POS_INT_MIN = 1
  25. UINT8_MAX = 255
  26. UINT8_MIN = 0
  27. UINT32_MAX = 4294967295
  28. UINT32_MIN = 0
  29. UINT64_MAX = 18446744073709551615
  30. UINT64_MIN = 0
  31. INT32_MAX = 2147483647
  32. INT32_MIN = -2147483648
  33. INT64_MAX = 9223372036854775807
  34. INT64_MIN = -9223372036854775808
  35. FLOAT_MAX_INTEGER = 16777216
  36. FLOAT_MIN_INTEGER = -16777216
  37. DOUBLE_MAX_INTEGER = 9007199254740992
  38. DOUBLE_MIN_INTEGER = -9007199254740992
  39. def check_fill_value(method):
  40. """Wrapper method to check the parameters of fill_value."""
  41. @wraps(method)
  42. def new_method(self, *args, **kwargs):
  43. [fill_value], _ = parse_user_args(method, *args, **kwargs)
  44. type_check(fill_value, (str, float, bool, int, bytes), "fill_value")
  45. return method(self, *args, **kwargs)
  46. return new_method
  47. def check_one_hot_op(method):
  48. """Wrapper method to check the parameters of one_hot_op."""
  49. @wraps(method)
  50. def new_method(self, *args, **kwargs):
  51. [num_classes, smoothing_rate], _ = parse_user_args(method, *args, **kwargs)
  52. type_check(num_classes, (int,), "num_classes")
  53. check_positive(num_classes)
  54. if smoothing_rate is not None:
  55. check_value(smoothing_rate, [0., 1.], "smoothing_rate")
  56. return method(self, *args, **kwargs)
  57. return new_method
  58. def check_num_classes(method):
  59. """Wrapper method to check the parameters of number of classes."""
  60. @wraps(method)
  61. def new_method(self, *args, **kwargs):
  62. [num_classes], _ = parse_user_args(method, *args, **kwargs)
  63. type_check(num_classes, (int,), "num_classes")
  64. check_positive(num_classes)
  65. return method(self, *args, **kwargs)
  66. return new_method
  67. def check_de_type(method):
  68. """Wrapper method to check the parameters of data type."""
  69. @wraps(method)
  70. def new_method(self, *args, **kwargs):
  71. [data_type], _ = parse_user_args(method, *args, **kwargs)
  72. type_check(data_type, (typing.Type,), "data_type")
  73. return method(self, *args, **kwargs)
  74. return new_method
  75. def check_slice_option(method):
  76. """Wrapper method to check the parameters of SliceOption."""
  77. @wraps(method)
  78. def new_method(self, *args, **kwargs):
  79. [slice_option], _ = parse_user_args(method, *args, **kwargs)
  80. from .c_transforms import _SliceOption
  81. if slice_option is not None:
  82. type_check(slice_option, (int, list, slice, bool, type(Ellipsis), _SliceOption), "slice_option")
  83. if isinstance(slice_option, list):
  84. type_check_list(slice_option, (int,), "slice_option")
  85. return method(self, *args, **kwargs)
  86. return new_method
  87. def check_slice_op(method):
  88. """Wrapper method to check the parameters of slice."""
  89. @wraps(method)
  90. def new_method(self, *args, **kwargs):
  91. [slice_op], _ = parse_user_args(method, *args, **kwargs)
  92. for s in slice_op:
  93. from .c_transforms import _SliceOption
  94. if s is not None:
  95. type_check(s, (int, list, slice, bool, type(Ellipsis), _SliceOption), "slice")
  96. if isinstance(s, list) and s:
  97. if isinstance(s[0], int):
  98. type_check_list(s, (int,), "slice")
  99. return method(self, *args, **kwargs)
  100. return new_method
  101. def check_mask_op(method):
  102. """Wrapper method to check the parameters of mask."""
  103. @wraps(method)
  104. def new_method(self, *args, **kwargs):
  105. [operator, constant, dtype], _ = parse_user_args(method, *args, **kwargs)
  106. from .c_transforms import Relational
  107. type_check(operator, (Relational,), "operator")
  108. type_check(constant, (str, float, bool, int, bytes), "constant")
  109. type_check(dtype, (typing.Type,), "dtype")
  110. return method(self, *args, **kwargs)
  111. return new_method
  112. def check_pad_end(method):
  113. """Wrapper method to check the parameters of PadEnd."""
  114. @wraps(method)
  115. def new_method(self, *args, **kwargs):
  116. [pad_shape, pad_value], _ = parse_user_args(method, *args, **kwargs)
  117. if pad_value is not None:
  118. type_check(pad_value, (str, float, bool, int, bytes), "pad_value")
  119. type_check(pad_shape, (list,), "pad_end")
  120. for dim in pad_shape:
  121. if dim is not None:
  122. if isinstance(dim, int):
  123. check_pos_int64(dim)
  124. else:
  125. raise TypeError("a value in the list is not an integer.")
  126. return method(self, *args, **kwargs)
  127. return new_method
  128. def check_concat_type(method):
  129. """Wrapper method to check the parameters of concatenation op."""
  130. @wraps(method)
  131. def new_method(self, *args, **kwargs):
  132. [axis, prepend, append], _ = parse_user_args(method, *args, **kwargs)
  133. if axis is not None:
  134. type_check(axis, (int,), "axis")
  135. if axis not in (0, -1):
  136. raise ValueError("only 1D concatenation supported.")
  137. if prepend is not None:
  138. type_check(prepend, (np.ndarray,), "prepend")
  139. if len(prepend.shape) != 1:
  140. raise ValueError("can only prepend 1D arrays.")
  141. if append is not None:
  142. type_check(append, (np.ndarray,), "append")
  143. if len(append.shape) != 1:
  144. raise ValueError("can only append 1D arrays.")
  145. return method(self, *args, **kwargs)
  146. return new_method
  147. def check_random_transform_ops(method):
  148. """Wrapper method to check the parameters of RandomChoice, RandomApply and Compose."""
  149. @wraps(method)
  150. def new_method(self, *args, **kwargs):
  151. arg_list, _ = parse_user_args(method, *args, **kwargs)
  152. type_check(arg_list[0], (list,), "op_list")
  153. if not arg_list[0]:
  154. raise ValueError("op_list can not be empty.")
  155. for ind, op in enumerate(arg_list[0]):
  156. check_tensor_op(op, "op_list[{0}]".format(ind))
  157. if len(arg_list) == 2: # random apply takes an additional arg
  158. type_check(arg_list[1], (float, int), "prob")
  159. check_value(arg_list[1], (0, 1), "prob")
  160. return method(self, *args, **kwargs)
  161. return new_method
  162. def check_compose_list(method):
  163. """Wrapper method to check the transform list of Python Compose."""
  164. @wraps(method)
  165. def new_method(self, *args, **kwargs):
  166. [transforms], _ = parse_user_args(method, *args, **kwargs)
  167. type_check(transforms, (list,), transforms)
  168. if not transforms:
  169. raise ValueError("transforms list is empty.")
  170. for i, transfrom in enumerate(transforms):
  171. if not callable(transfrom):
  172. raise ValueError("transforms[{}] is not callable.".format(i))
  173. return method(self, *args, **kwargs)
  174. return new_method
  175. def check_compose_call(method):
  176. """Wrapper method to check the transform list of Compose."""
  177. @wraps(method)
  178. def new_method(self, *args, **kwargs):
  179. sig = inspect.signature(method)
  180. ba = sig.bind_partial(method, *args, **kwargs)
  181. img = ba.arguments.get("args")
  182. if img is None:
  183. raise TypeError(
  184. "Compose was called without an image. Fix invocation (avoid it being invoked as Compose([...])()).")
  185. return method(self, *args, **kwargs)
  186. return new_method
  187. def check_random_apply(method):
  188. """Wrapper method to check the parameters of random apply."""
  189. @wraps(method)
  190. def new_method(self, *args, **kwargs):
  191. [transforms, prob], _ = parse_user_args(method, *args, **kwargs)
  192. type_check(transforms, (list,), "transforms")
  193. for i, transfrom in enumerate(transforms):
  194. if not callable(transfrom):
  195. raise ValueError("transforms[{}] is not callable.".format(i))
  196. if prob is not None:
  197. type_check(prob, (float, int,), "prob")
  198. check_value(prob, [0., 1.], "prob")
  199. return method(self, *args, **kwargs)
  200. return new_method
  201. def check_transforms_list(method):
  202. """Wrapper method to check the parameters of transform list."""
  203. @wraps(method)
  204. def new_method(self, *args, **kwargs):
  205. [transforms], _ = parse_user_args(method, *args, **kwargs)
  206. type_check(transforms, (list,), "transforms")
  207. for i, transfrom in enumerate(transforms):
  208. if not callable(transfrom):
  209. raise ValueError("transforms[{}] is not callable.".format(i))
  210. return method(self, *args, **kwargs)
  211. return new_method