# Copyright 2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from mindspore import context from mindspore.nn import Cell from mindspore.common.api import _cell_graph_executor class ParallelValidator: """ Validator for distribute operator. Args: net (Cell): `auto_parallel_mode` = True for networks where compile has been executed. Examples: >>> from mindspore.common.api import _cell_graph_executor >>> from parallel.util.utils import ParallelValidator >>> net = Net() >>> net.set_auto_parallel() >>> net.set_train() >>> phase, _ = _cell_graph_executor.compile(net, *inputs, auto_parallel_mode=True) >>> validator = ParallelValidator(net, phase) # Init validator by net and phase >>> assert validator.check_parameter_shape("x", [8, 3, 256, 256]) # Check parameter slice shape >>> # expect_layout: (device_arrangement, tensor_map, slice_shape, field_size, uniform_split, opt_shard_group) >>> expect_layout = ([4, 2], [1, -1, -1, -1], [8, 3, 256, 256], 0, True, '') >>> assert validator.check_parameter_laytout("x", expect_layout) >>> # check attrs for "ROIAlign-0" from graph_1 >>> expect_attrs = {'pooled_height': POOLED_HEIGHT, 'pooled_width': POOLED_WIDTH} >>> assert validator.check_node_attrs("ROIAlign-0", expect_attrs, graph_id=1) >>> # check node inputs for "ROIAlign-0 from graph_0 (default graph_id) >>> expect_inputs = ['features', 'TensorScatterUpdate-0'] >>> assert validator.check_node_inputs('ROIAlign-0', 'features', 'TensorScatterUpdate-0') >>> # check sub graph structure from graph_1 >>> sub_graph = { ... 'ROIAlign-0': ['features', 'TensorScatterUpdate-0'], ... 'MaskedFill-0': ['ROIAlign-0', 'ExpandDims-2', 0.0], ... 'AllReduce-0': ['MaskedFill-0'] ... } >>> assert validator.check_graph_structure(sub_graph, graph_id=1) """ def __init__(self, net, phase): self._parameter_layout_dict = net.parameter_layout_dict self._graph_info_dict = _cell_graph_executor._graph_executor.get_parallel_graph_info(phase) @property def parameter_layout_dict(self): return self._parameter_layout_dict @property def graph_info_dict(self): return self._graph_info_dict def check_parameter_layout(self, param_name: str, layout: [tuple, list]) -> bool: """Verify parameter layout.""" if not isinstance(layout, (tuple, list)): raise TypeError("Type of expect_inputs must be list or tuple, but got {}".format(type(layout))) if param_name not in self._parameter_layout_dict.keys(): return False return self._parameter_layout_dict[param_name] == layout def check_parameter_shape(self, param_name: str, shape: [tuple, list]) -> bool: """Verify parameter shape""" if not isinstance(shape, (tuple, list)): raise TypeError("Type of expect_inputs must be list or tuple, but got {}".format(type(shape))) if param_name not in self._parameter_layout_dict.keys(): return False return self._parameter_layout_dict[param_name][2] == shape def check_node_attrs(self, node_name: str, expect_attrs: dict, graph_id=0) -> bool: if not isinstance(expect_attrs, dict): raise TypeError("Type of expect_attrs must be dict, but got {}".format(type(expect_attrs))) cnode_info_dict = self._get_graph_cnode_info(graph_id) if node_name not in cnode_info_dict.keys(): return False attrs = cnode_info_dict[node_name]['attrs'] for attr_name in expect_attrs.keys(): if attr_name not in attrs.keys() or attrs[attr_name] != expect_attrs[attr_name]: return False return True def check_node_inputs(self, node_name: str, expect_inputs: [tuple, list], graph_id=0) -> bool: if not isinstance(expect_inputs, (tuple, list)): raise TypeError("Type of expect_inputs must be list or tuple, but got {}".format(type(expect_inputs))) cnode_info_dict = self._get_graph_cnode_info(graph_id) expect_inputs = list(expect_inputs) if node_name not in cnode_info_dict.keys(): return False inputs = cnode_info_dict[node_name]['inputs'] return inputs == expect_inputs def check_graph_structure(self, nodes_dict: dict, graph_id=0) -> bool: if not isinstance(nodes_dict, dict): raise TypeError("Type of nodes_dict must be dict, but got {}".format(type(nodes_dict))) for name, inputs in nodes_dict.items(): if not self.check_node_inputs(name, inputs, graph_id): return False return True def _get_graph_cnode_info(self, graph_id): graph_name = "@graph_" + str(graph_id) if graph_name not in self._graph_info_dict.keys(): raise ValueError("{} is not exist".format(graph_name)) return self._graph_info_dict[graph_name] def compile_net(net: Cell, *inputs, auto_parallel_mode=False): net.set_auto_parallel() net.set_train() phase, _ = _cell_graph_executor.compile(net, *inputs, auto_parallel_mode=auto_parallel_mode) context.reset_auto_parallel_context() return phase