You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

query_model.py 13 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This file is used to define lineage info model."""
  16. import json
  17. from collections import namedtuple
  18. from google.protobuf.json_format import MessageToDict
  19. from mindinsight.lineagemgr.common.exceptions.exceptions import \
  20. LineageEventFieldNotExistException, LineageEventNotExistException
  21. from mindinsight.lineagemgr.summary._summary_adapter import organize_graph
  22. Field = namedtuple('Field', ['base_name', 'sub_name'])
  23. FIELD_MAPPING = {
  24. "summary_dir": Field('summary_dir', None),
  25. "loss_function": Field("hyper_parameters", 'loss_function'),
  26. "train_dataset_path": Field('train_dataset', 'train_dataset_path'),
  27. "train_dataset_count": Field("train_dataset", 'train_dataset_size'),
  28. "test_dataset_path": Field('valid_dataset', 'valid_dataset_path'),
  29. "test_dataset_count": Field('valid_dataset', 'valid_dataset_size'),
  30. "network": Field('algorithm', 'network'),
  31. "optimizer": Field('hyper_parameters', 'optimizer'),
  32. "learning_rate": Field('hyper_parameters', 'learning_rate'),
  33. "epoch": Field('hyper_parameters', 'epoch'),
  34. "batch_size": Field('hyper_parameters', 'batch_size'),
  35. "device_num": Field('hyper_parameters', 'device_num'),
  36. "loss": Field('algorithm', 'loss'),
  37. "model_size": Field('model', 'size'),
  38. "dataset_mark": Field('dataset_mark', None)
  39. }
  40. class LineageObj:
  41. """
  42. Lineage information class.
  43. An instance of the class hold lineage information for a training session.
  44. Args:
  45. summary_dir (str): Summary log dir.
  46. kwargs (dict): Params to init the instance.
  47. - train_lineage (Event): Train lineage object.
  48. - evaluation_lineage (Event): Evaluation lineage object.
  49. - dataset_graph (Event): Dataset graph object.
  50. - user_defined_info (Event): User defined info object.
  51. Raises:
  52. LineageEventNotExistException: If train and evaluation event not exist.
  53. LineageEventFieldNotExistException: If the special event field not exist.
  54. """
  55. _name_train_lineage = 'train_lineage'
  56. _name_evaluation_lineage = 'evaluation_lineage'
  57. _name_summary_dir = 'summary_dir'
  58. _name_metric = 'metric'
  59. _name_hyper_parameters = 'hyper_parameters'
  60. _name_algorithm = 'algorithm'
  61. _name_train_dataset = 'train_dataset'
  62. _name_model = 'model'
  63. _name_valid_dataset = 'valid_dataset'
  64. _name_dataset_graph = 'dataset_graph'
  65. _name_dataset_mark = 'dataset_mark'
  66. _name_user_defined = 'user_defined'
  67. _name_model_lineage = 'model_lineage'
  68. def __init__(self, summary_dir, **kwargs):
  69. self._lineage_info = {
  70. self._name_summary_dir: summary_dir
  71. }
  72. self._filtration_result = None
  73. self._init_lineage()
  74. self.parse_and_update_lineage(**kwargs)
  75. def _init_lineage(self):
  76. """Init lineage info."""
  77. # train
  78. self._lineage_info[self._name_model] = {}
  79. self._lineage_info[self._name_algorithm] = {}
  80. self._lineage_info[self._name_hyper_parameters] = {}
  81. self._lineage_info[self._name_train_dataset] = {}
  82. # eval
  83. self._lineage_info[self._name_metric] = {}
  84. self._lineage_info[self._name_valid_dataset] = {}
  85. # dataset graph
  86. self._lineage_info[self._name_dataset_graph] = {}
  87. # user defined
  88. self._lineage_info[self._name_user_defined] = {}
  89. def parse_and_update_lineage(self, **kwargs):
  90. """Parse and update lineage."""
  91. user_defined_info_list = kwargs.get('user_defined_info', [])
  92. train_lineage = kwargs.get('train_lineage')
  93. evaluation_lineage = kwargs.get('evaluation_lineage')
  94. dataset_graph = kwargs.get('dataset_graph')
  95. if not any([train_lineage, evaluation_lineage, dataset_graph]):
  96. raise LineageEventNotExistException()
  97. # If new train lineage, will clean the lineage saved before.
  98. if train_lineage is not None or dataset_graph is not None:
  99. self._init_lineage()
  100. self._parse_user_defined_info(user_defined_info_list)
  101. self._parse_train_lineage(train_lineage)
  102. self._parse_evaluation_lineage(evaluation_lineage)
  103. self._parse_dataset_graph(dataset_graph)
  104. self._filtration_result = self._organize_filtration_result()
  105. @property
  106. def summary_dir(self):
  107. """
  108. Get summary log dir.
  109. Returns:
  110. str, the summary log dir.
  111. """
  112. return self._lineage_info.get(self._name_summary_dir)
  113. @property
  114. def metric(self):
  115. """
  116. Get metric information.
  117. Returns:
  118. dict, the metric information.
  119. """
  120. return self._lineage_info.get(self._name_metric)
  121. @property
  122. def user_defined(self):
  123. """
  124. Get user defined information.
  125. Returns:
  126. dict, the user defined information.
  127. """
  128. return self._lineage_info.get(self._name_user_defined)
  129. @property
  130. def hyper_parameters(self):
  131. """
  132. Get hyperparameters.
  133. Returns:
  134. dict, the hyperparameters.
  135. """
  136. return self._lineage_info.get(self._name_hyper_parameters)
  137. @property
  138. def algorithm(self):
  139. """
  140. Get algorithm.
  141. Returns:
  142. dict, the algorithm.
  143. """
  144. return self._lineage_info.get(self._name_algorithm)
  145. @property
  146. def train_dataset(self):
  147. """
  148. Get train dataset information.
  149. Returns:
  150. dict, the train dataset information.
  151. """
  152. return self._lineage_info.get(self._name_train_dataset)
  153. @property
  154. def model(self):
  155. """
  156. Get model information.
  157. Returns:
  158. dict, the model information.
  159. """
  160. return self._lineage_info.get(self._name_model)
  161. @property
  162. def valid_dataset(self):
  163. """
  164. Get valid dataset information.
  165. Returns:
  166. dict, the valid dataset information.
  167. """
  168. return self._lineage_info.get(self._name_valid_dataset)
  169. @property
  170. def dataset_graph(self):
  171. """
  172. Get dataset_graph.
  173. Returns:
  174. dict, the dataset graph information.
  175. """
  176. return self._lineage_info.get(self._name_dataset_graph)
  177. @property
  178. def dataset_mark(self):
  179. """
  180. Get dataset_mark.
  181. Returns:
  182. dict, the dataset mark information.
  183. """
  184. return self._lineage_info.get(self._name_dataset_mark)
  185. @dataset_mark.setter
  186. def dataset_mark(self, dataset_mark):
  187. """
  188. Set dataset mark.
  189. Args:
  190. dataset_mark (int): Dataset mark.
  191. """
  192. self._lineage_info[self._name_dataset_mark] = dataset_mark
  193. # update dataset_mark into filtration result
  194. self._filtration_result[self._name_dataset_mark] = dataset_mark
  195. def get_summary_info(self, filter_keys: list):
  196. """
  197. Get the summary lineage information.
  198. Returns the content corresponding to the specified field in the filter
  199. key. The contents of the filter key include `metric`, `hyper_parameters`,
  200. `algorithm`, `train_dataset`, `valid_dataset` and `model`. You can
  201. specify multiple filter keys in the `filter_keys`
  202. Args:
  203. filter_keys (list): Filter keys.
  204. Returns:
  205. dict, the summary lineage information.
  206. """
  207. result = {
  208. self._name_summary_dir: self.summary_dir,
  209. }
  210. for key in filter_keys:
  211. result[key] = getattr(self, key)
  212. return result
  213. def to_dataset_lineage_dict(self):
  214. """
  215. Returns the dataset part lineage information.
  216. Returns:
  217. dict, the dataset lineage information.
  218. """
  219. dataset_lineage = {
  220. key: self._filtration_result.get(key)
  221. for key in [self._name_summary_dir, self._name_dataset_graph]
  222. }
  223. return dataset_lineage
  224. def to_model_lineage_dict(self):
  225. """
  226. Returns the model part lineage information.
  227. Returns:
  228. dict, the model lineage information.
  229. """
  230. filtration_result = dict(self._filtration_result)
  231. filtration_result.pop(self._name_dataset_graph)
  232. model_lineage = dict()
  233. model_lineage.update({self._name_summary_dir: filtration_result.pop(self._name_summary_dir)})
  234. model_lineage.update({self._name_model_lineage: filtration_result})
  235. return model_lineage
  236. def get_value_by_key(self, key):
  237. """
  238. Get the value based on the key in `FIELD_MAPPING` or
  239. the key prefixed with `metric/` or `user_defined/`.
  240. Args:
  241. key (str): The key in `FIELD_MAPPING`
  242. or prefixed with `metric/` or `user_defined/`.
  243. Returns:
  244. object, the value.
  245. """
  246. if key.startswith(('metric/', 'user_defined/')):
  247. key_name, sub_key = key.split('/', 1)
  248. sub_value_name = self._name_metric if key_name == 'metric' else self._name_user_defined
  249. sub_value = self._filtration_result.get(sub_value_name)
  250. if sub_value:
  251. return sub_value.get(sub_key)
  252. return self._filtration_result.get(key)
  253. def _organize_filtration_result(self):
  254. """
  255. Organize filtration result.
  256. Returns:
  257. dict, the filtration result.
  258. """
  259. result = {}
  260. for key, field in FIELD_MAPPING.items():
  261. if field.base_name is not None:
  262. base_attr = getattr(self, field.base_name)
  263. result[key] = base_attr.get(field.sub_name) \
  264. if field.sub_name else base_attr
  265. # add metric into filtration result
  266. result[self._name_metric] = self.metric
  267. result[self._name_user_defined] = self.user_defined
  268. # add dataset_graph into filtration result
  269. result[self._name_dataset_graph] = getattr(self, self._name_dataset_graph)
  270. return result
  271. def _parse_train_lineage(self, train_lineage):
  272. """
  273. Parse train lineage.
  274. Args:
  275. train_lineage (Event): Train lineage.
  276. """
  277. if train_lineage is None:
  278. return
  279. event_dict = MessageToDict(
  280. train_lineage, preserving_proto_field_name=True
  281. )
  282. train_dict = event_dict.get(self._name_train_lineage)
  283. if train_dict is None:
  284. raise LineageEventFieldNotExistException(
  285. self._name_train_lineage
  286. )
  287. # when MessageToDict is converted to dict, int64 type is converted
  288. # to string, so we convert it to an int in python
  289. if train_dict.get(self._name_model):
  290. model_size = train_dict.get(self._name_model).get('size')
  291. if model_size:
  292. train_dict[self._name_model]['size'] = int(model_size)
  293. self._lineage_info.update(**train_dict)
  294. def _parse_evaluation_lineage(self, evaluation_lineage):
  295. """
  296. Parse evaluation lineage.
  297. Args:
  298. evaluation_lineage (Event): Evaluation lineage.
  299. """
  300. if evaluation_lineage is None:
  301. return
  302. event_dict = MessageToDict(
  303. evaluation_lineage, preserving_proto_field_name=True
  304. )
  305. evaluation_dict = event_dict.get(self._name_evaluation_lineage)
  306. if evaluation_dict is None:
  307. raise LineageEventFieldNotExistException(
  308. self._name_evaluation_lineage
  309. )
  310. self._lineage_info.update(**evaluation_dict)
  311. metric = self._lineage_info.get(self._name_metric)
  312. self._lineage_info[self._name_metric] = json.loads(metric) if metric else {}
  313. def _parse_dataset_graph(self, dataset_graph):
  314. """
  315. Parse dataset graph.
  316. Args:
  317. dataset_graph (Event): Dataset graph.
  318. """
  319. if dataset_graph is not None:
  320. # convert message to dict
  321. event_dict = organize_graph(dataset_graph.dataset_graph)
  322. if event_dict is None:
  323. raise LineageEventFieldNotExistException(self._name_evaluation_lineage)
  324. self._lineage_info[self._name_dataset_graph] = event_dict if event_dict else {}
  325. def _parse_user_defined_info(self, user_defined_info_list):
  326. """
  327. Parse user defined info.
  328. Args:
  329. user_defined_info_list (list): user defined info list.
  330. """
  331. if not user_defined_info_list:
  332. return
  333. user_defined_infos = dict()
  334. for user_defined_info in user_defined_info_list:
  335. user_defined_infos.update(user_defined_info)
  336. self._lineage_info[self._name_user_defined] = user_defined_infos