You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

query_model.py 13 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago

  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This file is used to define lineage info model."""
  16. import json
  17. from collections import namedtuple
  18. from google.protobuf.json_format import MessageToDict
  19. from mindinsight.lineagemgr.common.exceptions.exceptions import \
  20. LineageEventFieldNotExistException, LineageEventNotExistException
  21. from mindinsight.lineagemgr.summary._summary_adapter import organize_graph
  22. Field = namedtuple('Field', ['base_name', 'sub_name'])
  23. FIELD_MAPPING = {
  24. "summary_dir": Field('summary_dir', None),
  25. "loss_function": Field("hyper_parameters", 'loss_function'),
  26. "train_dataset_path": Field('train_dataset', 'train_dataset_path'),
  27. "train_dataset_count": Field("train_dataset", 'train_dataset_size'),
  28. "test_dataset_path": Field('valid_dataset', 'valid_dataset_path'),
  29. "test_dataset_count": Field('valid_dataset', 'valid_dataset_size'),
  30. "network": Field('algorithm', 'network'),
  31. "optimizer": Field('hyper_parameters', 'optimizer'),
  32. "learning_rate": Field('hyper_parameters', 'learning_rate'),
  33. "epoch": Field('hyper_parameters', 'epoch'),
  34. "batch_size": Field('hyper_parameters', 'batch_size'),
  35. "loss": Field('algorithm', 'loss'),
  36. "model_size": Field('model', 'size'),
  37. "dataset_mark": Field('dataset_mark', None)
  38. }
  39. class LineageObj:
  40. """
  41. Lineage information class.
  42. An instance of the class hold lineage information for a training session.
  43. Args:
  44. summary_dir (str): Summary log dir.
  45. kwargs (dict): Params to init the instance.
  46. - train_lineage (Event): Train lineage object.
  47. - evaluation_lineage (Event): Evaluation lineage object.
  48. - dataset_graph (Event): Dataset graph object.
  49. - user_defined_info (Event): User defined info object.
  50. Raises:
  51. LineageEventNotExistException: If train and evaluation event not exist.
  52. LineageEventFieldNotExistException: If the special event field not exist.
  53. """
  54. _name_train_lineage = 'train_lineage'
  55. _name_evaluation_lineage = 'evaluation_lineage'
  56. _name_summary_dir = 'summary_dir'
  57. _name_metric = 'metric'
  58. _name_hyper_parameters = 'hyper_parameters'
  59. _name_algorithm = 'algorithm'
  60. _name_train_dataset = 'train_dataset'
  61. _name_model = 'model'
  62. _name_valid_dataset = 'valid_dataset'
  63. _name_dataset_graph = 'dataset_graph'
  64. _name_dataset_mark = 'dataset_mark'
  65. _name_user_defined = 'user_defined'
  66. _name_model_lineage = 'model_lineage'
  67. def __init__(self, summary_dir, **kwargs):
  68. self._lineage_info = {
  69. self._name_summary_dir: summary_dir
  70. }
  71. user_defined_info_list = kwargs.get('user_defined_info', [])
  72. train_lineage = kwargs.get('train_lineage')
  73. evaluation_lineage = kwargs.get('evaluation_lineage')
  74. dataset_graph = kwargs.get('dataset_graph')
  75. if not any([train_lineage, evaluation_lineage, dataset_graph]):
  76. raise LineageEventNotExistException()
  77. self._parse_user_defined_info(user_defined_info_list)
  78. self._parse_train_lineage(train_lineage)
  79. self._parse_evaluation_lineage(evaluation_lineage)
  80. self._parse_dataset_graph(dataset_graph)
  81. self._filtration_result = self._organize_filtration_result()
  82. @property
  83. def summary_dir(self):
  84. """
  85. Get summary log dir.
  86. Returns:
  87. str, the summary log dir.
  88. """
  89. return self._lineage_info.get(self._name_summary_dir)
  90. @property
  91. def metric(self):
  92. """
  93. Get metric information.
  94. Returns:
  95. dict, the metric information.
  96. """
  97. return self._lineage_info.get(self._name_metric)
  98. @property
  99. def user_defined(self):
  100. """
  101. Get user defined information.
  102. Returns:
  103. dict, the user defined information.
  104. """
  105. return self._lineage_info.get(self._name_user_defined)
  106. @property
  107. def hyper_parameters(self):
  108. """
  109. Get hyperparameters.
  110. Returns:
  111. dict, the hyperparameters.
  112. """
  113. return self._lineage_info.get(self._name_hyper_parameters)
  114. @property
  115. def algorithm(self):
  116. """
  117. Get algorithm.
  118. Returns:
  119. dict, the algorithm.
  120. """
  121. return self._lineage_info.get(self._name_algorithm)
  122. @property
  123. def train_dataset(self):
  124. """
  125. Get train dataset information.
  126. Returns:
  127. dict, the train dataset information.
  128. """
  129. return self._lineage_info.get(self._name_train_dataset)
  130. @property
  131. def model(self):
  132. """
  133. Get model information.
  134. Returns:
  135. dict, the model information.
  136. """
  137. return self._lineage_info.get(self._name_model)
  138. @property
  139. def valid_dataset(self):
  140. """
  141. Get valid dataset information.
  142. Returns:
  143. dict, the valid dataset information.
  144. """
  145. return self._lineage_info.get(self._name_valid_dataset)
  146. @property
  147. def dataset_graph(self):
  148. """
  149. Get dataset_graph.
  150. Returns:
  151. dict, the dataset graph information.
  152. """
  153. return self._lineage_info.get(self._name_dataset_graph)
  154. @property
  155. def dataset_mark(self):
  156. """
  157. Get dataset_mark.
  158. Returns:
  159. dict, the dataset mark information.
  160. """
  161. return self._lineage_info.get(self._name_dataset_mark)
  162. @dataset_mark.setter
  163. def dataset_mark(self, dataset_mark):
  164. """
  165. Set dataset mark.
  166. Args:
  167. dataset_mark (int): Dataset mark.
  168. """
  169. self._lineage_info[self._name_dataset_mark] = dataset_mark
  170. # update dataset_mark into filtration result
  171. self._filtration_result[self._name_dataset_mark] = dataset_mark
  172. def get_summary_info(self, filter_keys: list):
  173. """
  174. Get the summary lineage information.
  175. Returns the content corresponding to the specified field in the filter
  176. key. The contents of the filter key include `metric`, `hyper_parameters`,
  177. `algorithm`, `train_dataset`, `valid_dataset` and `model`. You can
  178. specify multiple filter keys in the `filter_keys`
  179. Args:
  180. filter_keys (list): Filter keys.
  181. Returns:
  182. dict, the summary lineage information.
  183. """
  184. result = {
  185. self._name_summary_dir: self.summary_dir,
  186. }
  187. for key in filter_keys:
  188. result[key] = getattr(self, key)
  189. return result
  190. def to_dataset_lineage_dict(self):
  191. """
  192. Returns the dataset part lineage information.
  193. Returns:
  194. dict, the dataset lineage information.
  195. """
  196. dataset_lineage = {
  197. key: self._filtration_result.get(key)
  198. for key in [self._name_summary_dir, self._name_dataset_graph]
  199. }
  200. return dataset_lineage
  201. def to_model_lineage_dict(self):
  202. """
  203. Returns the model part lineage information.
  204. Returns:
  205. dict, the model lineage information.
  206. """
  207. filtration_result = dict(self._filtration_result)
  208. filtration_result.pop(self._name_dataset_graph)
  209. model_lineage = dict()
  210. model_lineage.update({self._name_summary_dir: filtration_result.pop(self._name_summary_dir)})
  211. model_lineage.update({self._name_model_lineage: filtration_result})
  212. return model_lineage
  213. def get_value_by_key(self, key):
  214. """
  215. Get the value based on the key in `FIELD_MAPPING` or
  216. the key prefixed with `metric/` or `user_defined/`.
  217. Args:
  218. key (str): The key in `FIELD_MAPPING`
  219. or prefixed with `metric/` or `user_defined/`.
  220. Returns:
  221. object, the value.
  222. """
  223. if key.startswith(('metric/', 'user_defined/')):
  224. key_name, sub_key = key.split('/', 1)
  225. sub_value_name = self._name_metric if key_name == 'metric' else self._name_user_defined
  226. sub_value = self._filtration_result.get(sub_value_name)
  227. if sub_value:
  228. return sub_value.get(sub_key)
  229. return self._filtration_result.get(key)
  230. def _organize_filtration_result(self):
  231. """
  232. Organize filtration result.
  233. Returns:
  234. dict, the filtration result.
  235. """
  236. result = {}
  237. for key, field in FIELD_MAPPING.items():
  238. if field.base_name is not None:
  239. base_attr = getattr(self, field.base_name)
  240. result[key] = base_attr.get(field.sub_name) \
  241. if field.sub_name else base_attr
  242. # add metric into filtration result
  243. result[self._name_metric] = self.metric
  244. result[self._name_user_defined] = self.user_defined
  245. # add dataset_graph into filtration result
  246. result[self._name_dataset_graph] = getattr(self, self._name_dataset_graph)
  247. return result
  248. def _parse_train_lineage(self, train_lineage):
  249. """
  250. Parse train lineage.
  251. Args:
  252. train_lineage (Event): Train lineage.
  253. """
  254. if train_lineage is None:
  255. self._lineage_info[self._name_model] = {}
  256. self._lineage_info[self._name_algorithm] = {}
  257. self._lineage_info[self._name_hyper_parameters] = {}
  258. self._lineage_info[self._name_train_dataset] = {}
  259. return
  260. event_dict = MessageToDict(
  261. train_lineage, preserving_proto_field_name=True
  262. )
  263. train_dict = event_dict.get(self._name_train_lineage)
  264. if train_dict is None:
  265. raise LineageEventFieldNotExistException(
  266. self._name_train_lineage
  267. )
  268. # when MessageToDict is converted to dict, int64 type is converted
  269. # to string, so we convert it to an int in python
  270. if train_dict.get(self._name_model):
  271. model_size = train_dict.get(self._name_model).get('size')
  272. if model_size:
  273. train_dict[self._name_model]['size'] = int(model_size)
  274. self._lineage_info.update(**train_dict)
  275. def _parse_evaluation_lineage(self, evaluation_lineage):
  276. """
  277. Parse evaluation lineage.
  278. Args:
  279. evaluation_lineage (Event): Evaluation lineage.
  280. """
  281. if evaluation_lineage is None:
  282. self._lineage_info[self._name_metric] = {}
  283. self._lineage_info[self._name_valid_dataset] = {}
  284. return
  285. event_dict = MessageToDict(
  286. evaluation_lineage, preserving_proto_field_name=True
  287. )
  288. evaluation_dict = event_dict.get(self._name_evaluation_lineage)
  289. if evaluation_dict is None:
  290. raise LineageEventFieldNotExistException(
  291. self._name_evaluation_lineage
  292. )
  293. self._lineage_info.update(**evaluation_dict)
  294. metric = self._lineage_info.get(self._name_metric)
  295. self._lineage_info[self._name_metric] = json.loads(metric) if metric else {}
  296. def _parse_dataset_graph(self, dataset_graph):
  297. """
  298. Parse dataset graph.
  299. Args:
  300. dataset_graph (Event): Dataset graph.
  301. """
  302. if dataset_graph is None:
  303. self._lineage_info[self._name_dataset_graph] = {}
  304. else:
  305. # convert message to dict
  306. event_dict = organize_graph(dataset_graph.dataset_graph)
  307. if event_dict is None:
  308. raise LineageEventFieldNotExistException(self._name_evaluation_lineage)
  309. self._lineage_info[self._name_dataset_graph] = event_dict if event_dict else {}
  310. def _parse_user_defined_info(self, user_defined_info_list):
  311. """
  312. Parse user defined info.
  313. Args:
  314. user_defined_info_list (list): user defined info list.
  315. """
  316. user_defined_infos = dict()
  317. for user_defined_info in user_defined_info_list:
  318. user_defined_infos.update(user_defined_info)
  319. self._lineage_info[self._name_user_defined] = user_defined_infos

MindInsight为MindSpore提供了简单易用的调优调试能力。在训练过程中,可以将标量、张量、图像、计算图、模型超参、训练耗时等数据记录到文件中,通过MindInsight可视化页面进行查看及分析。