| @@ -17,6 +17,7 @@ | |||||
| import os | import os | ||||
| import re | import re | ||||
| import json | import json | ||||
| from json.decoder import JSONDecodeError | |||||
| from importlib import import_module | from importlib import import_module | ||||
| @@ -34,6 +35,9 @@ from mindspore.nn.optim.optimizer import Optimizer | |||||
| from mindspore.nn.loss.loss import _Loss | from mindspore.nn.loss.loss import _Loss | ||||
| from mindspore.train._utils import check_value_type | from mindspore.train._utils import check_value_type | ||||
| HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG" | |||||
| HYPER_CONFIG_LEN_LIMIT = 100000 | |||||
| class LineageMetadata: | class LineageMetadata: | ||||
| """Initialize parameters used in model lineage management.""" | """Initialize parameters used in model lineage management.""" | ||||
| @@ -188,8 +192,7 @@ class SummaryCollector(Callback): | |||||
| msg = f"For 'collect_specified_data' the value after processing is: {self._collect_specified_data}." | msg = f"For 'collect_specified_data' the value after processing is: {self._collect_specified_data}." | ||||
| logger.info(msg) | logger.info(msg) | ||||
| self._check_custom_lineage_data(custom_lineage_data) | |||||
| self._custom_lineage_data = custom_lineage_data | |||||
| self._custom_lineage_data = self._process_custom_lineage_data(custom_lineage_data) | |||||
| self._temp_optimizer = None | self._temp_optimizer = None | ||||
| self._has_saved_graph = False | self._has_saved_graph = False | ||||
| @@ -232,8 +235,7 @@ class SummaryCollector(Callback): | |||||
| if value <= 0: | if value <= 0: | ||||
| raise ValueError(f'For `{name}` the value should be greater than 0, but got `{value}`.') | raise ValueError(f'For `{name}` the value should be greater than 0, but got `{value}`.') | ||||
| @staticmethod | |||||
| def _check_custom_lineage_data(custom_lineage_data): | |||||
| def _process_custom_lineage_data(self, custom_lineage_data): | |||||
| """ | """ | ||||
| Check user custom lineage data. | Check user custom lineage data. | ||||
| @@ -244,12 +246,50 @@ class SummaryCollector(Callback): | |||||
| TypeError: If the type of parameters is invalid. | TypeError: If the type of parameters is invalid. | ||||
| """ | """ | ||||
| if custom_lineage_data is None: | if custom_lineage_data is None: | ||||
| return | |||||
| custom_lineage_data = {} | |||||
| self._check_custom_lineage_type('custom_lineage_data', custom_lineage_data) | |||||
| auto_custom_lineage_data = self._collect_optimizer_custom_lineage_data() | |||||
| self._check_custom_lineage_type('auto_custom_lineage_data', auto_custom_lineage_data) | |||||
| # the priority of user defined info is higher than auto collected info | |||||
| auto_custom_lineage_data.update(custom_lineage_data) | |||||
| custom_lineage_data = auto_custom_lineage_data | |||||
| return custom_lineage_data | |||||
| def _check_custom_lineage_type(self, param_name, custom_lineage): | |||||
| """Check custom lineage type.""" | |||||
| check_value_type(param_name, custom_lineage, [dict, type(None)]) | |||||
| for key, value in custom_lineage.items(): | |||||
| check_value_type(f'{param_name} -> {key}', key, str) | |||||
| check_value_type(f'the value of {param_name} -> {key}', value, (int, str, float)) | |||||
| def _collect_optimizer_custom_lineage_data(self): | |||||
| """Collect custom lineage data if mindoptimizer has set the hyper config""" | |||||
| auto_custom_lineage_data = {} | |||||
| hyper_config = os.environ.get(HYPER_CONFIG_ENV_NAME) | |||||
| if hyper_config is None: | |||||
| logger.debug("Hyper config is not in system environment.") | |||||
| return auto_custom_lineage_data | |||||
| if len(hyper_config) > HYPER_CONFIG_LEN_LIMIT: | |||||
| logger.warning("Hyper config is too long. The length limit is %s, the length of " | |||||
| "hyper_config is %s." % (HYPER_CONFIG_LEN_LIMIT, len(hyper_config))) | |||||
| return auto_custom_lineage_data | |||||
| check_value_type('custom_lineage_data', custom_lineage_data, [dict, type(None)]) | |||||
| for key, value in custom_lineage_data.items(): | |||||
| check_value_type(f'custom_lineage_data -> {key}', key, str) | |||||
| check_value_type(f'the value of custom_lineage_data -> {key}', value, (int, str, float)) | |||||
| try: | |||||
| hyper_config = json.loads(hyper_config) | |||||
| except (TypeError, JSONDecodeError) as exc: | |||||
| logger.warning("Hyper config decode error. Detail: %s." % str(exc)) | |||||
| return auto_custom_lineage_data | |||||
| custom_lineage_data = hyper_config.get("custom_lineage_data") | |||||
| if custom_lineage_data is None: | |||||
| logger.info("No custom lineage data in hyper config. Please check the custom lineage data " | |||||
| "if custom parameters exist in the configuration file.") | |||||
| auto_custom_lineage_data = custom_lineage_data if custom_lineage_data is not None else {} | |||||
| return auto_custom_lineage_data | |||||
| @staticmethod | @staticmethod | ||||
| def _check_action(action): | def _check_action(action): | ||||