Browse Source

!167 optimize redundant funtion codes

Merge pull request !167 from liangyongxiong/redundant-codes
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
4f3d07937d
2 changed files with 39 additions and 56 deletions
  1. +26
    -33
      mindinsight/datavisual/data_transform/ms_data_loader.py
  2. +13
    -23
      mindinsight/datavisual/data_transform/summary_watcher.py

+ 26
- 33
mindinsight/datavisual/data_transform/ms_data_loader.py View File

@@ -386,48 +386,40 @@ class _SummaryParser(_Parser):
Args: Args:
event (Event): Message event in summary proto, data read from file handler. event (Event): Message event in summary proto, data read from file handler.
""" """
plugins = {
'scalar_value': PluginNameEnum.SCALAR,
'image': PluginNameEnum.IMAGE,
'histogram': PluginNameEnum.HISTOGRAM,
}

if event.HasField('summary'): if event.HasField('summary'):
for value in event.summary.value: for value in event.summary.value:
if value.HasField('scalar_value'):
tag = '{}/{}'.format(value.tag, PluginNameEnum.SCALAR.value)
tensor_event = TensorEvent(wall_time=event.wall_time,
step=event.step,
tag=tag,
plugin_name=PluginNameEnum.SCALAR.value,
value=value.scalar_value,
filename=self._latest_filename)
self._events_data.add_tensor_event(tensor_event)
for plugin in plugins:
if not value.HasField(plugin):
continue
plugin_name_enum = plugins[plugin]
tensor_event_value = getattr(value, plugin)

if plugin == 'histogram':
tensor_event_value = HistogramContainer(tensor_event_value)
# Drop steps if original_buckets_count exceeds HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT
# to avoid time-consuming re-sample process.
if tensor_event_value.original_buckets_count > HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT:
logger.warning('original_buckets_count exceeds '
'HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT')
continue


if value.HasField('image'):
tag = '{}/{}'.format(value.tag, PluginNameEnum.IMAGE.value)
tensor_event = TensorEvent(wall_time=event.wall_time, tensor_event = TensorEvent(wall_time=event.wall_time,
step=event.step, step=event.step,
tag=tag,
plugin_name=PluginNameEnum.IMAGE.value,
value=value.image,
tag='{}/{}'.format(value.tag, plugin_name_enum.value),
plugin_name=plugin_name_enum.value,
value=tensor_event_value,
filename=self._latest_filename) filename=self._latest_filename)
self._events_data.add_tensor_event(tensor_event) self._events_data.add_tensor_event(tensor_event)


if value.HasField('histogram'):
histogram_msg = HistogramContainer(value.histogram)
# Drop steps if original_buckets_count exceeds HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT
# to avoid time-consuming re-sample process.
if histogram_msg.original_buckets_count > HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT:
logger.warning('original_buckets_count exceeds HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT')
else:
tag = '{}/{}'.format(value.tag, PluginNameEnum.HISTOGRAM.value)
tensor_event = TensorEvent(wall_time=event.wall_time,
step=event.step,
tag=tag,
plugin_name=PluginNameEnum.HISTOGRAM.value,
value=histogram_msg,
filename=self._latest_filename)
self._events_data.add_tensor_event(tensor_event)

if event.HasField('graph_def'):
graph_proto = event.graph_def
elif event.HasField('graph_def'):
graph = MSGraph() graph = MSGraph()
graph.build_graph(graph_proto)
graph.build_graph(event.graph_def)
tensor_event = TensorEvent(wall_time=event.wall_time, tensor_event = TensorEvent(wall_time=event.wall_time,
step=event.step, step=event.step,
tag=self._latest_filename, tag=self._latest_filename,
@@ -439,6 +431,7 @@ class _SummaryParser(_Parser):
graph_tags = self._events_data.list_tags_by_plugin(PluginNameEnum.GRAPH.value) graph_tags = self._events_data.list_tags_by_plugin(PluginNameEnum.GRAPH.value)
except KeyError: except KeyError:
graph_tags = [] graph_tags = []

summary_tags = self.filter_files(graph_tags) summary_tags = self.filter_files(graph_tags)
for tag in summary_tags: for tag in summary_tags:
self._events_data.delete_tensor_event(tag) self._events_data.delete_tensor_event(tag)


+ 13
- 23
mindinsight/datavisual/data_transform/summary_watcher.py View File

@@ -64,20 +64,12 @@ class SummaryWatcher:
if self._contains_null_byte(summary_base_dir=summary_base_dir): if self._contains_null_byte(summary_base_dir=summary_base_dir):
return [] return []
if not os.path.exists(summary_base_dir):
logger.warning('Path of summary base directory not exists.')
return []
if not os.path.isdir(summary_base_dir):
logger.warning('Path of summary base directory is not a valid directory.')
relative_path = os.path.join('.', '')
if not self._is_valid_summary_directory(summary_base_dir, relative_path):
return [] return []
summary_dict = {} summary_dict = {}
if not overall:
counter = Counter(max_count=self.MAX_SCAN_COUNT)
else:
counter = Counter()
counter = Counter(max_count=None if overall else self.MAX_SCAN_COUNT)
try: try:
entries = os.scandir(summary_base_dir) entries = os.scandir(summary_base_dir)
@@ -94,19 +86,13 @@ class SummaryWatcher:
logger.info('Stop further scanning due to overall is False and ' logger.info('Stop further scanning due to overall is False and '
'number of scanned files exceeds upper limit.') 'number of scanned files exceeds upper limit.')
break break
relative_path = os.path.join('.', '')
if entry.is_symlink(): if entry.is_symlink():
pass pass
elif entry.is_file(): elif entry.is_file():
self._update_summary_dict(summary_dict, summary_base_dir, relative_path, entry) self._update_summary_dict(summary_dict, summary_base_dir, relative_path, entry)
elif entry.is_dir(): elif entry.is_dir():
full_path = os.path.realpath(os.path.join(summary_base_dir, entry.name))
try:
subdir_entries = os.scandir(full_path)
except PermissionError:
logger.warning('Path of %s under summary base directory is not accessible.', entry.name)
continue
self._scan_subdir_entries(summary_dict, summary_base_dir, subdir_entries, entry.name, counter)
entry_path = os.path.realpath(os.path.join(summary_base_dir, entry.name))
self._scan_subdir_entries(summary_dict, summary_base_dir, entry_path, entry.name, counter)
directories = [] directories = []
for key, value in summary_dict.items(): for key, value in summary_dict.items():
@@ -130,18 +116,24 @@ class SummaryWatcher:
return directories return directories
def _scan_subdir_entries(self, summary_dict, summary_base_dir, subdir_entries, entry_name, counter):
def _scan_subdir_entries(self, summary_dict, summary_base_dir, entry_path, entry_name, counter):
""" """
Scan subdir entries. Scan subdir entries.
Args: Args:
summary_dict (dict): Temporary data structure to hold summary directory info. summary_dict (dict): Temporary data structure to hold summary directory info.
summary_base_dir (str): Path of summary base directory. summary_base_dir (str): Path of summary base directory.
entry_path(str): Path entry.
entry_name (str): Name of entry. entry_name (str): Name of entry.
subdir_entries(DirEntry): Directory entry instance.
counter (Counter): An instance of CountLimiter. counter (Counter): An instance of CountLimiter.
""" """
try:
subdir_entries = os.scandir(entry_path)
except PermissionError:
logger.warning('Path of %s under summary base directory is not accessible.', entry_name)
return
for subdir_entry in subdir_entries: for subdir_entry in subdir_entries:
if len(summary_dict) == self.MAX_SUMMARY_DIR_COUNT: if len(summary_dict) == self.MAX_SUMMARY_DIR_COUNT:
break break
@@ -189,8 +181,6 @@ class SummaryWatcher:
""" """
summary_base_dir = os.path.realpath(summary_base_dir) summary_base_dir = os.path.realpath(summary_base_dir)
summary_directory = os.path.realpath(os.path.join(summary_base_dir, relative_path)) summary_directory = os.path.realpath(os.path.join(summary_base_dir, relative_path))
if summary_base_dir == summary_directory:
return True
if not os.path.exists(summary_directory): if not os.path.exists(summary_directory):
logger.warning('Path of summary directory not exists.') logger.warning('Path of summary directory not exists.')


Loading…
Cancel
Save