|
|
|
@@ -50,6 +50,9 @@ def get_value(): |
|
|
|
_VALUE_CACHE = list() |
|
|
|
return value |
|
|
|
|
|
|
|
_SPECIFIED_DATA = SummaryCollector._DEFAULT_SPECIFIED_DATA |
|
|
|
_SPECIFIED_DATA['collect_metric'] = False |
|
|
|
|
|
|
|
|
|
|
|
class CustomNet(Cell): |
|
|
|
"""Define custom netwrok.""" |
|
|
|
@@ -190,8 +193,8 @@ class TestSummaryCollector: |
|
|
|
data = {'unexpected_key': True} |
|
|
|
with pytest.raises(ValueError) as exc: |
|
|
|
SummaryCollector(summary_dir, collect_specified_data=data) |
|
|
|
expected_msg = f"For `collect_specified_data` the keys {set(data)} are unsupported." |
|
|
|
assert expected_msg == str(exc.value) |
|
|
|
expected_msg = f"For `collect_specified_data` the keys {set(data)} are unsupported" |
|
|
|
assert expected_msg in str(exc.value) |
|
|
|
|
|
|
|
@pytest.mark.parametrize("custom_lineage_data", [ |
|
|
|
123, |
|
|
|
@@ -273,12 +276,16 @@ class TestSummaryCollector: |
|
|
|
assert name == 'train_dataset' |
|
|
|
|
|
|
|
@pytest.mark.parametrize("net_output, expected_loss", [ |
|
|
|
(None, None), |
|
|
|
(1, Tensor(1)), |
|
|
|
(1.5, Tensor(1.5)), |
|
|
|
(Tensor(1), Tensor(1)), |
|
|
|
([1], Tensor(1)), |
|
|
|
([Tensor(1)], Tensor(1)), |
|
|
|
(Tensor([1]), Tensor(1)), |
|
|
|
({}, None), |
|
|
|
(Tensor([[1, 2], [3, 4]]), Tensor(2.5)), |
|
|
|
([Tensor([[3, 4, 3]]), Tensor([3, 4])], Tensor(3.33333)), |
|
|
|
(tuple([1]), Tensor(1)), |
|
|
|
(None, None) |
|
|
|
]) |
|
|
|
def test_get_loss(self, net_output, expected_loss): |
|
|
|
"""Test get loss success and failed.""" |
|
|
|
@@ -375,3 +382,20 @@ class TestSummaryCollector: |
|
|
|
assert PluginEnum.HISTOGRAM.value == result[0][0] |
|
|
|
assert expected_names == [data[1] for data in result] |
|
|
|
assert expected_values == [data[2] for data in result] |
|
|
|
|
|
|
|
@pytest.mark.parametrize("specified_data, action, expected_result", [ |
|
|
|
(None, True, SummaryCollector._DEFAULT_SPECIFIED_DATA), |
|
|
|
(None, False, {}), |
|
|
|
({}, True, SummaryCollector._DEFAULT_SPECIFIED_DATA), |
|
|
|
({}, False, {}), |
|
|
|
({'collect_metric': False}, True, _SPECIFIED_DATA), |
|
|
|
({'collect_metric': True}, False, {'collect_metric': True}) |
|
|
|
]) |
|
|
|
def test_process_specified_data(self, specified_data, action, expected_result): |
|
|
|
"""Test process specified data.""" |
|
|
|
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) |
|
|
|
summary_collector = SummaryCollector(summary_dir, |
|
|
|
collect_specified_data=specified_data, |
|
|
|
keep_default_action=action) |
|
|
|
|
|
|
|
assert summary_collector._collect_specified_data == expected_result |