You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from mindspore import context
  15. from mindspore.nn import Cell
  16. from mindspore.common.api import _cell_graph_executor
  17. class ParallelValidator:
  18. """
  19. Validator for distribute operator.
  20. Args:
  21. net (Cell): `auto_parallel_mode` = True for networks where compile has been executed.
  22. Examples:
  23. >>> from mindspore.common.api import _cell_graph_executor
  24. >>> from parallel.util.utils import ParallelValidator
  25. >>> net = Net()
  26. >>> net.set_auto_parallel()
  27. >>> net.set_train()
  28. >>> phase, _ = _cell_graph_executor.compile(net, *inputs, auto_parallel_mode=True)
  29. >>> validator = ParallelValidator(net, phase) # Init validator by net and phase
  30. >>> assert validator.check_parameter_shape("x", [8, 3, 256, 256]) # Check parameter slice shape
  31. >>> # expect_layout: (device_arrangement, tensor_map, slice_shape, field_size, uniform_split, opt_shard_group)
  32. >>> expect_layout = ([4, 2], [1, -1, -1, -1], [8, 3, 256, 256], 0, True, '')
  33. >>> assert validator.check_parameter_laytout("x", expect_layout)
  34. >>> # check attrs for "ROIAlign-0" from graph_1
  35. >>> expect_attrs = {'pooled_height': POOLED_HEIGHT, 'pooled_width': POOLED_WIDTH}
  36. >>> assert validator.check_node_attrs("ROIAlign-0", expect_attrs, graph_id=1)
  37. >>> # check node inputs for "ROIAlign-0 from graph_0 (default graph_id)
  38. >>> expect_inputs = ['features', 'TensorScatterUpdate-0']
  39. >>> assert validator.check_node_inputs('ROIAlign-0', 'features', 'TensorScatterUpdate-0')
  40. >>> # check sub graph structure from graph_1
  41. >>> sub_graph = {
  42. ... 'ROIAlign-0': ['features', 'TensorScatterUpdate-0'],
  43. ... 'MaskedFill-0': ['ROIAlign-0', 'ExpandDims-2', 0.0],
  44. ... 'AllReduce-0': ['MaskedFill-0']
  45. ... }
  46. >>> assert validator.check_graph_structure(sub_graph, graph_id=1)
  47. """
  48. def __init__(self, net, phase):
  49. self._parameter_layout_dict = net.parameter_layout_dict
  50. self._graph_info_dict = _cell_graph_executor._graph_executor.get_parallel_graph_info(phase)
  51. @property
  52. def parameter_layout_dict(self):
  53. return self._parameter_layout_dict
  54. @property
  55. def graph_info_dict(self):
  56. return self._graph_info_dict
  57. def check_parameter_layout(self, param_name: str, layout: [tuple, list]) -> bool:
  58. """Verify parameter layout."""
  59. if not isinstance(layout, (tuple, list)):
  60. raise TypeError("Type of expect_inputs must be list or tuple, but got {}".format(type(layout)))
  61. if param_name not in self._parameter_layout_dict.keys():
  62. return False
  63. return self._parameter_layout_dict[param_name] == layout
  64. def check_parameter_shape(self, param_name: str, shape: [tuple, list]) -> bool:
  65. """Verify parameter shape"""
  66. if not isinstance(shape, (tuple, list)):
  67. raise TypeError("Type of expect_inputs must be list or tuple, but got {}".format(type(shape)))
  68. if param_name not in self._parameter_layout_dict.keys():
  69. return False
  70. return self._parameter_layout_dict[param_name][2] == shape
  71. def check_node_attrs(self, node_name: str, expect_attrs: dict, graph_id=0) -> bool:
  72. if not isinstance(expect_attrs, dict):
  73. raise TypeError("Type of expect_attrs must be dict, but got {}".format(type(expect_attrs)))
  74. cnode_info_dict = self._get_graph_cnode_info(graph_id)
  75. if node_name not in cnode_info_dict.keys():
  76. return False
  77. attrs = cnode_info_dict[node_name]['attrs']
  78. for attr_name in expect_attrs.keys():
  79. if attr_name not in attrs.keys() or attrs[attr_name] != expect_attrs[attr_name]:
  80. return False
  81. return True
  82. def check_node_inputs(self, node_name: str, expect_inputs: [tuple, list], graph_id=0) -> bool:
  83. if not isinstance(expect_inputs, (tuple, list)):
  84. raise TypeError("Type of expect_inputs must be list or tuple, but got {}".format(type(expect_inputs)))
  85. cnode_info_dict = self._get_graph_cnode_info(graph_id)
  86. expect_inputs = list(expect_inputs)
  87. if node_name not in cnode_info_dict.keys():
  88. return False
  89. inputs = cnode_info_dict[node_name]['inputs']
  90. return inputs == expect_inputs
  91. def check_graph_structure(self, nodes_dict: dict, graph_id=0) -> bool:
  92. if not isinstance(nodes_dict, dict):
  93. raise TypeError("Type of nodes_dict must be dict, but got {}".format(type(nodes_dict)))
  94. for name, inputs in nodes_dict.items():
  95. if not self.check_node_inputs(name, inputs, graph_id):
  96. return False
  97. return True
  98. def _get_graph_cnode_info(self, graph_id):
  99. graph_name = "@graph_" + str(graph_id)
  100. if graph_name not in self._graph_info_dict.keys():
  101. raise ValueError("{} is not exist".format(graph_name))
  102. return self._graph_info_dict[graph_name]
  103. def compile_net(net: Cell, *inputs, auto_parallel_mode=False):
  104. net.set_auto_parallel()
  105. net.set_train()
  106. phase, _ = _cell_graph_executor.compile(net, *inputs, auto_parallel_mode=auto_parallel_mode)
  107. context.reset_auto_parallel_context()
  108. return phase