|
|
@@ -17,6 +17,7 @@ |
|
|
import os |
|
|
import os |
|
|
import re |
|
|
import re |
|
|
import json |
|
|
import json |
|
|
|
|
|
from json.decoder import JSONDecodeError |
|
|
|
|
|
|
|
|
from importlib import import_module |
|
|
from importlib import import_module |
|
|
|
|
|
|
|
|
@@ -34,6 +35,9 @@ from mindspore.nn.optim.optimizer import Optimizer |
|
|
from mindspore.nn.loss.loss import _Loss |
|
|
from mindspore.nn.loss.loss import _Loss |
|
|
from mindspore.train._utils import check_value_type |
|
|
from mindspore.train._utils import check_value_type |
|
|
|
|
|
|
|
|
|
|
|
HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG" |
|
|
|
|
|
HYPER_CONFIG_LEN_LIMIT = 100000 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LineageMetadata: |
|
|
class LineageMetadata: |
|
|
"""Initialize parameters used in model lineage management.""" |
|
|
"""Initialize parameters used in model lineage management.""" |
|
|
@@ -188,8 +192,7 @@ class SummaryCollector(Callback): |
|
|
msg = f"For 'collect_specified_data' the value after processing is: {self._collect_specified_data}." |
|
|
msg = f"For 'collect_specified_data' the value after processing is: {self._collect_specified_data}." |
|
|
logger.info(msg) |
|
|
logger.info(msg) |
|
|
|
|
|
|
|
|
self._check_custom_lineage_data(custom_lineage_data) |
|
|
|
|
|
self._custom_lineage_data = custom_lineage_data |
|
|
|
|
|
|
|
|
self._custom_lineage_data = self._process_custom_lineage_data(custom_lineage_data) |
|
|
|
|
|
|
|
|
self._temp_optimizer = None |
|
|
self._temp_optimizer = None |
|
|
self._has_saved_graph = False |
|
|
self._has_saved_graph = False |
|
|
@@ -232,8 +235,7 @@ class SummaryCollector(Callback): |
|
|
if value <= 0: |
|
|
if value <= 0: |
|
|
raise ValueError(f'For `{name}` the value should be greater than 0, but got `{value}`.') |
|
|
raise ValueError(f'For `{name}` the value should be greater than 0, but got `{value}`.') |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
|
def _check_custom_lineage_data(custom_lineage_data): |
|
|
|
|
|
|
|
|
def _process_custom_lineage_data(self, custom_lineage_data): |
|
|
""" |
|
|
""" |
|
|
Check user custom lineage data. |
|
|
Check user custom lineage data. |
|
|
|
|
|
|
|
|
@@ -244,12 +246,50 @@ class SummaryCollector(Callback): |
|
|
TypeError: If the type of parameters is invalid. |
|
|
TypeError: If the type of parameters is invalid. |
|
|
""" |
|
|
""" |
|
|
if custom_lineage_data is None: |
|
|
if custom_lineage_data is None: |
|
|
return |
|
|
|
|
|
|
|
|
custom_lineage_data = {} |
|
|
|
|
|
self._check_custom_lineage_type('custom_lineage_data', custom_lineage_data) |
|
|
|
|
|
|
|
|
|
|
|
auto_custom_lineage_data = self._collect_optimizer_custom_lineage_data() |
|
|
|
|
|
self._check_custom_lineage_type('auto_custom_lineage_data', auto_custom_lineage_data) |
|
|
|
|
|
# the priority of user defined info is higher than auto collected info |
|
|
|
|
|
auto_custom_lineage_data.update(custom_lineage_data) |
|
|
|
|
|
custom_lineage_data = auto_custom_lineage_data |
|
|
|
|
|
|
|
|
|
|
|
return custom_lineage_data |
|
|
|
|
|
|
|
|
|
|
|
def _check_custom_lineage_type(self, param_name, custom_lineage): |
|
|
|
|
|
"""Check custom lineage type.""" |
|
|
|
|
|
check_value_type(param_name, custom_lineage, [dict, type(None)]) |
|
|
|
|
|
for key, value in custom_lineage.items(): |
|
|
|
|
|
check_value_type(f'{param_name} -> {key}', key, str) |
|
|
|
|
|
check_value_type(f'the value of {param_name} -> {key}', value, (int, str, float)) |
|
|
|
|
|
|
|
|
|
|
|
def _collect_optimizer_custom_lineage_data(self): |
|
|
|
|
|
"""Collect custom lineage data if mindoptimizer has set the hyper config""" |
|
|
|
|
|
auto_custom_lineage_data = {} |
|
|
|
|
|
|
|
|
|
|
|
hyper_config = os.environ.get(HYPER_CONFIG_ENV_NAME) |
|
|
|
|
|
if hyper_config is None: |
|
|
|
|
|
logger.debug("Hyper config is not in system environment.") |
|
|
|
|
|
return auto_custom_lineage_data |
|
|
|
|
|
if len(hyper_config) > HYPER_CONFIG_LEN_LIMIT: |
|
|
|
|
|
logger.warning("Hyper config is too long. The length limit is %s, the length of " |
|
|
|
|
|
"hyper_config is %s." % (HYPER_CONFIG_LEN_LIMIT, len(hyper_config))) |
|
|
|
|
|
return auto_custom_lineage_data |
|
|
|
|
|
|
|
|
check_value_type('custom_lineage_data', custom_lineage_data, [dict, type(None)]) |
|
|
|
|
|
for key, value in custom_lineage_data.items(): |
|
|
|
|
|
check_value_type(f'custom_lineage_data -> {key}', key, str) |
|
|
|
|
|
check_value_type(f'the value of custom_lineage_data -> {key}', value, (int, str, float)) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
hyper_config = json.loads(hyper_config) |
|
|
|
|
|
except (TypeError, JSONDecodeError) as exc: |
|
|
|
|
|
logger.warning("Hyper config decode error. Detail: %s." % str(exc)) |
|
|
|
|
|
return auto_custom_lineage_data |
|
|
|
|
|
|
|
|
|
|
|
custom_lineage_data = hyper_config.get("custom_lineage_data") |
|
|
|
|
|
if custom_lineage_data is None: |
|
|
|
|
|
logger.info("No custom lineage data in hyper config. Please check the custom lineage data " |
|
|
|
|
|
"if custom parameters exist in the configuration file.") |
|
|
|
|
|
auto_custom_lineage_data = custom_lineage_data if custom_lineage_data is not None else {} |
|
|
|
|
|
|
|
|
|
|
|
return auto_custom_lineage_data |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
@staticmethod |
|
|
def _check_action(action): |
|
|
def _check_action(action): |
|
|
|