Browse Source

compare scalars within multiple train jobs

tags/v0.3.0-alpha
liangyongxiong 6 years ago
parent
commit
8893236417
6 changed files with 269 additions and 49 deletions
  1. +18
    -10
      mindinsight/backend/datavisual/task_manager_api.py
  2. +11
    -0
      mindinsight/backend/datavisual/train_visual_api.py
  3. +34
    -6
      mindinsight/datavisual/data_transform/data_manager.py
  4. +81
    -32
      mindinsight/datavisual/data_transform/summary_watcher.py
  5. +48
    -1
      mindinsight/datavisual/processors/scalars_processor.py
  6. +77
    -0
      mindinsight/datavisual/processors/train_task_manager.py

+ 18
- 10
mindinsight/backend/datavisual/task_manager_api.py View File

@@ -25,10 +25,11 @@ from flask import request
from flask import jsonify

from mindinsight.conf import settings
from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher
from mindinsight.datavisual.utils.tools import str_to_bool
from mindinsight.datavisual.utils.tools import get_train_id
from mindinsight.datavisual.processors.train_task_manager import TrainTaskManager
from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher
from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER


@@ -65,16 +66,11 @@ def query_train_jobs():
offset = request.args.get("offset", default=0)
limit = request.args.get("limit", default=10)

summary_watcher = SummaryWatcher()
total, directories = summary_watcher.list_summary_directories_by_pagination(
settings.SUMMARY_BASE_DIR, offset, limit)
offset = Validation.check_offset(offset=offset)
limit = Validation.check_limit(limit, min_value=1, max_value=SummaryWatcher.MAX_SUMMARY_DIR_COUNT)

train_jobs = [{
'train_id': directory['relative_path'],
'relative_path': directory['relative_path'],
'create_time': directory['create_time'].strftime('%Y-%m-%d %H:%M:%S'),
'update_time': directory['update_time'].strftime('%Y-%m-%d %H:%M:%S'),
} for directory in directories]
processor = TrainTaskManager(DATA_MANAGER)
total, train_jobs = processor.query_train_jobs(offset, limit)

return jsonify({
'name': os.path.basename(os.path.realpath(settings.SUMMARY_BASE_DIR)),
@@ -83,6 +79,18 @@ def query_train_jobs():
})


@BLUEPRINT.route("/datavisual/train-job-caches", methods=["POST"])
def cache_train_jobs():
""" Cache train jobs."""
data = request.get_json(silent=True)
train_ids = data.get('train_ids', [])

processor = TrainTaskManager(DATA_MANAGER)
cache_result = processor.cache_train_jobs(train_ids)

return jsonify({'cache_result': cache_result})


def init_module(app):
"""
Init module entry.


+ 11
- 0
mindinsight/backend/datavisual/train_visual_api.py View File

@@ -162,6 +162,17 @@ def histogram():
return jsonify(response)


@BLUEPRINT.route("/datavisual/scalars", methods=["GET"])
def get_scalars():
"""Get scalar data for given train_ids and tags."""
train_ids = request.args.getlist('train_id')
tags = request.args.getlist('tag')

processor = ScalarsProcessor(DATA_MANAGER)
scalars = processor.get_scalars(train_ids, tags)
return jsonify({'scalars': scalars})


def init_module(app):
"""
Init module entry.


+ 34
- 6
mindinsight/datavisual/data_transform/data_manager.py View File

@@ -45,7 +45,7 @@ from mindinsight.utils.exceptions import ParamValueError


@enum.unique
class _CacheStatus(enum.Enum):
class CacheStatus(enum.Enum):
"""Train job cache status."""
NOT_IN_CACHE = "NOT_IN_CACHE"
CACHING = "CACHING"
@@ -63,13 +63,15 @@ class _BasicTrainJob:
abs_summary_dir (str): The canonical path of summary directory. It should be the return value of realpath().
create_time (DateTime): The create time of summary directory.
update_time (DateTime): The latest modify time of summary files directly in the summary directory.
profiler_dir (str): The relative path of profiler directory.
"""
def __init__(self, train_id, abs_summary_base_dir, abs_summary_dir, create_time, update_time):
def __init__(self, train_id, abs_summary_base_dir, abs_summary_dir, create_time, update_time, profiler_dir):
self._train_id = train_id
self._abs_summary_base_dir = abs_summary_base_dir
self._abs_summary_dir = abs_summary_dir
self._create_time = create_time
self._update_time = update_time
self._profiler_dir = profiler_dir

@property
def abs_summary_dir(self):
@@ -86,6 +88,16 @@ class _BasicTrainJob:
"""Get train id."""
return self._train_id

@property
def profiler_dir(self):
"""Get profiler directory path."""
return self._profiler_dir

@property
def create_time(self):
"""Get create time."""
return self._create_time

@property
def update_time(self):
"""Get update time."""
@@ -108,7 +120,7 @@ class CachedTrainJob:
# Other cached content is stored here.
self._content = {}

self._cache_status = _CacheStatus.NOT_IN_CACHE
self._cache_status = CacheStatus.NOT_IN_CACHE
self._key_locks = {}

@property
@@ -203,7 +215,7 @@ class TrainJob:
self._brief = brief_train_job
self._detail = detail_train_job
if self._detail is None:
self._cache_status = _CacheStatus.NOT_IN_CACHE
self._cache_status = CacheStatus.NOT_IN_CACHE
else:
self._cache_status = self._detail.cache_status

@@ -241,6 +253,20 @@ class TrainJob:
"""
return self._brief.get(key)

def get_basic_info(self):
"""
Get basic info.

Returns:
basic_info (_BasicTrainJob): Basic info about the train job.
"""
return self._brief.basic_info

@property
def cache_status(self):
"""Get cache status."""
return self._cache_status


class BaseCacheItemUpdater(abc.ABC):
"""Abstract base class for other modules to update cache content."""
@@ -686,7 +712,7 @@ class _DetailCacheManager(_BaseCacheManager):
train_job_obj.set(DATAVISUAL_CACHE_KEY, train_job)

# Will assign real value in future.
train_job_obj.cache_status = _CacheStatus.CACHED
train_job_obj.cache_status = CacheStatus.CACHED

return train_job_obj

@@ -863,6 +889,7 @@ class DataManager:

basic_train_jobs = []
for info in summaries_info:
profiler = info['profiler']
basic_train_jobs.append(_BasicTrainJob(
train_id=info['relative_path'],
abs_summary_base_dir=self._summary_base_dir,
@@ -871,7 +898,8 @@ class DataManager:
info['relative_path']
)),
create_time=info['create_time'],
update_time=info['update_time']
update_time=info['update_time'],
profiler_dir=None if profiler is None else profiler['directory'],
))

self._brief_cache.update_cache(basic_train_jobs)


+ 81
- 32
mindinsight/datavisual/data_transform/summary_watcher.py View File

@@ -31,6 +31,7 @@ class SummaryWatcher:
SUMMARY_FILENAME_REGEX = r'summary\.(?P<timestamp>\d+)'
PB_FILENAME_REGEX = r'\.pb$'
PROFILER_DIRECTORY_REGEX = r'^profiler$'
MAX_SUMMARY_DIR_COUNT = 999
# scan at most 20000 files/directories (approximately 1 seconds)
@@ -52,6 +53,8 @@ class SummaryWatcher:
starting with "./".
- create_time (datetime): Creation time of summary file.
- update_time (datetime): Modification time of summary file.
- profiler (dict): profiler info, including profiler subdirectory path, profiler creation time and
profiler modification time.
Examples:
>>> from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher
@@ -95,7 +98,7 @@ class SummaryWatcher:
if entry.is_symlink():
pass
elif entry.is_file():
self._update_summary_dict(summary_dict, relative_path, entry)
self._update_summary_dict(summary_dict, summary_base_dir, relative_path, entry)
elif entry.is_dir():
full_path = os.path.realpath(os.path.join(summary_base_dir, entry.name))
try:
@@ -103,27 +106,39 @@ class SummaryWatcher:
except PermissionError:
logger.warning('Path of %s under summary base directory is not accessible.', entry.name)
continue
self._scan_subdir_entries(summary_dict, subdir_entries, entry.name, counter)
directories = [{
'relative_path': key,
'create_time': value['ctime'],
'update_time': value['mtime'],
} for key, value in summary_dict.items()]
self._scan_subdir_entries(summary_dict, summary_base_dir, subdir_entries, entry.name, counter)
directories = []
for key, value in summary_dict.items():
directory = {
'relative_path': key,
'profiler': None,
'create_time': value['ctime'],
'update_time': value['mtime'],
}
profiler = value.get('profiler')
if profiler is not None:
directory['profiler'] = {
'directory': profiler['directory'],
'create_time': profiler['ctime'],
'update_time': profiler['mtime'],
}
directories.append(directory)
# sort by update time in descending order and relative path in ascending order
directories.sort(key=lambda x: (-int(x['update_time'].timestamp()), x['relative_path']))
return directories
def _scan_subdir_entries(self, summary_dict, subdir_entries, entry_name, counter):
def _scan_subdir_entries(self, summary_dict, summary_base_dir, subdir_entries, entry_name, counter):
"""
Scan subdir entries.
Args:
summary_dict (dict): Temporary data structure to hold summary directory info.
subdir_entries(DirEntry): Directory entry instance.
summary_base_dir (str): Path of summary base directory.
entry_name (str): Name of entry.
subdir_entries(DirEntry): Directory entry instance.
counter (Counter): An instance of CountLimiter.
"""
@@ -139,8 +154,7 @@ class SummaryWatcher:
subdir_relative_path = os.path.join('.', entry_name)
if subdir_entry.is_symlink():
pass
elif subdir_entry.is_file():
self._update_summary_dict(summary_dict, subdir_relative_path, subdir_entry)
self._update_summary_dict(summary_dict, summary_base_dir, subdir_relative_path, subdir_entry)
def _contains_null_byte(self, **kwargs):
"""
@@ -194,40 +208,62 @@ class SummaryWatcher:
return True
def _update_summary_dict(self, summary_dict, relative_path, entry):
def _update_summary_dict(self, summary_dict, summary_base_dir, relative_path, entry):
"""
Update summary_dict with ctime and mtime.
Args:
summary_dict (dict): Temporary data structure to hold summary directory info.
summary_base_dir (str): Path of summary base directory.
relative_path (str): Relative path of summary directory, referring to summary base directory,
starting with "./" .
entry (DirEntry): Directory entry instance needed to check with regular expression.
"""
summary_pattern = re.search(self.SUMMARY_FILENAME_REGEX, entry.name)
pb_pattern = re.search(self.PB_FILENAME_REGEX, entry.name)
if summary_pattern is None and pb_pattern is None:
return
ctime = datetime.datetime.fromtimestamp(entry.stat().st_ctime).astimezone()
mtime = datetime.datetime.fromtimestamp(entry.stat().st_mtime).astimezone()
if summary_pattern is not None:
timestamp = int(summary_pattern.groupdict().get('timestamp'))
try:
# extract created time from filename
ctime = datetime.datetime.fromtimestamp(timestamp).astimezone()
except OverflowError:
if entry.is_file():
summary_pattern = re.search(self.SUMMARY_FILENAME_REGEX, entry.name)
pb_pattern = re.search(self.PB_FILENAME_REGEX, entry.name)
if summary_pattern is None and pb_pattern is None:
return
if summary_pattern is not None:
timestamp = int(summary_pattern.groupdict().get('timestamp'))
try:
# extract created time from filename
ctime = datetime.datetime.fromtimestamp(timestamp).astimezone()
except OverflowError:
return
if relative_path not in summary_dict:
summary_dict[relative_path] = {
'ctime': ctime,
'mtime': mtime,
'profiler': None,
}
elif summary_dict[relative_path]['ctime'] < ctime:
summary_dict[relative_path].update({
'ctime': ctime,
'mtime': mtime,
})
elif entry.is_dir():
profiler_pattern = re.search(self.PROFILER_DIRECTORY_REGEX, entry.name)
full_dir_path = os.path.join(summary_base_dir, relative_path, entry.name)
if profiler_pattern is None or self._is_empty_directory(full_dir_path):
return
else:
ctime = datetime.datetime.fromtimestamp(entry.stat().st_ctime).astimezone()
# extract modified time from filesystem
mtime = datetime.datetime.fromtimestamp(entry.stat().st_mtime).astimezone()
if relative_path not in summary_dict or summary_dict[relative_path]['ctime'] < ctime:
summary_dict[relative_path] = {
profiler = {
'directory': os.path.join('.', entry.name),
'ctime': ctime,
'mtime': mtime,
}
if relative_path not in summary_dict:
summary_dict[relative_path] = {
'ctime': ctime,
'mtime': mtime,
'profiler': profiler,
}
def is_summary_directory(self, summary_base_dir, relative_path):
"""
Check if the given summary directory is valid.
@@ -259,15 +295,28 @@ class SummaryWatcher:
raise FileSystemPermissionError('Path of summary base directory is not accessible.')
for entry in entries:
if entry.is_symlink() or not entry.is_file():
if entry.is_symlink():
continue
summary_pattern = re.search(self.SUMMARY_FILENAME_REGEX, entry.name)
if summary_pattern is not None and entry.is_file():
return True
pb_pattern = re.search(self.PB_FILENAME_REGEX, entry.name)
if summary_pattern or pb_pattern:
if pb_pattern is not None and entry.is_file():
return True
profiler_pattern = re.search(self.PROFILER_DIRECTORY_REGEX, entry.name)
if profiler_pattern is not None and entry.is_dir():
full_path = os.path.realpath(os.path.join(summary_directory, entry.name))
if not self._is_empty_directory(full_path):
return True
return False
def _is_empty_directory(self, directory):
return not bool(os.listdir(directory))
def list_summary_directories_by_pagination(self, summary_base_dir, offset=0, limit=10):
"""
List summary directories within base directory.


+ 48
- 1
mindinsight/datavisual/processors/scalars_processor.py View File

@@ -13,7 +13,10 @@
# limitations under the License.
# ============================================================================
"""Scalar Processor APIs."""
from mindinsight.utils.exceptions import ParamValueError
from urllib.parse import unquote

from mindinsight.utils.exceptions import ParamValueError, UrlDecodeError
from mindinsight.datavisual.utils.tools import if_nan_inf_to_none
from mindinsight.datavisual.common.exceptions import ScalarNotExistError
from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.processors.base_processor import BaseProcessor
@@ -46,3 +49,47 @@ class ScalarsProcessor(BaseProcessor):
'step': tensor.step,
'value': tensor.value})
return dict(metadatas=job_response)

def get_scalars(self, train_ids, tags):
"""
Get scalar data for given train_ids and tags.

Args:
train_ids (list): Specify list of train job ID.
tags (list): Specify list of tags.

Returns:
list[dict], a list of dictionaries containing the `wall_time`, `step`, `value` for each scalar.
"""
for index, train_id in enumerate(train_ids):
try:
train_id = unquote(train_id, errors='strict')
except UnicodeDecodeError:
raise UrlDecodeError('Unquote train id error with strict mode')
else:
train_ids[index] = train_id

scalars = []
for train_id in train_ids:
for tag in tags:
try:
tensors = self._data_manager.list_tensors(train_id, tag)
except ParamValueError:
continue

scalar = {
'train_id': train_id,
'tag': tag,
'values': [],
}

for tensor in tensors:
scalar['values'].append({
'wall_time': tensor.wall_time,
'step': tensor.step,
'value': if_nan_inf_to_none('scalar_value', tensor.value),
})

scalars.append(scalar)

return scalars

+ 77
- 0
mindinsight/datavisual/processors/train_task_manager.py View File

@@ -14,11 +14,13 @@
# ============================================================================
"""Train task manager."""

from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.common import exceptions
from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.processors.base_processor import BaseProcessor
from mindinsight.datavisual.data_transform.data_manager import DATAVISUAL_PLUGIN_KEY, DATAVISUAL_CACHE_KEY
from mindinsight.datavisual.data_transform.data_manager import CacheStatus


class TrainTaskManager(BaseProcessor):
@@ -75,3 +77,78 @@ class TrainTaskManager(BaseProcessor):
return dict(
plugins=plugins
)

def query_train_jobs(self, offset=0, limit=10):
"""
Query train jobs.

Args:
offset (int): Specify page number. Default is 0.
limit (int): Specify page size. Default is 10.

Returns:
tuple, return quantity of total train jobs and list of train jobs specified by offset and limit.
"""
brief_cache = self._data_manager.get_brief_cache()
brief_train_jobs = list(brief_cache.get_train_jobs().values())
brief_train_jobs.sort(key=lambda x: x.basic_info.update_time, reverse=True)
total = len(brief_train_jobs)

start = offset * limit
end = (offset + 1) * limit
train_jobs = []

train_ids = [train_job.basic_info.train_id for train_job in brief_train_jobs[start:end]]

for train_id in train_ids:
try:
train_job = self._data_manager.get_train_job(train_id)
except exceptions.TrainJobNotExistError:
logger.warning('Train job %s not existed', train_id)
continue

basic_info = train_job.get_basic_info()
train_job_item = dict(
train_id=basic_info.train_id,
relative_path=basic_info.train_id,
create_time=basic_info.create_time.strftime('%Y-%m-%d %H:%M:%S'),
update_time=basic_info.update_time.strftime('%Y-%m-%d %H:%M:%S'),
profiler_dir=basic_info.profiler_dir,
cache_status=train_job.cache_status.value,
)
plugins = self.get_plugins(train_id)
train_job_item.update(plugins)
train_jobs.append(train_job_item)

return total, train_jobs

def cache_train_jobs(self, train_ids):
"""
Cache train jobs.

Args:
train_ids (list): Specify list of train_ids to be cached.

Returns:
dict, indicates train job ID and its current cache status.
"""
brief_cache = self._data_manager.get_brief_cache()
brief_train_jobs = brief_cache.get_train_jobs()

for train_id in train_ids:
brief_train_job = brief_train_jobs.get(train_id)
if brief_train_job is None:
raise exceptions.TrainJobNotExistError(f'Train id {train_id} not exists')

cache_result = []
for train_id in train_ids:
brief_train_job = brief_train_jobs.get(train_id)
if brief_train_job.cache_status.value == CacheStatus.NOT_IN_CACHE.value:
self._data_manager.cache_train_job(train_id)

cache_result.append({
'train_id': train_id,
'cache_status': brief_train_job.cache_status.value,
})

return cache_result

Loading…
Cancel
Save