|
|
|
@@ -31,6 +31,8 @@ from mindspore.train.callback import RunContext, ModelCheckpoint, SummaryStep |
|
|
|
from mindspore.train.summary import SummaryRecord |
|
|
|
|
|
|
|
|
|
|
|
@mock.patch('builtins.open') |
|
|
|
@mock.patch('os.makedirs') |
|
|
|
class TestModelLineage(TestCase): |
|
|
|
"""Test TrainLineage and EvalLineage class in model_lineage.py.""" |
|
|
|
|
|
|
|
@@ -51,9 +53,9 @@ class TestModelLineage(TestCase): |
|
|
|
cls.summary_log_path = '/path/to/summary_log' |
|
|
|
|
|
|
|
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record') |
|
|
|
def test_summary_record_exception(self, mock_validate_summary): |
|
|
|
def test_summary_record_exception(self, *args): |
|
|
|
"""Test SummaryRecord with exception.""" |
|
|
|
mock_validate_summary.return_value = None |
|
|
|
args[0].return_value = None |
|
|
|
summary_record = self.my_summary_record(self.summary_log_path) |
|
|
|
with self.assertRaises(MindInsightException) as context: |
|
|
|
self.my_train_module(summary_record=summary_record, raise_exception=1) |
|
|
|
@@ -150,9 +152,9 @@ class TestModelLineage(TestCase): |
|
|
|
args[6].assert_called() |
|
|
|
|
|
|
|
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record') |
|
|
|
def test_train_end_exception(self, mock_validate_summary): |
|
|
|
def test_train_end_exception(self, *args): |
|
|
|
"""Test TrainLineage.end method when exception.""" |
|
|
|
mock_validate_summary.return_value = True |
|
|
|
args[0].return_value = True |
|
|
|
train_lineage = self.my_train_module(self.my_summary_record(self.summary_log_path), True) |
|
|
|
with self.assertRaises(Exception) as context: |
|
|
|
train_lineage.end(self.run_context) |
|
|
|
@@ -218,9 +220,9 @@ class TestModelLineage(TestCase): |
|
|
|
self.assertTrue('End error in TrainLineage:' in str(context.exception)) |
|
|
|
|
|
|
|
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record') |
|
|
|
def test_eval_exception_train_id_none(self, mock_validate_summary): |
|
|
|
def test_eval_exception_train_id_none(self, *args): |
|
|
|
"""Test EvalLineage.end method with initialization error.""" |
|
|
|
mock_validate_summary.return_value = True |
|
|
|
args[0].return_value = True |
|
|
|
with self.assertRaises(MindInsightException) as context: |
|
|
|
self.my_eval_module(self.my_summary_record(self.summary_log_path), raise_exception=2) |
|
|
|
self.assertTrue('Invalid value for raise_exception.' in str(context.exception)) |
|
|
|
@@ -242,9 +244,9 @@ class TestModelLineage(TestCase): |
|
|
|
args[0].assert_called() |
|
|
|
|
|
|
|
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record') |
|
|
|
def test_eval_end_except_run_context(self, mock_validate_summary): |
|
|
|
def test_eval_end_except_run_context(self, *args): |
|
|
|
"""Test EvalLineage.end method when run_context is invalid..""" |
|
|
|
mock_validate_summary.return_value = True |
|
|
|
args[0].return_value = True |
|
|
|
eval_lineage = self.my_eval_module(self.my_summary_record(self.summary_log_path), True) |
|
|
|
with self.assertRaises(Exception) as context: |
|
|
|
eval_lineage.end(self.run_context) |
|
|
|
@@ -284,8 +286,9 @@ class TestModelLineage(TestCase): |
|
|
|
eval_lineage.end(self.my_run_context(self.run_context)) |
|
|
|
self.assertTrue('End error in EvalLineage' in str(context.exception)) |
|
|
|
|
|
|
|
def test_epoch_is_zero(self): |
|
|
|
def test_epoch_is_zero(self, *args): |
|
|
|
"""Test TrainLineage.end method.""" |
|
|
|
args[0].return_value = None |
|
|
|
run_context = self.run_context |
|
|
|
run_context['epoch_num'] = 0 |
|
|
|
with self.assertRaises(MindInsightException): |
|
|
|
@@ -345,7 +348,7 @@ class TestAnalyzer(TestCase): |
|
|
|
) |
|
|
|
res1 = self.analyzer.analyze_dataset(dataset, {'step_num': 10, 'epoch': 2}, 'train') |
|
|
|
res2 = self.analyzer.analyze_dataset(dataset, {'step_num': 5}, 'valid') |
|
|
|
assert res1 == {'step_num': 10, |
|
|
|
assert res1 == {'step_num': 10, |
|
|
|
'train_dataset_path': '/path/to', |
|
|
|
'train_dataset_size': 50, |
|
|
|
'epoch': 2} |
|
|
|
|