Merge pull request !53 from ougongchang/fixbug_pb_filetags/v0.2.0-alpha
| @@ -96,9 +96,16 @@ class MSGraph(Graph): | |||
| """ | |||
| logger.debug("Start to calc input.") | |||
| for node_def in graph_proto.node: | |||
| if not node_def.name: | |||
| logger.debug("The node name is empty, ignore it.") | |||
| continue | |||
| node_name = leaf_node_id_map_name[node_def.name] | |||
| node = self._leaf_nodes[node_name] | |||
| for input_def in node_def.input: | |||
| if not input_def.name: | |||
| logger.warning("The input node name is empty, ignore it. node name: %s.", node_name) | |||
| continue | |||
| edge_type = EdgeTypeEnum.DATA.value | |||
| if input_def.type == "CONTROL_EDGE": | |||
| edge_type = EdgeTypeEnum.CONTROL.value | |||
| @@ -60,7 +60,7 @@ class MSDataLoader: | |||
| self._latest_summary_filename = '' | |||
| self._latest_summary_file_size = 0 | |||
| self._summary_file_handler = None | |||
| self._latest_pb_file_mtime = 0 | |||
| self._pb_parser = _PbParser(summary_dir) | |||
| def get_events_data(self): | |||
| """Return events data read from log file.""" | |||
| @@ -348,14 +348,58 @@ class MSDataLoader: | |||
| list[str], filename list. | |||
| """ | |||
| pb_filenames = self._filter_pb_files(filenames) | |||
| pb_filenames = sorted(pb_filenames, key=lambda file: FileHandler.file_stat( | |||
| FileHandler.join(self._summary_dir, file)).mtime) | |||
| pb_filenames = self._pb_parser.sort_pb_files(pb_filenames) | |||
| for filename in pb_filenames: | |||
| mtime = FileHandler.file_stat(FileHandler.join(self._summary_dir, filename)).mtime | |||
| if mtime <= self._latest_pb_file_mtime: | |||
| tensor_event = self._pb_parser.parse_pb_file(filename) | |||
| if tensor_event is None: | |||
| continue | |||
| self._latest_pb_file_mtime = mtime | |||
| self._parse_pb_file(filename) | |||
| self._events_data.add_tensor_event(tensor_event) | |||
| class _PbParser: | |||
| """This class is used to parse pb file.""" | |||
| def __init__(self, summary_dir): | |||
| self._latest_filename = '' | |||
| self._latest_mtime = 0 | |||
| self._summary_dir = summary_dir | |||
| def parse_pb_file(self, filename): | |||
| """ | |||
| Parse single pb file. | |||
| Args: | |||
| filename (str): The file path of pb file. | |||
| Returns: | |||
| TensorEvent, if load pb file and build graph success, will return tensor event, else return None. | |||
| """ | |||
| if not self._is_parse_pb_file(filename): | |||
| return None | |||
| try: | |||
| tensor_event = self._parse_pb_file(filename) | |||
| return tensor_event | |||
| except UnknownError: | |||
| # Parse pb file failed, so return None. | |||
| return None | |||
| def sort_pb_files(self, filenames): | |||
| """Sort by creating time increments and filenames increments.""" | |||
| filenames = sorted(filenames, key=lambda file: ( | |||
| FileHandler.file_stat(FileHandler.join(self._summary_dir, file)).mtime, file)) | |||
| return filenames | |||
| def _is_parse_pb_file(self, filename): | |||
| """Determines whether the file should be loaded。""" | |||
| mtime = FileHandler.file_stat(FileHandler.join(self._summary_dir, filename)).mtime | |||
| if mtime < self._latest_mtime or \ | |||
| (mtime == self._latest_mtime and filename <= self._latest_filename): | |||
| return False | |||
| self._latest_mtime = mtime | |||
| self._latest_filename = filename | |||
| return True | |||
| def _parse_pb_file(self, filename): | |||
| """ | |||
| @@ -363,6 +407,9 @@ class MSDataLoader: | |||
| Args: | |||
| filename (str): The file path of pb file. | |||
| Returns: | |||
| TensorEvent, if load pb file and build graph success, will return tensor event, else return None. | |||
| """ | |||
| file_path = FileHandler.join(self._summary_dir, filename) | |||
| logger.info("Start to load graph from pb file, file path: %s.", file_path) | |||
| @@ -372,13 +419,24 @@ class MSDataLoader: | |||
| model_proto.ParseFromString(filehandler.read()) | |||
| except ParseError: | |||
| logger.warning("The given file is not a valid pb file, file path: %s.", file_path) | |||
| return | |||
| return None | |||
| graph = MSGraph() | |||
| graph.build_graph(model_proto.graph) | |||
| try: | |||
| graph.build_graph(model_proto.graph) | |||
| except Exception as ex: | |||
| # Normally, there are no exceptions, and it is only possible for users on the MindSpore side | |||
| # to dump other non-default graphs. | |||
| logger.error("Build graph failed, file path: %s.", file_path) | |||
| logger.exception(ex) | |||
| raise UnknownError(str(ex)) | |||
| tensor_event = TensorEvent(wall_time=FileHandler.file_stat(file_path), | |||
| step=0, | |||
| tag=filename, | |||
| plugin_name=PluginNameEnum.GRAPH.value, | |||
| value=graph) | |||
| self._events_data.add_tensor_event(tensor_event) | |||
| logger.info("Build graph success, file path: %s.", file_path) | |||
| return tensor_event | |||
| @@ -27,8 +27,12 @@ import pytest | |||
| from mindinsight.datavisual.data_transform import ms_data_loader | |||
| from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader | |||
| from mindinsight.datavisual.data_transform.ms_data_loader import _PbParser | |||
| from mindinsight.datavisual.data_transform.events_data import TensorEvent | |||
| from mindinsight.datavisual.common.enums import PluginNameEnum | |||
| from ..mock import MockLogger | |||
| from ....utils.log_generators.graph_pb_generator import create_graph_pb_file | |||
| # bytes of 3 scalar events | |||
| SCALAR_RECORD = (b'\x1e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\t\x96\xe1\xeb)>}\xd7A\x10\x01*' | |||
| @@ -69,9 +73,9 @@ class TestMsDataLoader: | |||
| summary_dir = tempfile.mkdtemp() | |||
| ms_loader = MSDataLoader(summary_dir) | |||
| ms_loader._check_files_deleted(new_file_list, old_file_list) | |||
| shutil.rmtree(summary_dir) | |||
| assert MockLogger.log_msg['warning'] == "There are some files has been deleted, " \ | |||
| "we will reload all files in path {}.".format(summary_dir) | |||
| shutil.rmtree(summary_dir) | |||
| @pytest.mark.usefixtures('crc_pass') | |||
| def test_load_success_with_crc_pass(self): | |||
| @@ -96,8 +100,8 @@ class TestMsDataLoader: | |||
| write_file(file2, SCALAR_RECORD) | |||
| ms_loader = MSDataLoader(summary_dir) | |||
| ms_loader.load() | |||
| assert 'Check crc faild and ignore this file' in str(MockLogger.log_msg['warning']) | |||
| shutil.rmtree(summary_dir) | |||
| assert 'Check crc faild and ignore this file' in str(MockLogger.log_msg['warning']) | |||
| def test_filter_event_files(self): | |||
| """Test filter_event_files function ok.""" | |||
| @@ -112,9 +116,83 @@ class TestMsDataLoader: | |||
| ms_loader = MSDataLoader(summary_dir) | |||
| res = ms_loader.filter_valid_files() | |||
| expected = sorted(['aaasummary.5678', 'summary.0012', 'hellosummary.98786', 'mysummary.123abce']) | |||
| shutil.rmtree(summary_dir) | |||
| assert sorted(res) == expected | |||
| def test_load_single_pb_file(self): | |||
| """Test load pb file success.""" | |||
| filename = 'ms_output.pb' | |||
| summary_dir = tempfile.mkdtemp() | |||
| create_graph_pb_file(output_dir=summary_dir, filename=filename) | |||
| ms_loader = MSDataLoader(summary_dir) | |||
| ms_loader.load() | |||
| events_data = ms_loader.get_events_data() | |||
| plugins = events_data.list_tags_by_plugin(PluginNameEnum.GRAPH.value) | |||
| shutil.rmtree(summary_dir) | |||
| assert len(plugins) == 1 | |||
| assert plugins[0] == filename | |||
| class TestPbParser: | |||
| """Test pb parser""" | |||
| _summary_dir = '' | |||
| def setup_method(self): | |||
| self._summary_dir = tempfile.mkdtemp() | |||
| def teardown_method(self): | |||
| shutil.rmtree(self._summary_dir) | |||
| def test_parse_pb_file(self): | |||
| """Test parse pb file success.""" | |||
| filename = 'ms_output.pb' | |||
| create_graph_pb_file(output_dir=self._summary_dir, filename=filename) | |||
| parser = _PbParser(self._summary_dir) | |||
| tensor_event = parser.parse_pb_file(filename) | |||
| assert isinstance(tensor_event, TensorEvent) | |||
| def test_is_parse_pb_file(self): | |||
| """Test parse an older file.""" | |||
| filename = 'ms_output.pb' | |||
| create_graph_pb_file(output_dir=self._summary_dir, filename=filename) | |||
| parser = _PbParser(self._summary_dir) | |||
| result = parser._is_parse_pb_file(filename) | |||
| assert result | |||
| filename = 'ms_output_older.pb' | |||
| file_path = create_graph_pb_file(output_dir=self._summary_dir, filename=filename) | |||
| atime = 1 | |||
| mtime = 1 | |||
| os.utime(file_path, (atime, mtime)) | |||
| result = parser._is_parse_pb_file(filename) | |||
| assert not result | |||
| def test_sort_pb_file_by_mtime(self): | |||
| """Test sort pb files.""" | |||
| filenames = ['abc.pb', 'bbc.pb'] | |||
| for file in filenames: | |||
| create_graph_pb_file(output_dir=self._summary_dir, filename=file) | |||
| parser = _PbParser(self._summary_dir) | |||
| sorted_filenames = parser.sort_pb_files(filenames) | |||
| assert filenames == sorted_filenames | |||
| def test_sort_pb_file_by_filename(self): | |||
| """Test sort pb file by file name.""" | |||
| filenames = ['aaa.pb', 'bbb.pb', 'ccc.pb'] | |||
| for file in filenames: | |||
| create_graph_pb_file(output_dir=self._summary_dir, filename=file) | |||
| atime, mtime = (3, 3) | |||
| os.utime(os.path.realpath(os.path.join(self._summary_dir, 'aaa.pb')), (atime, mtime)) | |||
| atime, mtime = (1, 1) | |||
| os.utime(os.path.realpath(os.path.join(self._summary_dir, 'bbb.pb')), (atime, mtime)) | |||
| os.utime(os.path.realpath(os.path.join(self._summary_dir, 'ccc.pb')), (atime, mtime)) | |||
| expected_filenames = ['bbb.pb', 'ccc.pb', 'aaa.pb'] | |||
| parser = _PbParser(self._summary_dir) | |||
| sorted_filenames = parser.sort_pb_files(filenames) | |||
| assert expected_filenames == sorted_filenames | |||
| def write_file(filename, record): | |||
| @@ -0,0 +1,40 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Log generator for graph pb file.""" | |||
| import os | |||
| import json | |||
| from google.protobuf import json_format | |||
| from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2 | |||
| def create_graph_pb_file(output_dir='./', filename='ms_output.pb'): | |||
| """Create graph pb file, and return file path.""" | |||
| graph_base = os.path.join(os.path.dirname(__file__), "graph_base.json") | |||
| with open(graph_base, 'r') as fp: | |||
| data = json.load(fp) | |||
| model_def = dict(graph=data) | |||
| model_proto = json_format.Parse(json.dumps(model_def), anf_ir_pb2.ModelProto()) | |||
| msg = model_proto.SerializeToString() | |||
| output_path = os.path.realpath(os.path.join(output_dir, filename)) | |||
| with open(output_path, 'wb') as fp: | |||
| fp.write(msg) | |||
| return output_path | |||
| if __name__ == '__main__': | |||
| create_graph_pb_file() | |||