From cd868aea52d6bd984d53266f2fb61eaae4a9fb9e Mon Sep 17 00:00:00 2001 From: ougongchang Date: Tue, 30 Jun 2020 19:44:28 +0800 Subject: [PATCH] fix get loss error and NoneType error cause by _proceesor_specified_data fix get loss error when it not a scalar and fix process specified data failed when the action is False, and collect_specified_data parameter is not None --- .../train/callback/_summary_collector.py | 15 +++++---- .../train/summary/test_summary_collector.py | 32 ++++++++++++++++--- 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/mindspore/train/callback/_summary_collector.py b/mindspore/train/callback/_summary_collector.py index a94eb2fc55..c77a41fedc 100644 --- a/mindspore/train/callback/_summary_collector.py +++ b/mindspore/train/callback/_summary_collector.py @@ -239,7 +239,8 @@ class SummaryCollector(Callback): unexpected_params = set(specified_data) - set(self._DEFAULT_SPECIFIED_DATA) if unexpected_params: - raise ValueError(f'For `collect_specified_data` the keys {unexpected_params} are unsupported.') + raise ValueError(f'For `collect_specified_data` the keys {unexpected_params} are unsupported, ' + f'expect the follow keys: {list(self._DEFAULT_SPECIFIED_DATA.keys())}') if 'histogram_regular' in specified_data: check_value_type('histogram_regular', specified_data.get('histogram_regular'), (str, type(None))) @@ -250,7 +251,8 @@ class SummaryCollector(Callback): check_value_type(item, specified_data.get(item), bool) if action: - result = dict(self._DEFAULT_SPECIFIED_DATA).update(specified_data) + result = dict(self._DEFAULT_SPECIFIED_DATA) + result.update(specified_data) else: result = specified_data return result @@ -444,15 +446,12 @@ class SummaryCollector(Callback): self._is_parse_loss_success = False return None - if isinstance(output, (int, float)): + if isinstance(output, (int, float, Tensor)): loss = output - elif isinstance(output, (list, tuple)): + elif isinstance(output, (list, tuple)) and output: # If the output is a list, since the default network returns loss first, # we assume that the first one is loss. loss = output[0] - elif isinstance(output, Tensor) and (not output.shape or output.shape == (1,)): - loss_numpy = output.asnumpy() - loss = float(np.atleast_1d(loss_numpy)[0]) else: logger.warning("The output type could not be identified, so no loss was recorded in SummaryCollector.") self._is_parse_loss_success = False @@ -461,6 +460,8 @@ class SummaryCollector(Callback): if not isinstance(loss, Tensor): loss = Tensor(loss) + precision = 4 + loss = Tensor(round(np.mean(loss.asnumpy()), precision)) return loss def _get_optimizer(self, cb_params): diff --git a/tests/ut/python/train/summary/test_summary_collector.py b/tests/ut/python/train/summary/test_summary_collector.py index 1390d29bc1..31552e44bd 100644 --- a/tests/ut/python/train/summary/test_summary_collector.py +++ b/tests/ut/python/train/summary/test_summary_collector.py @@ -50,6 +50,9 @@ def get_value(): _VALUE_CACHE = list() return value +_SPECIFIED_DATA = SummaryCollector._DEFAULT_SPECIFIED_DATA +_SPECIFIED_DATA['collect_metric'] = False + class CustomNet(Cell): """Define custom netwrok.""" @@ -190,8 +193,8 @@ class TestSummaryCollector: data = {'unexpected_key': True} with pytest.raises(ValueError) as exc: SummaryCollector(summary_dir, collect_specified_data=data) - expected_msg = f"For `collect_specified_data` the keys {set(data)} are unsupported." - assert expected_msg == str(exc.value) + expected_msg = f"For `collect_specified_data` the keys {set(data)} are unsupported" + assert expected_msg in str(exc.value) @pytest.mark.parametrize("custom_lineage_data", [ 123, @@ -273,12 +276,16 @@ class TestSummaryCollector: assert name == 'train_dataset' @pytest.mark.parametrize("net_output, expected_loss", [ + (None, None), (1, Tensor(1)), + (1.5, Tensor(1.5)), + (Tensor(1), Tensor(1)), ([1], Tensor(1)), ([Tensor(1)], Tensor(1)), - (Tensor([1]), Tensor(1)), + ({}, None), + (Tensor([[1, 2], [3, 4]]), Tensor(2.5)), + ([Tensor([[3, 4, 3]]), Tensor([3, 4])], Tensor(3.33333)), (tuple([1]), Tensor(1)), - (None, None) ]) def test_get_loss(self, net_output, expected_loss): """Test get loss success and failed.""" @@ -375,3 +382,20 @@ class TestSummaryCollector: assert PluginEnum.HISTOGRAM.value == result[0][0] assert expected_names == [data[1] for data in result] assert expected_values == [data[2] for data in result] + + @pytest.mark.parametrize("specified_data, action, expected_result", [ + (None, True, SummaryCollector._DEFAULT_SPECIFIED_DATA), + (None, False, {}), + ({}, True, SummaryCollector._DEFAULT_SPECIFIED_DATA), + ({}, False, {}), + ({'collect_metric': False}, True, _SPECIFIED_DATA), + ({'collect_metric': True}, False, {'collect_metric': True}) + ]) + def test_process_specified_data(self, specified_data, action, expected_result): + """Test process specified data.""" + summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) + summary_collector = SummaryCollector(summary_dir, + collect_specified_data=specified_data, + keep_default_action=action) + + assert summary_collector._collect_specified_data == expected_result