You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from mindspore.common.api import _cell_graph_executor
  15. class ParallelValidator:
  16. """
  17. Validator for distribute operator.
  18. Args:
  19. net (Cell): `auto_parallel_mode` = True for networks where compile has been executed.
  20. Examples:
  21. >>> from mindspore.common.api import _cell_graph_executor
  22. >>> from parallel.util.utils import ParallelValidator
  23. >>> net = Net()
  24. >>> net.set_auto_parallel()
  25. >>> net.set_train()
  26. >>> phase, _ = _cell_graph_executor.compile(net, *inputs, auto_parallel_mode=True)
  27. >>> validator = ParallelValidator(net, phase) # Init validator by net and phase
  28. >>> assert validator.check_parameter_shape("x", [8, 3, 256, 256]) # Check parameter slice shape
  29. >>> # expect_layout: (device_arrangement, tensor_map, slice_shape, field_size, uniform_split, opt_shard_group)
  30. >>> expect_layout = ([4, 2], [1, -1, -1, -1], [8, 3, 256, 256], 0, True, '')
  31. >>> assert validator.check_parameter_laytout("x", expect_layout)
  32. >>> # check attrs for "ROIAlign-0" from graph_1
  33. >>> expect_attrs = {'pooled_height': POOLED_HEIGHT, 'pooled_width': POOLED_WIDTH}
  34. >>> assert validator.check_node_attrs("ROIAlign-0", expect_attrs, graph_id=1)
  35. >>> # check node inputs for "ROIAlign-0 from graph_0 (default graph_id)
  36. >>> expect_inputs = ['features', 'TensorScatterUpdate-0']
  37. >>> assert validator.check_node_inputs('ROIAlign-0', 'features', 'TensorScatterUpdate-0')
  38. >>> # check sub graph structure from graph_1
  39. >>> sub_graph = {
  40. ... 'ROIAlign-0': ['features', 'TensorScatterUpdate-0'],
  41. ... 'MaskedFill-0': ['ROIAlign-0', 'ExpandDims-2', 0.0],
  42. ... 'AllReduce-0': ['MaskedFill-0']
  43. ... }
  44. >>> assert validator.check_graph_structure(sub_graph, graph_id=1)
  45. """
  46. def __init__(self, net, phase):
  47. self._parameter_layout_dict = net.parameter_layout_dict
  48. self._graph_info_dict = _cell_graph_executor._graph_executor.get_parallel_graph_info(phase)
  49. @property
  50. def parameter_layout_dict(self):
  51. return self._parameter_layout_dict
  52. @property
  53. def graph_info_dict(self):
  54. return self._graph_info_dict
  55. def check_parameter_layout(self, param_name: str, layout: [tuple, list]) -> bool:
  56. """Verify parameter layout."""
  57. if not isinstance(layout, (tuple, list)):
  58. raise TypeError("Type of expect_inputs must be list or tuple, but got {}".format(type(layout)))
  59. if param_name not in self._parameter_layout_dict.keys():
  60. return False
  61. return self._parameter_layout_dict[param_name] == layout
  62. def check_parameter_shape(self, param_name: str, shape: [tuple, list]) -> bool:
  63. """Verify parameter shape"""
  64. if not isinstance(shape, (tuple, list)):
  65. raise TypeError("Type of expect_inputs must be list or tuple, but got {}".format(type(shape)))
  66. if param_name not in self._parameter_layout_dict.keys():
  67. return False
  68. return self._parameter_layout_dict[param_name][2] == shape
  69. def check_node_attrs(self, node_name: str, expect_attrs: dict, graph_id=0) -> bool:
  70. if not isinstance(expect_attrs, dict):
  71. raise TypeError("Type of expect_attrs must be dict, but got {}".format(type(expect_attrs)))
  72. cnode_info_dict = self._get_graph_cnode_info(graph_id)
  73. if node_name not in cnode_info_dict.keys():
  74. return False
  75. attrs = cnode_info_dict[node_name]['attrs']
  76. for attr_name in expect_attrs.keys():
  77. if attr_name not in attrs.keys() or attrs[attr_name] != expect_attrs[attr_name]:
  78. return False
  79. return True
  80. def check_node_inputs(self, node_name: str, expect_inputs: [tuple, list], graph_id=0) -> bool:
  81. if not isinstance(expect_inputs, (tuple, list)):
  82. raise TypeError("Type of expect_inputs must be list or tuple, but got {}".format(type(expect_inputs)))
  83. cnode_info_dict = self._get_graph_cnode_info(graph_id)
  84. expect_inputs = list(expect_inputs)
  85. if node_name not in cnode_info_dict.keys():
  86. return False
  87. inputs = cnode_info_dict[node_name]['inputs']
  88. return inputs == expect_inputs
  89. def check_graph_structure(self, nodes_dict: dict, graph_id=0) -> bool:
  90. if not isinstance(nodes_dict, dict):
  91. raise TypeError("Type of nodes_dict must be dict, but got {}".format(type(nodes_dict)))
  92. for name, inputs in nodes_dict.items():
  93. if not self.check_node_inputs(name, inputs, graph_id):
  94. return False
  95. return True
  96. def _get_graph_cnode_info(self, graph_id):
  97. graph_name = "@graph_" + str(graph_id)
  98. if graph_name not in self._graph_info_dict.keys():
  99. raise ValueError("{} is not exist".format(graph_name))
  100. return self._graph_info_dict[graph_name]