Browse Source

!713 add UT for calc_hyper_param_importance function

Merge pull request !713 from shenghong96/calc_hyper_importance
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
872da81ad5
1 changed files with 26 additions and 0 deletions
  1. +26
    -0
      tests/ut/optimizer/test_utils.py

+ 26
- 0
tests/ut/optimizer/test_utils.py View File

@@ -14,8 +14,11 @@
# ============================================================================ # ============================================================================
"""Test the model module.""" """Test the model module."""
import numpy as np import numpy as np
import pandas as pd
import pytest import pytest


from mindinsight.optimizer.common.exceptions import SamplesNotEnoughError, CorrelationNanError
from mindinsight.optimizer.utils.importances import calc_hyper_param_importance
from mindinsight.optimizer.utils.utils import is_simple_numpy_number, calc_histogram from mindinsight.optimizer.utils.utils import is_simple_numpy_number, calc_histogram




@@ -33,3 +36,26 @@ def test_calc_histogram():
assert output[0][1] == pytest.approx(0.6, 1e-6) assert output[0][1] == pytest.approx(0.6, 1e-6)
assert output[1][1] == pytest.approx(0.6, 1e-6) assert output[1][1] == pytest.approx(0.6, 1e-6)
assert output[0][2] == pytest.approx(2.0, 1e-6) assert output[0][2] == pytest.approx(2.0, 1e-6)


def test_calc_hyper_param_importance_exception_1():
"""Test calc_hyper_param_importance function when number of samples is less or equal than 2"""
flattened_lineage = {'epoch': [10, 10], 'accuracy': [32, 32]}
with pytest.raises(SamplesNotEnoughError) as info:
calc_hyper_param_importance(pd.DataFrame(flattened_lineage), 'epoch', 'accuracy')
assert "Number of samples is less or equal than 2." in str(info.value)


def test_calc_hyper_param_importance_exception_2():
"""Test calc_hyper_param_importance function when correlation equals to NaN"""
flattened_lineage = {'epoch': [10, 10, 10], 'accuracy': [0.6432, 0.6281, 0.6692]}
with pytest.raises(CorrelationNanError) as info:
calc_hyper_param_importance(pd.DataFrame(flattened_lineage), 'epoch', 'accuracy')
assert "Correlation is nan!" in str(info.value)


def test_calc_hyper_param_importance():
"""Test calc_hyper_param_importance function"""
flattened_lineage = {'epoch': [10, 20, 30], 'accuracy': [30, 40, 50]}
result = calc_hyper_param_importance(pd.DataFrame(flattened_lineage), 'epoch', 'accuracy')
assert result == pytest.approx(1.0, 1e-6)

Loading…
Cancel
Save