You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validators.py 6.3 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Validators for TensorOps.
  16. """
  17. from functools import wraps
  18. import numpy as np
  19. from mindspore._c_expression import typing
  20. from ..core.validator_helpers import parse_user_args, type_check, check_pos_int64, check_value, check_positive, \
  21. check_tensor_op
  22. # POS_INT_MIN is used to limit values from starting from 0
  23. POS_INT_MIN = 1
  24. UINT8_MAX = 255
  25. UINT8_MIN = 0
  26. UINT32_MAX = 4294967295
  27. UINT32_MIN = 0
  28. UINT64_MAX = 18446744073709551615
  29. UINT64_MIN = 0
  30. INT32_MAX = 2147483647
  31. INT32_MIN = -2147483648
  32. INT64_MAX = 9223372036854775807
  33. INT64_MIN = -9223372036854775808
  34. FLOAT_MAX_INTEGER = 16777216
  35. FLOAT_MIN_INTEGER = -16777216
  36. DOUBLE_MAX_INTEGER = 9007199254740992
  37. DOUBLE_MIN_INTEGER = -9007199254740992
  38. def check_fill_value(method):
  39. """Wrapper method to check the parameters of fill_value."""
  40. @wraps(method)
  41. def new_method(self, *args, **kwargs):
  42. [fill_value], _ = parse_user_args(method, *args, **kwargs)
  43. type_check(fill_value, (str, float, bool, int, bytes), "fill_value")
  44. return method(self, *args, **kwargs)
  45. return new_method
  46. def check_one_hot_op(method):
  47. """Wrapper method to check the parameters of one_hot_op."""
  48. @wraps(method)
  49. def new_method(self, *args, **kwargs):
  50. [num_classes, smoothing_rate], _ = parse_user_args(method, *args, **kwargs)
  51. type_check(num_classes, (int,), "num_classes")
  52. check_positive(num_classes)
  53. if smoothing_rate is not None:
  54. check_value(smoothing_rate, [0., 1.], "smoothing_rate")
  55. return method(self, *args, **kwargs)
  56. return new_method
  57. def check_num_classes(method):
  58. """Wrapper method to check the parameters of number of classes."""
  59. @wraps(method)
  60. def new_method(self, *args, **kwargs):
  61. [num_classes], _ = parse_user_args(method, *args, **kwargs)
  62. type_check(num_classes, (int,), "num_classes")
  63. check_positive(num_classes)
  64. return method(self, *args, **kwargs)
  65. return new_method
  66. def check_de_type(method):
  67. """Wrapper method to check the parameters of data type."""
  68. @wraps(method)
  69. def new_method(self, *args, **kwargs):
  70. [data_type], _ = parse_user_args(method, *args, **kwargs)
  71. type_check(data_type, (typing.Type,), "data_type")
  72. return method(self, *args, **kwargs)
  73. return new_method
  74. def check_slice_op(method):
  75. """Wrapper method to check the parameters of slice."""
  76. @wraps(method)
  77. def new_method(self, *args):
  78. for _, arg in enumerate(args):
  79. type_check(arg, (int, slice, list, type(None), type(Ellipsis)), "arg")
  80. if isinstance(arg, list):
  81. for a in arg:
  82. type_check(a, (int,), "a")
  83. return method(self, *args)
  84. return new_method
  85. def check_mask_op(method):
  86. """Wrapper method to check the parameters of mask."""
  87. @wraps(method)
  88. def new_method(self, *args, **kwargs):
  89. [operator, constant, dtype], _ = parse_user_args(method, *args, **kwargs)
  90. from .c_transforms import Relational
  91. type_check(operator, (Relational,), "operator")
  92. type_check(constant, (str, float, bool, int, bytes), "constant")
  93. type_check(dtype, (typing.Type,), "dtype")
  94. return method(self, *args, **kwargs)
  95. return new_method
  96. def check_pad_end(method):
  97. """Wrapper method to check the parameters of PadEnd."""
  98. @wraps(method)
  99. def new_method(self, *args, **kwargs):
  100. [pad_shape, pad_value], _ = parse_user_args(method, *args, **kwargs)
  101. if pad_value is not None:
  102. type_check(pad_value, (str, float, bool, int, bytes), "pad_value")
  103. type_check(pad_shape, (list,), "pad_end")
  104. for dim in pad_shape:
  105. if dim is not None:
  106. if isinstance(dim, int):
  107. check_pos_int64(dim)
  108. else:
  109. raise TypeError("a value in the list is not an integer.")
  110. return method(self, *args, **kwargs)
  111. return new_method
  112. def check_concat_type(method):
  113. """Wrapper method to check the parameters of concatenation op."""
  114. @wraps(method)
  115. def new_method(self, *args, **kwargs):
  116. [axis, prepend, append], _ = parse_user_args(method, *args, **kwargs)
  117. if axis is not None:
  118. type_check(axis, (int,), "axis")
  119. if axis not in (0, -1):
  120. raise ValueError("only 1D concatenation supported.")
  121. if prepend is not None:
  122. type_check(prepend, (np.ndarray,), "prepend")
  123. if len(prepend.shape) != 1:
  124. raise ValueError("can only prepend 1D arrays.")
  125. if append is not None:
  126. type_check(append, (np.ndarray,), "append")
  127. if len(append.shape) != 1:
  128. raise ValueError("can only append 1D arrays.")
  129. return method(self, *args, **kwargs)
  130. return new_method
  131. def check_random_transform_ops(method):
  132. """Wrapper method to check the parameters of RandomChoice, RandomApply and Compose."""
  133. @wraps(method)
  134. def new_method(self, *args, **kwargs):
  135. arg_list, _ = parse_user_args(method, *args, **kwargs)
  136. type_check(arg_list[0], (list,), "op_list")
  137. if not arg_list[0]:
  138. raise ValueError("op_list can not be empty.")
  139. for ind, op in enumerate(arg_list[0]):
  140. check_tensor_op(op, "op_list[{0}]".format(ind))
  141. if len(arg_list) == 2: # random apply takes an additional arg
  142. type_check(arg_list[1], (float, int), "prob")
  143. check_value(arg_list[1], (0, 1), "prob")
  144. return method(self, *args, **kwargs)
  145. return new_method