You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validators.py 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Validators for TensorOps.
  16. """
  17. from functools import wraps
  18. from mindspore._c_expression import typing
  19. # POS_INT_MIN is used to limit values from starting from 0
  20. POS_INT_MIN = 1
  21. UINT8_MAX = 255
  22. UINT8_MIN = 0
  23. UINT32_MAX = 4294967295
  24. UINT32_MIN = 0
  25. UINT64_MAX = 18446744073709551615
  26. UINT64_MIN = 0
  27. INT32_MAX = 2147483647
  28. INT32_MIN = -2147483648
  29. INT64_MAX = 9223372036854775807
  30. INT64_MIN = -9223372036854775808
  31. FLOAT_MAX_INTEGER = 16777216
  32. FLOAT_MIN_INTEGER = -16777216
  33. DOUBLE_MAX_INTEGER = 9007199254740992
  34. DOUBLE_MIN_INTEGER = -9007199254740992
  35. def check_type(value, valid_type):
  36. if not isinstance(value, valid_type):
  37. raise ValueError("Wrong input type")
  38. def check_value(value, valid_range):
  39. if value < valid_range[0] or value > valid_range[1]:
  40. raise ValueError("Input is not within the required range")
  41. def check_range(values, valid_range):
  42. if not valid_range[0] <= values[0] <= values[1] <= valid_range[1]:
  43. raise ValueError("Input range is not valid")
  44. def check_positive(value):
  45. if value <= 0:
  46. raise ValueError("Input must greater than 0")
  47. def check_positive_float(value, valid_max=None):
  48. if value <= 0 or not isinstance(value, float) or (valid_max is not None and value > valid_max):
  49. raise ValueError("Input need to be a valid positive float.")
  50. def check_bool(value):
  51. if not isinstance(value, bool):
  52. raise ValueError("Value needs to be a boolean.")
  53. def check_2tuple(value):
  54. if not (isinstance(value, tuple) and len(value) == 2):
  55. raise ValueError("Value needs to be a 2-tuple.")
  56. def check_list(value):
  57. if not isinstance(value, list):
  58. raise ValueError("The input needs to be a list.")
  59. def check_uint8(value):
  60. if not isinstance(value, int):
  61. raise ValueError("The input needs to be a integer")
  62. check_value(value, [UINT8_MIN, UINT8_MAX])
  63. def check_uint32(value):
  64. if not isinstance(value, int):
  65. raise ValueError("The input needs to be a integer")
  66. check_value(value, [UINT32_MIN, UINT32_MAX])
  67. def check_pos_int32(value):
  68. """Checks for int values starting from 1"""
  69. if not isinstance(value, int):
  70. raise ValueError("The input needs to be a integer")
  71. check_value(value, [POS_INT_MIN, INT32_MAX])
  72. def check_uint64(value):
  73. if not isinstance(value, int):
  74. raise ValueError("The input needs to be a integer")
  75. check_value(value, [UINT64_MIN, UINT64_MAX])
  76. def check_pos_int64(value):
  77. if not isinstance(value, int):
  78. raise ValueError("The input needs to be a integer")
  79. check_value(value, [UINT64_MIN, INT64_MAX])
  80. def check_pos_float32(value):
  81. check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER])
  82. def check_pos_float64(value):
  83. check_value(value, [UINT64_MIN, DOUBLE_MAX_INTEGER])
  84. def check_one_hot_op(method):
  85. """Wrapper method to check the parameters of one hot op."""
  86. @wraps(method)
  87. def new_method(self, *args, **kwargs):
  88. args = (list(args) + 2 * [None])[:2]
  89. num_classes, smoothing_rate = args
  90. if "num_classes" in kwargs:
  91. num_classes = kwargs.get("num_classes")
  92. if "smoothing_rate" in kwargs:
  93. smoothing_rate = kwargs.get("smoothing_rate")
  94. if num_classes is None:
  95. raise ValueError("num_classes")
  96. check_pos_int32(num_classes)
  97. kwargs["num_classes"] = num_classes
  98. if smoothing_rate is not None:
  99. check_value(smoothing_rate, [0., 1.])
  100. kwargs["smoothing_rate"] = smoothing_rate
  101. return method(self, **kwargs)
  102. return new_method
  103. def check_num_classes(method):
  104. """Wrapper method to check the parameters of number of classes."""
  105. @wraps(method)
  106. def new_method(self, *args, **kwargs):
  107. num_classes = (list(args) + [None])[0]
  108. if "num_classes" in kwargs:
  109. num_classes = kwargs.get("num_classes")
  110. if num_classes is None:
  111. raise ValueError("num_classes is not provided.")
  112. check_pos_int32(num_classes)
  113. kwargs["num_classes"] = num_classes
  114. return method(self, **kwargs)
  115. return new_method
  116. def check_de_type(method):
  117. """Wrapper method to check the parameters of data type."""
  118. @wraps(method)
  119. def new_method(self, *args, **kwargs):
  120. data_type = (list(args) + [None])[0]
  121. if "data_type" in kwargs:
  122. data_type = kwargs.get("data_type")
  123. if data_type is None:
  124. raise ValueError("data_type is not provided.")
  125. if not isinstance(data_type, typing.Type):
  126. raise TypeError("data_type is not a MindSpore data type.")
  127. kwargs["data_type"] = data_type
  128. return method(self, **kwargs)
  129. return new_method