You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_CollectiveBase.py 5.0 kB

first commit Former-commit-id: 08bc23ba02cffbce3cf63962390a65459a132e48 [formerly 0795edd4834b9b7dc66db8d10d4cbaf42bbf82cb] [formerly b5010b42541add7e2ea2578bf2da537efc457757 [formerly a7ca09c2c34c4fc8b3d8e01fcfa08eeeb2cae99d]] [formerly 615058473a2177ca5b89e9edbb797f4c2a59c7e5 [formerly 743d8dfc6843c4c205051a8ab309fbb2116c895e] [formerly bb0ea98b1e14154ef464e2f7a16738705894e54b [formerly 960a69da74b81ef8093820e003f2d6c59a34974c]]] [formerly 2fa3be52c1b44665bc81a7cc7d4cea4bbf0d91d5 [formerly 2054589f0898627e0a17132fd9d4cc78efc91867] [formerly 3b53730e8a895e803dfdd6ca72bc05e17a4164c1 [formerly 8a2fa8ab7baf6686d21af1f322df46fd58c60e69]] [formerly 87d1e3a07a19d03c7d7c94d93ab4fa9f58dada7c [formerly f331916385a5afac1234854ee8d7f160f34b668f] [formerly 69fb3c78a483343f5071da4f7e2891b83a49dd18 [formerly 386086f05aa9487f65bce2ee54438acbdce57650]]]] Former-commit-id: a00aed8c934a6460c4d9ac902b9a74a3d6864697 [formerly 26fdeca29c2f07916d837883983ca2982056c78e] [formerly 0e3170d41a2f99ecf5c918183d361d4399d793bf [formerly 3c12ad4c88ac5192e0f5606ac0d88dd5bf8602dc]] [formerly d5894f84f2fd2e77a6913efdc5ae388cf1be0495 [formerly ad3e7bc670ff92c992730d29c9d3aa1598d844e8] [formerly 69fb3c78a483343f5071da4f7e2891b83a49dd18]] Former-commit-id: 3c19c9fae64f6106415fbc948a4dc613b9ee12f8 [formerly 467ddc0549c74bb007e8f01773bb6dc9103b417d] [formerly 5fa518345d958e2760e443b366883295de6d991c [formerly 3530e130b9fdb7280f638dbc2e785d2165ba82aa]] Former-commit-id: 9f5d473d42a435ec0d60149939d09be1acc25d92 [formerly be0b25c4ec2cde052a041baf0e11f774a158105d] Former-commit-id: 9eca71cb73ba9edccd70ac06a3b636b8d4093b04
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. # -*- coding: utf-8 -*-
  2. from __future__ import division
  3. from __future__ import print_function
  4. import os
  5. import sys
  6. import unittest
  7. from sklearn.utils.testing import assert_equal
  8. from sklearn.utils.testing import assert_raises
  9. import numpy as np
  10. # temporary solution for relative imports in case pyod is not installed
  11. # if pyod is installed, no need to use the following line
  12. sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
  13. from detection_algorithm.core.CollectiveBase import CollectiveBaseDetector
  14. from pyod.utils.data import generate_data
  15. # Check sklearn\tests\test_base
  16. # A few test classes
  17. # noinspection PyMissingConstructor,PyPep8Naming
  18. class MyEstimator(CollectiveBaseDetector):
  19. def __init__(self, l1=0, empty=None):
  20. self.l1 = l1
  21. self.empty = empty
  22. def fit(self, X, y=None):
  23. pass
  24. def decision_function(self, X):
  25. pass
  26. # noinspection PyMissingConstructor
  27. class K(CollectiveBaseDetector):
  28. def __init__(self, c=None, d=None):
  29. self.c = c
  30. self.d = d
  31. def fit(self, X, y=None):
  32. pass
  33. def decision_function(self, X):
  34. pass
  35. # noinspection PyMissingConstructor
  36. class T(CollectiveBaseDetector):
  37. def __init__(self, a=None, b=None):
  38. self.a = a
  39. self.b = b
  40. def fit(self, X, y=None):
  41. pass
  42. def decision_function(self, X):
  43. pass
  44. # noinspection PyMissingConstructor
  45. class ModifyInitParams(CollectiveBaseDetector):
  46. """Deprecated behavior.
  47. Equal parameters but with a type cast.
  48. Doesn't fulfill a is a
  49. """
  50. def __init__(self, a=np.array([0])):
  51. self.a = a.copy()
  52. def fit(self, X, y=None):
  53. pass
  54. def decision_function(self, X):
  55. pass
  56. # noinspection PyMissingConstructor
  57. class VargEstimator(CollectiveBaseDetector):
  58. """scikit-learn estimators shouldn't have vargs."""
  59. def __init__(self, *vargs):
  60. pass
  61. def fit(self, X, y=None):
  62. pass
  63. def decision_function(self, X):
  64. pass
  65. class Dummy1(CollectiveBaseDetector):
  66. def __init__(self, contamination=0.1):
  67. super(Dummy1, self).__init__(contamination=contamination)
  68. def decision_function(self, X):
  69. pass
  70. def fit(self, X, y=None):
  71. pass
  72. class Dummy2(CollectiveBaseDetector):
  73. def __init__(self, contamination=0.1):
  74. super(Dummy2, self).__init__(contamination=contamination)
  75. def decision_function(self, X):
  76. pass
  77. def fit(self, X, y=None):
  78. return X
  79. class Dummy3(CollectiveBaseDetector):
  80. def __init__(self, contamination=0.1):
  81. super(Dummy3, self).__init__(contamination=contamination)
  82. def decision_function(self, X):
  83. pass
  84. def fit(self, X, y=None):
  85. self.labels_ = X
  86. class TestBASE(unittest.TestCase):
  87. def setUp(self):
  88. self.n_train = 100
  89. self.n_test = 50
  90. self.contamination = 0.1
  91. self.roc_floor = 0.6
  92. self.X_train, self.y_train, self.X_test, self.y_test = generate_data(
  93. n_train=self.n_train, n_test=self.n_test,
  94. contamination=self.contamination)
  95. def test_init(self):
  96. """
  97. Test base class initialization
  98. :return:
  99. """
  100. self.dummy_clf = Dummy1()
  101. assert_equal(self.dummy_clf.contamination, 0.1)
  102. self.dummy_clf = Dummy1(contamination=0.2)
  103. assert_equal(self.dummy_clf.contamination, 0.2)
  104. with assert_raises(ValueError):
  105. Dummy1(contamination=0.51)
  106. with assert_raises(ValueError):
  107. Dummy1(contamination=0)
  108. with assert_raises(ValueError):
  109. Dummy1(contamination=-0.5)
  110. def test_fit(self):
  111. self.dummy_clf = Dummy2()
  112. assert_equal(self.dummy_clf.fit(0), 0)
  113. def test_fit_predict(self):
  114. # TODO: add more testcases
  115. self.dummy_clf = Dummy3()
  116. assert_equal(self.dummy_clf.fit_predict(0), 0)
  117. def test_predict_proba(self):
  118. # TODO: create uniform testcases
  119. pass
  120. def test_rank(self):
  121. # TODO: create uniform testcases
  122. pass
  123. def test_repr(self):
  124. # Smoke test the repr of the base estimator.
  125. my_estimator = MyEstimator()
  126. repr(my_estimator)
  127. test = T(K(), K())
  128. assert_equal(
  129. repr(test),
  130. "T(a=K(c=None, d=None), b=K(c=None, d=None))"
  131. )
  132. some_est = T(a=["long_params"] * 1000)
  133. assert_equal(len(repr(some_est)), 415)
  134. def test_str(self):
  135. # Smoke test the str of the base estimator
  136. my_estimator = MyEstimator()
  137. str(my_estimator)
  138. def test_get_params(self):
  139. test = T(K(), K())
  140. assert ('a__d' in test.get_params(deep=True))
  141. assert ('a__d' not in test.get_params(deep=False))
  142. test.set_params(a__d=2)
  143. assert (test.a.d == 2)
  144. assert_raises(ValueError, test.set_params, a__a=2)
  145. def tearDown(self):
  146. pass
  147. if __name__ == '__main__':
  148. unittest.main()

全栈的自动化机器学习系统,主要针对多变量时间序列数据的异常检测。TODS提供了详尽的用于构建基于机器学习的异常检测系统的模块,它们包括:数据处理(data processing),时间序列处理( time series processing),特征分析(feature analysis),检测算法(detection algorithms),和强化模块( reinforcement module)。这些模块所提供的功能包括常见的数据预处理、时间序列数据的平滑或变换,从时域或频域中抽取特征、多种多样的检测算