Browse Source

!713 add UT for calc_hyper_param_importance function

Merge pull request !713 from shenghong96/calc_hyper_importance
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
872da81ad5
1 changed files with 26 additions and 0 deletions
  1. +26
    -0
      tests/ut/optimizer/test_utils.py

+ 26
- 0
tests/ut/optimizer/test_utils.py View File

@@ -14,8 +14,11 @@
# ============================================================================
"""Test the model module."""
import numpy as np
import pandas as pd
import pytest

from mindinsight.optimizer.common.exceptions import SamplesNotEnoughError, CorrelationNanError
from mindinsight.optimizer.utils.importances import calc_hyper_param_importance
from mindinsight.optimizer.utils.utils import is_simple_numpy_number, calc_histogram


@@ -33,3 +36,26 @@ def test_calc_histogram():
assert output[0][1] == pytest.approx(0.6, 1e-6)
assert output[1][1] == pytest.approx(0.6, 1e-6)
assert output[0][2] == pytest.approx(2.0, 1e-6)


def test_calc_hyper_param_importance_exception_1():
"""Test calc_hyper_param_importance function when number of samples is less or equal than 2"""
flattened_lineage = {'epoch': [10, 10], 'accuracy': [32, 32]}
with pytest.raises(SamplesNotEnoughError) as info:
calc_hyper_param_importance(pd.DataFrame(flattened_lineage), 'epoch', 'accuracy')
assert "Number of samples is less or equal than 2." in str(info.value)


def test_calc_hyper_param_importance_exception_2():
"""Test calc_hyper_param_importance function when correlation equals to NaN"""
flattened_lineage = {'epoch': [10, 10, 10], 'accuracy': [0.6432, 0.6281, 0.6692]}
with pytest.raises(CorrelationNanError) as info:
calc_hyper_param_importance(pd.DataFrame(flattened_lineage), 'epoch', 'accuracy')
assert "Correlation is nan!" in str(info.value)


def test_calc_hyper_param_importance():
"""Test calc_hyper_param_importance function"""
flattened_lineage = {'epoch': [10, 20, 30], 'accuracy': [30, 40, 50]}
result = calc_hyper_param_importance(pd.DataFrame(flattened_lineage), 'epoch', 'accuracy')
assert result == pytest.approx(1.0, 1e-6)

Loading…
Cancel
Save