You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

condition.py 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Management of all conditions.
  17. This module is used to register all conditions, as well as their parameters.
  18. This module also provide the available conditions to condition_collections api.
  19. """
  20. from enum import Enum
  21. from mindinsight.conditionmgr.log import logger
  22. class ConditionIdEnum(Enum):
  23. """Condition ids."""
  24. WEIGHT_INITIALIZATION = "weight_initialization"
  25. WEIGHT_OVERFLOW = "weight_overflow"
  26. WEIGHT_TOO_LARGE = "weight_too_large"
  27. WEIGHT_TOO_SMALL = "weight_too_small"
  28. GRADIENT_VANISHING = "gradient_vanishing"
  29. GRADIENT_TOO_LARGE = "gradient_too_large"
  30. GRADIENT_EXPLODING = "gradient_exploding"
  31. TENSOR_OVERFLOW = "tensor_overflow"
  32. OPERATOR_OVERFLOW = "operator_overflow"
  33. NAN = "nan"
  34. OVERFLOW_ASCEND_CHIP = "overflow"
  35. INF = "inf"
  36. MAX_GT = "max_gt"
  37. MAX_LT = "max_lt"
  38. MIN_GT = "min_gt"
  39. MIN_LT = "min_lt"
  40. MAX_MIN_GT = "max_min_gt"
  41. MAX_MIN_LT = "max_min_lt"
  42. MEAN_GT = "mean_gt"
  43. MEAN_LT = "mean_lt"
  44. TENSOR_INITIALIZATION = "tensor_initialization"
  45. TENSOR_TOO_LARGE = "tensor_too_large"
  46. TENSOR_TOO_SMALL = "tensor_too_small"
  47. TENSOR_ALL_ZERO = "tensor_all_zero"
  48. WEIGHT_NOT_CHANGED = "weight_not_changed"
  49. WEIGHT_CHANGE_TOO_LARGE = "weight_change_too_large"
  50. WEIGHT_CHANGE_TOO_SMALL = "weight_change_too_small"
  51. TENSOR_CHANGE_TOO_LARGE = "tensor_change_too_large"
  52. TENSOR_CHANGE_TOO_SMALL = "tensor_change_too_small"
  53. TENSOR_NOT_CHANGED = "tensor_not_changed"
  54. class OptimizePhaseEnum(Enum):
  55. """Optimize phases."""
  56. TENSOR_CHECK = 400
  57. OPERATOR_CHECK = 100
  58. LOSS_CHECK = 300
  59. INPUT_DATA_CHECK = 200
  60. class ValueTypeEnum(Enum):
  61. """Value types."""
  62. FLOAT64 = 1
  63. INT64 = 2
  64. BOOL = 3
  65. class PlatformEnum(Enum):
  66. """Platform types."""
  67. GPU = "GPU"
  68. ASCEND = "Ascend"
  69. class TargetTypeEnum(Enum):
  70. """Target types."""
  71. TENSOR = 'tensor'
  72. WEIGHT = 'weight'
  73. ACTIVATION = 'activation'
  74. GRADIENT = 'gradient'
  75. class ConditionContext:
  76. """
  77. The class for condition context.
  78. Args:
  79. backend (str): parameter name.
  80. step (int): the type of value.
  81. debugger_capability (tuple): whether the param support no assignment.
  82. """
  83. def __init__(self, backend, step=0, debugger_capability=(1, 0)):
  84. self._backend = backend
  85. self._step = step
  86. self._debugger_capability = debugger_capability
  87. @property
  88. def backend(self):
  89. """Get backend."""
  90. return self._backend
  91. @property
  92. def step(self):
  93. """Get _step."""
  94. return self._step
  95. @property
  96. def debugger_capability(self):
  97. """Get debugger_capability."""
  98. return self._debugger_capability
  99. class ConditionParameter:
  100. """
  101. The class for parameters of conditions.
  102. Args:
  103. name (str): parameter name.
  104. value_type (ValueTypeEnum): the type of value.
  105. support_disable (bool): whether the param support no assignment.
  106. default_value (float): default value.
  107. visible_on_ui (bool): whether the param visible on ui.
  108. """
  109. def __init__(self, name, value_type: ValueTypeEnum, support_disable=True, default_value=None, visible_on_ui=True):
  110. self._name = name
  111. self._type = value_type
  112. self._support_disable = support_disable
  113. self._default_value = default_value
  114. self._visible_on_ui = visible_on_ui
  115. @property
  116. def name(self):
  117. """Get name of parameter."""
  118. return self._name
  119. @property
  120. def type(self):
  121. """Get type of parameter."""
  122. return self._type
  123. @property
  124. def support_disable(self):
  125. """Get support_disable of parameter."""
  126. return self._support_disable
  127. @property
  128. def default_value(self):
  129. """Get default_value of parameter."""
  130. return self._default_value
  131. @property
  132. def visible_on_ui(self):
  133. """Get visible_on_ui of parameter."""
  134. return self._visible_on_ui
  135. class Condition:
  136. """
  137. The class for parameters of conditions.
  138. Args:
  139. condition_id (str): condition id.
  140. abbr (str): the abbreviation of condition id.
  141. optimize_phase (OptimizePhaseEnum): optimize phase.
  142. parameters (List[ConditionParameter]): parameters.
  143. supported_target_type (TargetTypeEnum): the supported target type.
  144. supported_platforms (tuple[PlatformEnum, PlatformEnum]): the supported platforms.
  145. minimum_debugger_capability (tuple): the minimum debugger capability required.
  146. available_test_func (func): the function used to test whether the condition is available
  147. """
  148. def __init__(self, condition_id, abbr, optimize_phase, parameters, supported_target_type, supported_platforms,
  149. minimum_debugger_capability, available_test_func=None):
  150. self.id = condition_id
  151. self._abbr = abbr
  152. self.optimize_phase = optimize_phase
  153. self._parameters = {
  154. parameter.name: parameter for parameter in parameters
  155. }
  156. self._supported_target_type = supported_target_type
  157. self.supported_platforms = supported_platforms
  158. self.minimum_debugger_capability = minimum_debugger_capability
  159. self.available_test_func = available_test_func
  160. def get_parameter_definition(self, name):
  161. """Return parameter definition by the name"""
  162. return self._parameters[name]
  163. def is_available(self, condition_context):
  164. """Check is the condition available."""
  165. backend = condition_context.backend
  166. debugger_capability = condition_context.debugger_capability
  167. if debugger_capability < self.minimum_debugger_capability:
  168. logger.debug("The debugger capability is lower than the minimum debugger capability.")
  169. return False
  170. if backend not in [platform.value for platform in self.supported_platforms]:
  171. logger.debug("The condition %s is not supported on the platform.", self.id)
  172. return False
  173. if self.available_test_func is None:
  174. return True
  175. return self.available_test_func(condition_context)
  176. @property
  177. def abbr(self):
  178. """The abbreviation of condition"""
  179. return self._abbr
  180. @property
  181. def names(self):
  182. """The name of condition"""
  183. return self._parameters.keys()
  184. @property
  185. def parameters(self):
  186. """The parameters of condition"""
  187. return self._parameters.values()
  188. @property
  189. def supported_target_type(self):
  190. """The supported target type of condition"""
  191. return self._supported_target_type
  192. def check_initialization_available(condition_context):
  193. """Check if initialization is available at this step"""
  194. if condition_context.step == 0:
  195. return True
  196. return False