You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mi_validators.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Validator Functions for Offline Debugger APIs.
  17. """
  18. from functools import wraps
  19. import mindspore.offline_debug.dbg_services as cds
  20. from mindspore.offline_debug.mi_validator_helpers import parse_user_args, type_check, type_check_list, check_dir, check_uint32, check_uint64, check_iteration
  21. def check_init(method):
  22. """Wrapper method to check the parameters of DbgServices init."""
  23. @wraps(method)
  24. def new_method(self, *args, **kwargs):
  25. [dump_file_path, verbose], _ = parse_user_args(method, *args, **kwargs)
  26. type_check(dump_file_path, (str,), "dump_file_path")
  27. type_check(verbose, (bool,), "verbose")
  28. check_dir(dump_file_path)
  29. return method(self, *args, **kwargs)
  30. return new_method
  31. def check_initialize(method):
  32. """Wrapper method to check the parameters of DbgServices Initialize method."""
  33. @wraps(method)
  34. def new_method(self, *args, **kwargs):
  35. [net_name, is_sync_mode, max_mem_usage], _ = parse_user_args(method, *args, **kwargs)
  36. type_check(net_name, (str,), "net_name")
  37. type_check(is_sync_mode, (bool,), "is_sync_mode")
  38. check_uint32(max_mem_usage, "max_mem_usage")
  39. return method(self, *args, **kwargs)
  40. return new_method
  41. def check_add_watchpoint(method):
  42. """Wrapper method to check the parameters of DbgServices AddWatchpoint."""
  43. @wraps(method)
  44. def new_method(self, *args, **kwargs):
  45. [id_value, watch_condition, check_node_list, parameter_list], _ = parse_user_args(method, *args, **kwargs)
  46. check_uint32(id_value, "id")
  47. check_uint32(watch_condition, "watch_condition")
  48. type_check(check_node_list, (dict,), "check_node_list")
  49. for node_name, node_info in check_node_list.items():
  50. type_check(node_name, (str,), "node_name")
  51. type_check(node_info, (dict,), "node_info")
  52. for info_name, info_param in node_info.items():
  53. type_check(info_name, (str,), "node parameter name")
  54. if info_name in ["rank_id"]:
  55. if isinstance(info_param, str):
  56. if info_param not in ["*"]:
  57. raise ValueError("Node parameter {} only accepts '*' as string.".format(info_name))
  58. else:
  59. for param in info_param:
  60. check_uint32(param, "rank_id")
  61. elif info_name in ["root_graph_id"]:
  62. if isinstance(info_param, str):
  63. if info_param not in ["*"]:
  64. raise ValueError("Node parameter {} only accepts '*' as string.".format(info_name))
  65. else:
  66. for param in info_param:
  67. check_uint32(param, "root_graph_id")
  68. elif info_name in ["is_output"]:
  69. type_check(info_param, (bool,), "is_output")
  70. else:
  71. raise ValueError("Node parameter {} is not defined.".format(info_name))
  72. param_names = ["param_{0}".format(i) for i in range(len(parameter_list))]
  73. type_check_list(parameter_list, (cds.Parameter,), param_names)
  74. return method(self, *args, **kwargs)
  75. return new_method
  76. def check_remove_watchpoint(method):
  77. """Wrapper method to check the parameters of DbgServices RemoveWatchpoint."""
  78. @wraps(method)
  79. def new_method(self, *args, **kwargs):
  80. [id_value], _ = parse_user_args(method, *args, **kwargs)
  81. check_uint32(id_value, "id")
  82. return method(self, *args, **kwargs)
  83. return new_method
  84. def check_check_watchpoints(method):
  85. """Wrapper method to check the parameters of DbgServices CheckWatchpoint."""
  86. @wraps(method)
  87. def new_method(self, *args, **kwargs):
  88. [iteration], _ = parse_user_args(method, *args, **kwargs)
  89. check_iteration(iteration, "iteration")
  90. return method(self, *args, **kwargs)
  91. return new_method
  92. def check_read_tensor_info(method):
  93. """Wrapper method to check the parameters of DbgServices ReadTensors."""
  94. @wraps(method)
  95. def new_method(self, *args, **kwargs):
  96. [info_list], _ = parse_user_args(method, *args, **kwargs)
  97. info_names = ["info_{0}".format(i) for i in range(len(info_list))]
  98. type_check_list(info_list, (cds.TensorInfo,), info_names)
  99. return method(self, *args, **kwargs)
  100. return new_method
  101. def check_initialize_done(method):
  102. """Wrapper method to check if initlize is done for DbgServices."""
  103. @wraps(method)
  104. def new_method(self, *args, **kwargs):
  105. if not self.initialized:
  106. raise RuntimeError("Inilize should be called before any other methods of DbgServices!")
  107. return method(self, *args, **kwargs)
  108. return new_method
  109. def check_tensor_info_init(method):
  110. """Wrapper method to check the parameters of DbgServices TensorInfo init."""
  111. @wraps(method)
  112. def new_method(self, *args, **kwargs):
  113. [node_name, slot, iteration, rank_id, root_graph_id,
  114. is_output], _ = parse_user_args(method, *args, **kwargs)
  115. type_check(node_name, (str,), "node_name")
  116. check_uint32(slot, "slot")
  117. check_iteration(iteration, "iteration")
  118. check_uint32(rank_id, "rank_id")
  119. check_uint32(root_graph_id, "root_graph_id")
  120. type_check(is_output, (bool,), "is_output")
  121. return method(self, *args, **kwargs)
  122. return new_method
  123. def check_tensor_data_init(method):
  124. """Wrapper method to check the parameters of DbgServices TensorData init."""
  125. @wraps(method)
  126. def new_method(self, *args, **kwargs):
  127. [data_ptr, data_size, dtype, shape], _ = parse_user_args(method, *args, **kwargs)
  128. type_check(data_ptr, (bytes,), "data_ptr")
  129. check_uint64(data_size, "data_size")
  130. type_check(dtype, (int,), "dtype")
  131. shape_names = ["shape_{0}".format(i) for i in range(len(shape))]
  132. type_check_list(shape, (int,), shape_names)
  133. if len(data_ptr) != data_size:
  134. raise ValueError("data_ptr length ({0}) is not equal to data_size ({1}).".format(len(data_ptr), data_size))
  135. return method(self, *args, **kwargs)
  136. return new_method
  137. def check_tensor_base_data_init(method):
  138. """Wrapper method to check the parameters of DbgServices TensorBaseData init."""
  139. @wraps(method)
  140. def new_method(self, *args, **kwargs):
  141. [data_size, dtype, shape], _ = parse_user_args(method, *args, **kwargs)
  142. check_uint64(data_size, "data_size")
  143. type_check(dtype, (int,), "dtype")
  144. shape_names = ["shape_{0}".format(i) for i in range(len(shape))]
  145. type_check_list(shape, (int,), shape_names)
  146. return method(self, *args, **kwargs)
  147. return new_method
  148. def check_tensor_stat_data_init(method):
  149. """Wrapper method to check the parameters of DbgServices TensorBaseData init."""
  150. @wraps(method)
  151. def new_method(self, *args, **kwargs):
  152. [data_size, dtype, shape, is_bool, max_value, min_value,
  153. avg_value, count, neg_zero_count, pos_zero_count,
  154. nan_count, neg_inf_count, pos_inf_count,
  155. zero_count], _ = parse_user_args(method, *args, **kwargs)
  156. check_uint64(data_size, "data_size")
  157. type_check(dtype, (int,), "dtype")
  158. shape_names = ["shape_{0}".format(i) for i in range(len(shape))]
  159. type_check_list(shape, (int,), shape_names)
  160. type_check(is_bool, (bool,), "is_bool")
  161. type_check(max_value, (float,), "max_value")
  162. type_check(min_value, (float,), "min_value")
  163. type_check(avg_value, (float,), "avg_value")
  164. type_check(count, (int,), "count")
  165. type_check(neg_zero_count, (int,), "neg_zero_count")
  166. type_check(pos_zero_count, (int,), "pos_zero_count")
  167. type_check(nan_count, (int,), "nan_count")
  168. type_check(neg_inf_count, (int,), "neg_inf_count")
  169. type_check(pos_inf_count, (int,), "pos_inf_count")
  170. type_check(zero_count, (int,), "zero_count")
  171. return method(self, *args, **kwargs)
  172. return new_method
  173. def check_watchpoint_hit_init(method):
  174. """Wrapper method to check the parameters of DbgServices WatchpointHit init."""
  175. @wraps(method)
  176. def new_method(self, *args, **kwargs):
  177. [name, slot, condition, watchpoint_id,
  178. parameters, error_code, rank_id, root_graph_id], _ = parse_user_args(method, *args, **kwargs)
  179. type_check(name, (str,), "name")
  180. check_uint32(slot, "slot")
  181. type_check(condition, (int,), "condition")
  182. check_uint32(watchpoint_id, "watchpoint_id")
  183. param_names = ["param_{0}".format(i) for i in range(len(parameters))]
  184. type_check_list(parameters, (cds.Parameter,), param_names)
  185. type_check(error_code, (int,), "error_code")
  186. check_uint32(rank_id, "rank_id")
  187. check_uint32(root_graph_id, "root_graph_id")
  188. return method(self, *args, **kwargs)
  189. return new_method
  190. def check_parameter_init(method):
  191. """Wrapper method to check the parameters of DbgServices Parameter init."""
  192. @wraps(method)
  193. def new_method(self, *args, **kwargs):
  194. [name, disabled, value, hit, actual_value], _ = parse_user_args(method, *args, **kwargs)
  195. type_check(name, (str,), "name")
  196. type_check(disabled, (bool,), "disabled")
  197. type_check(value, (float,), "value")
  198. type_check(hit, (bool,), "hit")
  199. type_check(actual_value, (float,), "actual_value")
  200. return method(self, *args, **kwargs)
  201. return new_method