|
|
|
@@ -14,6 +14,7 @@ |
|
|
|
# ============================================================================ |
|
|
|
"""Train task manager.""" |
|
|
|
|
|
|
|
from mindinsight.utils.exceptions import ParamTypeError |
|
|
|
from mindinsight.datavisual.common.log import logger |
|
|
|
from mindinsight.datavisual.common import exceptions |
|
|
|
from mindinsight.datavisual.common.enums import PluginNameEnum |
|
|
|
@@ -141,9 +142,20 @@ class TrainTaskManager(BaseProcessor): |
|
|
|
|
|
|
|
Returns: |
|
|
|
dict, indicates train job ID and its current cache status. |
|
|
|
|
|
|
|
Raises: |
|
|
|
ParamTypeError, if the given train_ids parameter is not in valid type. |
|
|
|
""" |
|
|
|
if not isinstance(train_ids, list): |
|
|
|
logger.error("train_ids must be list.") |
|
|
|
raise ParamTypeError('train_ids', list) |
|
|
|
|
|
|
|
cache_result = [] |
|
|
|
for train_id in train_ids: |
|
|
|
if not isinstance(train_id, str): |
|
|
|
logger.error("train_id must be str.") |
|
|
|
raise ParamTypeError('train_id', str) |
|
|
|
|
|
|
|
try: |
|
|
|
train_job = self._data_manager.get_train_job(train_id) |
|
|
|
except exceptions.TrainJobNotExistError: |
|
|
|
|