You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_query_model.py 6.8 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Test the query_model module."""
  16. from unittest import TestCase
  17. from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageEventFieldNotExistException,
  18. LineageEventNotExistException)
  19. from mindinsight.lineagemgr.querier.query_model import LineageObj
  20. from . import event_data
  21. from .test_querier import create_filtration_result, create_lineage_info
  22. class TestLineageObj(TestCase):
  23. """Test the class of `LineageObj`."""
  24. def setUp(self):
  25. """Initialization before test case execution."""
  26. lineage_info = create_lineage_info(
  27. event_data.EVENT_TRAIN_DICT_0,
  28. event_data.EVENT_EVAL_DICT_0,
  29. event_data.EVENT_DATASET_DICT_0
  30. )
  31. self.summary_dir = '/path/to/summary0'
  32. self.lineage_obj = LineageObj(
  33. self.summary_dir,
  34. train_lineage=lineage_info.train_lineage,
  35. evaluation_lineage=lineage_info.eval_lineage,
  36. dataset_graph=lineage_info.dataset_graph,
  37. )
  38. lineage_info = create_lineage_info(
  39. event_data.EVENT_TRAIN_DICT_0,
  40. None, None)
  41. self.lineage_obj_no_eval = LineageObj(
  42. self.summary_dir,
  43. train_lineage=lineage_info.train_lineage,
  44. evaluation_lineage=lineage_info.eval_lineage
  45. )
  46. def test_property(self):
  47. """Test the function of getting property."""
  48. self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir)
  49. self.assertDictEqual(
  50. event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'],
  51. self.lineage_obj.algorithm
  52. )
  53. self.assertDictEqual(
  54. event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'],
  55. self.lineage_obj.model
  56. )
  57. self.assertDictEqual(
  58. event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'],
  59. self.lineage_obj.train_dataset
  60. )
  61. self.assertDictEqual(
  62. event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'],
  63. self.lineage_obj.hyper_parameters
  64. )
  65. self.assertDictEqual(event_data.METRIC_0, self.lineage_obj.metric)
  66. self.assertDictEqual(
  67. event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'],
  68. self.lineage_obj.valid_dataset
  69. )
  70. def test_property_eval_not_exist(self):
  71. """Test the function of getting property with no evaluation event."""
  72. self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir)
  73. self.assertDictEqual(
  74. event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'],
  75. self.lineage_obj_no_eval.algorithm
  76. )
  77. self.assertDictEqual(
  78. event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'],
  79. self.lineage_obj_no_eval.model
  80. )
  81. self.assertDictEqual(
  82. event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'],
  83. self.lineage_obj_no_eval.train_dataset
  84. )
  85. self.assertDictEqual(
  86. event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'],
  87. self.lineage_obj_no_eval.hyper_parameters
  88. )
  89. self.assertDictEqual({}, self.lineage_obj_no_eval.metric)
  90. self.assertDictEqual({}, self.lineage_obj_no_eval.valid_dataset)
  91. def test_get_summary_info(self):
  92. """Test the function of get_summary_info."""
  93. filter_keys = ['algorithm', 'model']
  94. expected_result = {
  95. 'summary_dir': self.summary_dir,
  96. 'algorithm': event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'],
  97. 'model': event_data.EVENT_TRAIN_DICT_0['train_lineage']['model']
  98. }
  99. result = self.lineage_obj.get_summary_info(filter_keys)
  100. self.assertDictEqual(expected_result, result)
  101. def test_to_model_lineage_dict(self):
  102. """Test the function of to_model_lineage_dict."""
  103. expected_result = create_filtration_result(
  104. self.summary_dir,
  105. event_data.EVENT_TRAIN_DICT_0,
  106. event_data.EVENT_EVAL_DICT_0,
  107. event_data.METRIC_0,
  108. event_data.DATASET_DICT_0
  109. )
  110. expected_result['model_lineage']['dataset_mark'] = None
  111. expected_result.pop('dataset_graph')
  112. result = self.lineage_obj.to_model_lineage_dict()
  113. self.assertDictEqual(expected_result, result)
  114. def test_to_dataset_lineage_dict(self):
  115. """Test the function of to_dataset_lineage_dict."""
  116. expected_result = {
  117. "summary_dir": self.summary_dir,
  118. "dataset_graph": event_data.DATASET_DICT_0
  119. }
  120. result = self.lineage_obj.to_dataset_lineage_dict()
  121. self.assertDictEqual(expected_result, result)
  122. def test_get_value_by_key(self):
  123. """Test the function of get_value_by_key."""
  124. result = self.lineage_obj.get_value_by_key('model_size')
  125. self.assertEqual(
  126. event_data.EVENT_TRAIN_DICT_0['train_lineage']['model']['size'],
  127. result
  128. )
  129. def test_init_fail(self):
  130. """Test the function of init with exception."""
  131. with self.assertRaises(LineageEventNotExistException):
  132. LineageObj(self.summary_dir)
  133. lineage_info = create_lineage_info(
  134. event_data.EVENT_TRAIN_DICT_EXCEPTION, None, None
  135. )
  136. with self.assertRaises(LineageEventFieldNotExistException):
  137. self.lineage_obj = LineageObj(
  138. self.summary_dir,
  139. train_lineage=lineage_info.train_lineage,
  140. evaluation_lineage=lineage_info.eval_lineage
  141. )
  142. lineage_info = create_lineage_info(
  143. event_data.EVENT_TRAIN_DICT_0,
  144. event_data.EVENT_EVAL_DICT_EXCEPTION,
  145. event_data.EVENT_DATASET_DICT_0
  146. )
  147. with self.assertRaises(LineageEventFieldNotExistException):
  148. self.lineage_obj = LineageObj(
  149. self.summary_dir,
  150. train_lineage=lineage_info.train_lineage,
  151. evaluation_lineage=lineage_info.eval_lineage
  152. )

MindInsight为MindSpore提供了简单易用的调优调试能力。在训练过程中,可以将标量、张量、图像、计算图、模型超参、训练耗时等数据记录到文件中,通过MindInsight可视化页面进行查看及分析。