You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validators.py 10 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Validators for TensorOps.
  16. """
  17. from functools import wraps
  18. import numpy as np
  19. from mindspore._c_expression import typing
  20. # POS_INT_MIN is used to limit values from starting from 0
  21. POS_INT_MIN = 1
  22. UINT8_MAX = 255
  23. UINT8_MIN = 0
  24. UINT32_MAX = 4294967295
  25. UINT32_MIN = 0
  26. UINT64_MAX = 18446744073709551615
  27. UINT64_MIN = 0
  28. INT32_MAX = 2147483647
  29. INT32_MIN = -2147483648
  30. INT64_MAX = 9223372036854775807
  31. INT64_MIN = -9223372036854775808
  32. FLOAT_MAX_INTEGER = 16777216
  33. FLOAT_MIN_INTEGER = -16777216
  34. DOUBLE_MAX_INTEGER = 9007199254740992
  35. DOUBLE_MIN_INTEGER = -9007199254740992
  36. def check_type(value, valid_type):
  37. if not isinstance(value, valid_type):
  38. raise ValueError("Wrong input type")
  39. def check_value(value, valid_range):
  40. if value < valid_range[0] or value > valid_range[1]:
  41. raise ValueError("Input is not within the required range")
  42. def check_range(values, valid_range):
  43. if not valid_range[0] <= values[0] <= values[1] <= valid_range[1]:
  44. raise ValueError("Input range is not valid")
  45. def check_positive(value):
  46. if value <= 0:
  47. raise ValueError("Input must greater than 0")
  48. def check_positive_float(value, valid_max=None):
  49. if value <= 0 or not isinstance(value, float) or (valid_max is not None and value > valid_max):
  50. raise ValueError("Input need to be a valid positive float.")
  51. def check_bool(value):
  52. if not isinstance(value, bool):
  53. raise ValueError("Value needs to be a boolean.")
  54. def check_2tuple(value):
  55. if not (isinstance(value, tuple) and len(value) == 2):
  56. raise ValueError("Value needs to be a 2-tuple.")
  57. def check_list(value):
  58. if not isinstance(value, list):
  59. raise ValueError("The input needs to be a list.")
  60. def check_uint8(value):
  61. if not isinstance(value, int):
  62. raise ValueError("The input needs to be a integer")
  63. check_value(value, [UINT8_MIN, UINT8_MAX])
  64. def check_uint32(value):
  65. if not isinstance(value, int):
  66. raise ValueError("The input needs to be a integer")
  67. check_value(value, [UINT32_MIN, UINT32_MAX])
  68. def check_pos_int32(value):
  69. """Checks for int values starting from 1"""
  70. if not isinstance(value, int):
  71. raise ValueError("The input needs to be a integer")
  72. check_value(value, [POS_INT_MIN, INT32_MAX])
  73. def check_uint64(value):
  74. if not isinstance(value, int):
  75. raise ValueError("The input needs to be a integer")
  76. check_value(value, [UINT64_MIN, UINT64_MAX])
  77. def check_pos_int64(value):
  78. if not isinstance(value, int):
  79. raise ValueError("The input needs to be a integer")
  80. check_value(value, [UINT64_MIN, INT64_MAX])
  81. def check_pos_float32(value):
  82. check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER])
  83. def check_pos_float64(value):
  84. check_value(value, [UINT64_MIN, DOUBLE_MAX_INTEGER])
  85. def check_one_hot_op(method):
  86. """Wrapper method to check the parameters of one hot op."""
  87. @wraps(method)
  88. def new_method(self, *args, **kwargs):
  89. args = (list(args) + 2 * [None])[:2]
  90. num_classes, smoothing_rate = args
  91. if "num_classes" in kwargs:
  92. num_classes = kwargs.get("num_classes")
  93. if "smoothing_rate" in kwargs:
  94. smoothing_rate = kwargs.get("smoothing_rate")
  95. if num_classes is None:
  96. raise ValueError("num_classes")
  97. check_pos_int32(num_classes)
  98. kwargs["num_classes"] = num_classes
  99. if smoothing_rate is not None:
  100. check_value(smoothing_rate, [0., 1.])
  101. kwargs["smoothing_rate"] = smoothing_rate
  102. return method(self, **kwargs)
  103. return new_method
  104. def check_num_classes(method):
  105. """Wrapper method to check the parameters of number of classes."""
  106. @wraps(method)
  107. def new_method(self, *args, **kwargs):
  108. num_classes = (list(args) + [None])[0]
  109. if "num_classes" in kwargs:
  110. num_classes = kwargs.get("num_classes")
  111. if num_classes is None:
  112. raise ValueError("num_classes is not provided.")
  113. check_pos_int32(num_classes)
  114. kwargs["num_classes"] = num_classes
  115. return method(self, **kwargs)
  116. return new_method
  117. def check_fill_value(method):
  118. """Wrapper method to check the parameters of fill value."""
  119. @wraps(method)
  120. def new_method(self, *args, **kwargs):
  121. fill_value = (list(args) + [None])[0]
  122. if "fill_value" in kwargs:
  123. fill_value = kwargs.get("fill_value")
  124. if fill_value is None:
  125. raise ValueError("fill_value is not provided.")
  126. if not isinstance(fill_value, (str, float, bool, int, bytes)):
  127. raise TypeError("fill_value must be either a primitive python str, float, bool, bytes or int")
  128. kwargs["fill_value"] = fill_value
  129. return method(self, **kwargs)
  130. return new_method
  131. def check_de_type(method):
  132. """Wrapper method to check the parameters of data type."""
  133. @wraps(method)
  134. def new_method(self, *args, **kwargs):
  135. data_type = (list(args) + [None])[0]
  136. if "data_type" in kwargs:
  137. data_type = kwargs.get("data_type")
  138. if data_type is None:
  139. raise ValueError("data_type is not provided.")
  140. if not isinstance(data_type, typing.Type):
  141. raise TypeError("data_type is not a MindSpore data type.")
  142. kwargs["data_type"] = data_type
  143. return method(self, **kwargs)
  144. return new_method
  145. def check_slice_op(method):
  146. """Wrapper method to check the parameters of slice."""
  147. @wraps(method)
  148. def new_method(self, *args):
  149. for i, arg in enumerate(args):
  150. if arg is not None and arg is not Ellipsis and not isinstance(arg, (int, slice, list)):
  151. raise TypeError("Indexing of dim " + str(i) + "is not of valid type")
  152. if isinstance(arg, list):
  153. for a in arg:
  154. if not isinstance(a, int):
  155. raise TypeError("Index " + a + " is not an int")
  156. return method(self, *args)
  157. return new_method
  158. def check_mask_op(method):
  159. """Wrapper method to check the parameters of mask."""
  160. @wraps(method)
  161. def new_method(self, *args, **kwargs):
  162. operator, constant, dtype = (list(args) + 3 * [None])[:3]
  163. if "operator" in kwargs:
  164. operator = kwargs.get("operator")
  165. if "constant" in kwargs:
  166. constant = kwargs.get("constant")
  167. if "dtype" in kwargs:
  168. dtype = kwargs.get("dtype")
  169. if operator is None:
  170. raise ValueError("operator is not provided.")
  171. if constant is None:
  172. raise ValueError("constant is not provided.")
  173. from .c_transforms import Relational
  174. if not isinstance(operator, Relational):
  175. raise TypeError("operator is not a Relational operator enum.")
  176. if not isinstance(constant, (str, float, bool, int, bytes)):
  177. raise TypeError("constant must be either a primitive python str, float, bool, bytes or int")
  178. if dtype is not None:
  179. if not isinstance(dtype, typing.Type):
  180. raise TypeError("dtype is not a MindSpore data type.")
  181. kwargs["dtype"] = dtype
  182. kwargs["operator"] = operator
  183. kwargs["constant"] = constant
  184. return method(self, **kwargs)
  185. return new_method
  186. def check_pad_end(method):
  187. """Wrapper method to check the parameters of PadEnd."""
  188. @wraps(method)
  189. def new_method(self, *args, **kwargs):
  190. pad_shape, pad_value = (list(args) + 2 * [None])[:2]
  191. if "pad_shape" in kwargs:
  192. pad_shape = kwargs.get("pad_shape")
  193. if "pad_value" in kwargs:
  194. pad_value = kwargs.get("pad_value")
  195. if pad_shape is None:
  196. raise ValueError("pad_shape is not provided.")
  197. if pad_value is not None:
  198. if not isinstance(pad_value, (str, float, bool, int, bytes)):
  199. raise TypeError("pad_value must be either a primitive python str, float, bool, int or bytes")
  200. kwargs["pad_value"] = pad_value
  201. if not isinstance(pad_shape, list):
  202. raise TypeError("pad_shape must be a list")
  203. for dim in pad_shape:
  204. if dim is not None:
  205. if isinstance(dim, int):
  206. check_pos_int64(dim)
  207. else:
  208. raise TypeError("a value in the list is not an integer.")
  209. kwargs["pad_shape"] = pad_shape
  210. return method(self, **kwargs)
  211. return new_method
  212. def check_concat_type(method):
  213. """Wrapper method to check the parameters of concatenation op."""
  214. @wraps(method)
  215. def new_method(self, *args, **kwargs):
  216. axis, prepend, append = (list(args) + 3 * [None])[:3]
  217. if "prepend" in kwargs:
  218. prepend = kwargs.get("prepend")
  219. if "append" in kwargs:
  220. append = kwargs.get("append")
  221. if "axis" in kwargs:
  222. axis = kwargs.get("axis")
  223. if axis is not None:
  224. if not isinstance(axis, int):
  225. raise TypeError("axis type is not valid, must be an integer.")
  226. if axis not in (0, -1):
  227. raise ValueError("only 1D concatenation supported.")
  228. kwargs["axis"] = axis
  229. if prepend is not None:
  230. if not isinstance(prepend, (type(None), np.ndarray)):
  231. raise ValueError("prepend type is not valid, must be None for no prepend tensor or a numpy array.")
  232. kwargs["prepend"] = prepend
  233. if append is not None:
  234. if not isinstance(append, (type(None), np.ndarray)):
  235. raise ValueError("append type is not valid, must be None for no append tensor or a numpy array.")
  236. kwargs["append"] = append
  237. return method(self, **kwargs)
  238. return new_method