You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_primitive_validation.py 39 kB

first commit Former-commit-id: 08bc23ba02cffbce3cf63962390a65459a132e48 [formerly 0795edd4834b9b7dc66db8d10d4cbaf42bbf82cb] [formerly b5010b42541add7e2ea2578bf2da537efc457757 [formerly a7ca09c2c34c4fc8b3d8e01fcfa08eeeb2cae99d]] [formerly 615058473a2177ca5b89e9edbb797f4c2a59c7e5 [formerly 743d8dfc6843c4c205051a8ab309fbb2116c895e] [formerly bb0ea98b1e14154ef464e2f7a16738705894e54b [formerly 960a69da74b81ef8093820e003f2d6c59a34974c]]] [formerly 2fa3be52c1b44665bc81a7cc7d4cea4bbf0d91d5 [formerly 2054589f0898627e0a17132fd9d4cc78efc91867] [formerly 3b53730e8a895e803dfdd6ca72bc05e17a4164c1 [formerly 8a2fa8ab7baf6686d21af1f322df46fd58c60e69]] [formerly 87d1e3a07a19d03c7d7c94d93ab4fa9f58dada7c [formerly f331916385a5afac1234854ee8d7f160f34b668f] [formerly 69fb3c78a483343f5071da4f7e2891b83a49dd18 [formerly 386086f05aa9487f65bce2ee54438acbdce57650]]]] Former-commit-id: a00aed8c934a6460c4d9ac902b9a74a3d6864697 [formerly 26fdeca29c2f07916d837883983ca2982056c78e] [formerly 0e3170d41a2f99ecf5c918183d361d4399d793bf [formerly 3c12ad4c88ac5192e0f5606ac0d88dd5bf8602dc]] [formerly d5894f84f2fd2e77a6913efdc5ae388cf1be0495 [formerly ad3e7bc670ff92c992730d29c9d3aa1598d844e8] [formerly 69fb3c78a483343f5071da4f7e2891b83a49dd18]] Former-commit-id: 3c19c9fae64f6106415fbc948a4dc613b9ee12f8 [formerly 467ddc0549c74bb007e8f01773bb6dc9103b417d] [formerly 5fa518345d958e2760e443b366883295de6d991c [formerly 3530e130b9fdb7280f638dbc2e785d2165ba82aa]] Former-commit-id: 9f5d473d42a435ec0d60149939d09be1acc25d92 [formerly be0b25c4ec2cde052a041baf0e11f774a158105d] Former-commit-id: 9eca71cb73ba9edccd70ac06a3b636b8d4093b04
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808
  1. import typing
  2. import unittest
  3. import logging
  4. from d3m import container, exceptions, utils
  5. from d3m.metadata import base as metadata_base, hyperparams, params
  6. from d3m.primitive_interfaces import base, transformer, unsupervised_learning
  7. Inputs = container.List
  8. Outputs = container.List
  9. class Hyperparams(hyperparams.Hyperparams):
  10. pass
  11. class TestPrimitiveValidation(unittest.TestCase):
  12. def test_multi_produce_missing_argument(self):
  13. with self.assertRaisesRegex(exceptions.InvalidPrimitiveCodeError, '\'multi_produce\' method arguments have to be an union of all arguments of all produce methods, but it does not accept all expected arguments'):
  14. # Silence any validation warnings.
  15. with utils.silence():
  16. class TestPrimitive(transformer.TransformerPrimitiveBase[Inputs, Outputs, Hyperparams]):
  17. metadata = metadata_base.PrimitiveMetadata({
  18. 'id': '67568a80-dec2-4597-a10f-39afb13d3b9c',
  19. 'version': '0.1.0',
  20. 'name': "Test Primitive",
  21. 'python_path': 'd3m.primitives.test.TestPrimitive',
  22. 'algorithm_types': [
  23. metadata_base.PrimitiveAlgorithmType.NUMERICAL_METHOD,
  24. ],
  25. 'primitive_family': metadata_base.PrimitiveFamily.OPERATOR,
  26. })
  27. def produce(self, *, inputs: Inputs, second_inputs: Inputs, timeout: float = None, iterations: int = None) -> base.CallResult[Outputs]:
  28. pass
  29. def test_fit_multi_produce_missing_argument(self):
  30. with self.assertRaisesRegex(exceptions.InvalidPrimitiveCodeError, '\'fit_multi_produce\' method arguments have to be an union of all arguments of \'set_training_data\' method and all produce methods, but it does not accept all expected arguments'):
  31. # Silence any validation warnings.
  32. with utils.silence():
  33. class TestPrimitive(transformer.TransformerPrimitiveBase[Inputs, Outputs, Hyperparams]):
  34. metadata = metadata_base.PrimitiveMetadata({
  35. 'id': '67568a80-dec2-4597-a10f-39afb13d3b9c',
  36. 'version': '0.1.0',
  37. 'name': "Test Primitive",
  38. 'python_path': 'd3m.primitives.test.TestPrimitive',
  39. 'algorithm_types': [
  40. metadata_base.PrimitiveAlgorithmType.NUMERICAL_METHOD,
  41. ],
  42. 'primitive_family': metadata_base.PrimitiveFamily.OPERATOR,
  43. })
  44. def produce(self, *, inputs: Inputs, second_inputs: Inputs, timeout: float = None, iterations: int = None) -> base.CallResult[Outputs]:
  45. pass
  46. def multi_produce(self, *, produce_methods: typing.Sequence[str], inputs: Inputs, second_inputs: Inputs, timeout: float = None, iterations: int = None) -> base.MultiCallResult:
  47. pass
  48. def test_multi_produce_extra_argument(self):
  49. with self.assertRaisesRegex(exceptions.InvalidPrimitiveCodeError, '\'multi_produce\' method arguments have to be an union of all arguments of all produce methods, but it accepts unexpected arguments'):
  50. # Silence any validation warnings.
  51. with utils.silence():
  52. class TestPrimitive(transformer.TransformerPrimitiveBase[Inputs, Outputs, Hyperparams]):
  53. metadata = metadata_base.PrimitiveMetadata({
  54. 'id': '67568a80-dec2-4597-a10f-39afb13d3b9c',
  55. 'version': '0.1.0',
  56. 'name': "Test Primitive",
  57. 'python_path': 'd3m.primitives.test.TestPrimitive',
  58. 'algorithm_types': [
  59. metadata_base.PrimitiveAlgorithmType.NUMERICAL_METHOD,
  60. ],
  61. 'primitive_family': metadata_base.PrimitiveFamily.OPERATOR,
  62. })
  63. def produce(self, *, inputs: Inputs, timeout: float = None, iterations: int = None) -> base.CallResult[Outputs]:
  64. pass
  65. def multi_produce(self, *, produce_methods: typing.Sequence[str], inputs: Inputs, second_inputs: Inputs, timeout: float = None, iterations: int = None) -> base.MultiCallResult:
  66. pass
  67. def test_fit_multi_produce_extra_argument(self):
  68. with self.assertRaisesRegex(exceptions.InvalidPrimitiveCodeError, '\'fit_multi_produce\' method arguments have to be an union of all arguments of \'set_training_data\' method and all produce methods, but it accepts unexpected arguments'):
  69. # Silence any validation warnings.
  70. with utils.silence():
  71. class TestPrimitive(transformer.TransformerPrimitiveBase[Inputs, Outputs, Hyperparams]):
  72. metadata = metadata_base.PrimitiveMetadata({
  73. 'id': '67568a80-dec2-4597-a10f-39afb13d3b9c',
  74. 'version': '0.1.0',
  75. 'name': "Test Primitive",
  76. 'python_path': 'd3m.primitives.test.TestPrimitive',
  77. 'algorithm_types': [
  78. metadata_base.PrimitiveAlgorithmType.NUMERICAL_METHOD,
  79. ],
  80. 'primitive_family': metadata_base.PrimitiveFamily.OPERATOR,
  81. })
  82. def produce(self, *, inputs: Inputs, timeout: float = None, iterations: int = None) -> base.CallResult[Outputs]:
  83. pass
  84. def fit_multi_produce(self, *, produce_methods: typing.Sequence[str], inputs: Inputs, second_inputs: Inputs, timeout: float = None, iterations: int = None) -> base.MultiCallResult:
  85. pass
  86. def test_produce_using_produce_methods(self):
  87. with self.assertRaisesRegex(exceptions.InvalidPrimitiveCodeError, 'Produce method cannot use \'produce_methods\' argument'):
  88. # Silence any validation warnings.
  89. with utils.silence():
  90. class TestPrimitive(transformer.TransformerPrimitiveBase[Inputs, Outputs, Hyperparams]):
  91. metadata = metadata_base.PrimitiveMetadata({
  92. 'id': '67568a80-dec2-4597-a10f-39afb13d3b9c',
  93. 'version': '0.1.0',
  94. 'name': "Test Primitive",
  95. 'python_path': 'd3m.primitives.test.TestPrimitive',
  96. 'algorithm_types': [
  97. metadata_base.PrimitiveAlgorithmType.NUMERICAL_METHOD,
  98. ],
  99. 'primitive_family': metadata_base.PrimitiveFamily.OPERATOR,
  100. })
  101. def produce(self, *, inputs: Inputs, produce_methods: typing.Sequence[str], timeout: float = None, iterations: int = None) -> base.CallResult[Outputs]:
  102. pass
  103. def test_hyperparams_to_tune(self):
  104. with self.assertRaisesRegex(exceptions.InvalidMetadataError, 'Hyper-parameter in \'hyperparams_to_tune\' metadata does not exist'):
  105. # Silence any validation warnings.
  106. with utils.silence():
  107. class TestPrimitive(transformer.TransformerPrimitiveBase[Inputs, Outputs, Hyperparams]):
  108. metadata = metadata_base.PrimitiveMetadata({
  109. 'id': '67568a80-dec2-4597-a10f-39afb13d3b9c',
  110. 'version': '0.1.0',
  111. 'name': "Test Primitive",
  112. 'python_path': 'd3m.primitives.test.TestPrimitive',
  113. 'algorithm_types': [
  114. metadata_base.PrimitiveAlgorithmType.NUMERICAL_METHOD,
  115. ],
  116. 'primitive_family': metadata_base.PrimitiveFamily.OPERATOR,
  117. 'hyperparams_to_tune': [
  118. 'foobar',
  119. ]
  120. })
  121. def produce(self, *, inputs: Inputs, timeout: float = None, iterations: int = None) -> base.CallResult[Outputs]:
  122. pass
  123. def test_inputs_across_samples(self):
  124. with self.assertRaisesRegex(exceptions.InvalidPrimitiveCodeError, 'Method \'.*\' has an argument \'.*\' set as computing across samples, but it does not exist'):
  125. # Silence any validation warnings.
  126. with utils.silence():
  127. class TestPrimitive(transformer.TransformerPrimitiveBase[Inputs, Outputs, Hyperparams]):
  128. metadata = metadata_base.PrimitiveMetadata({
  129. 'id': '67568a80-dec2-4597-a10f-39afb13d3b9c',
  130. 'version': '0.1.0',
  131. 'name': "Test Primitive",
  132. 'python_path': 'd3m.primitives.test.TestPrimitive',
  133. 'algorithm_types': [
  134. metadata_base.PrimitiveAlgorithmType.NUMERICAL_METHOD,
  135. ],
  136. 'primitive_family': metadata_base.PrimitiveFamily.OPERATOR,
  137. 'hyperparams_to_tune': [
  138. 'foobar',
  139. ]
  140. })
  141. @base.inputs_across_samples('foobar')
  142. def produce(self, *, inputs: Inputs, timeout: float = None, iterations: int = None) -> base.CallResult[Outputs]:
  143. pass
  144. with self.assertRaisesRegex(exceptions.InvalidPrimitiveCodeError, 'Method \'.*\' has an argument \'.*\' set as computing across samples, but it is not a PIPELINE argument'):
  145. # Silence any validation warnings.
  146. with utils.silence():
  147. class TestPrimitive(transformer.TransformerPrimitiveBase[Inputs, Outputs, Hyperparams]):
  148. metadata = metadata_base.PrimitiveMetadata({
  149. 'id': '67568a80-dec2-4597-a10f-39afb13d3b9c',
  150. 'version': '0.1.0',
  151. 'name': "Test Primitive",
  152. 'python_path': 'd3m.primitives.test.TestPrimitive',
  153. 'algorithm_types': [
  154. metadata_base.PrimitiveAlgorithmType.NUMERICAL_METHOD,
  155. ],
  156. 'primitive_family': metadata_base.PrimitiveFamily.OPERATOR,
  157. 'hyperparams_to_tune': [
  158. 'foobar',
  159. ]
  160. })
  161. @base.inputs_across_samples('timeout')
  162. def produce(self, *, inputs: Inputs, timeout: float = None, iterations: int = None) -> base.CallResult[Outputs]:
  163. pass
  164. def test_can_detect_too_many_package_components(self):
  165. logger = logging.getLogger('d3m.metadata.base')
  166. # Ensure a warning message is generated for too many package components
  167. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  168. metadata_base.PrimitiveMetadata()._validate_namespace_compliance('d3m.primitives.classification.random_forest.SKLearn.toomany', metadata_base.PrimitiveFamily.CLASSIFICATION)
  169. self.assertEqual(len(cm.records), 1)
  170. self.assertEqual(cm.records[0].msg,
  171. "%(python_path)s: Primitive's Python path does not adhere to d3m.primitives namespace specification. "
  172. "Reason: must have 5 segments.")
  173. # Ensure a warning message is NOT generated for an acceptable number of components
  174. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  175. logger.debug("Dummy log")
  176. metadata_base.PrimitiveMetadata()._validate_namespace_compliance('d3m.primitives.classification.random_forest.SKLearn', metadata_base.PrimitiveFamily.CLASSIFICATION)
  177. self.assertEqual(len(cm.records), 1)
  178. def test_with_string_instead_of_enum(self):
  179. logger = logging.getLogger(metadata_base.__name__)
  180. # Ensure a warning message is NOT generated for an acceptable number of components
  181. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  182. logger.debug("Dummy log")
  183. metadata_base.PrimitiveMetadata()._validate_namespace_compliance('d3m.primitives.classification.random_forest.SKLearn', metadata_base.PrimitiveFamily.CLASSIFICATION.name)
  184. self.assertEqual(len(cm.records), 1)
  185. def test_can_detect_too_few_package_components(self):
  186. logger = logging.getLogger(metadata_base.__name__)
  187. # Ensure a warning message is generated for too few package components
  188. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  189. metadata_base.PrimitiveMetadata()._validate_namespace_compliance('d3m.primitives.classification.too_few', metadata_base.PrimitiveFamily.CLASSIFICATION)
  190. self.assertEqual(len(cm.records), 1)
  191. self.assertEqual(cm.records[0].msg,
  192. "%(python_path)s: Primitive's Python path does not adhere to d3m.primitives namespace specification. "
  193. "Reason: must have 5 segments.")
  194. # Ensure a warning message is NOT generated for an acceptable number of components
  195. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  196. logger.debug("Dummy log")
  197. metadata_base.PrimitiveMetadata()._validate_namespace_compliance('d3m.primitives.classification.random_forest.SKLearn', metadata_base.PrimitiveFamily.CLASSIFICATION)
  198. self.assertEqual(len(cm.records), 1)
  199. def test_can_detect_bad_primitive_family(self):
  200. logger = logging.getLogger(metadata_base.__name__)
  201. # Ensure a warning message is generated for a bad primitive family
  202. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  203. metadata_base.PrimitiveMetadata()._validate_namespace_compliance('d3m.primitives.bad_family.random_forest.SKLearn', metadata_base.PrimitiveFamily.CLASSIFICATION)
  204. self.assertEqual(len(cm.records), 1)
  205. self.assertEqual(cm.records[0].msg,
  206. "%(python_path)s: Primitive's Python path does not adhere to d3m.primitives namespace specification."
  207. " Reason: primitive family segment must match primitive's primitive family.")
  208. # Ensure a warning message is NOT generated for an acceptable primitive family
  209. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  210. logger.debug("Dummy log")
  211. metadata_base.PrimitiveMetadata()._validate_namespace_compliance('d3m.primitives.classification.random_forest.SKLearn', metadata_base.PrimitiveFamily.CLASSIFICATION)
  212. self.assertEqual(len(cm.records), 1)
  213. def test_can_detect_bad_primitive_name(self):
  214. logger = logging.getLogger(metadata_base.__name__)
  215. # Ensure a warning message is generated for a bad primitive name
  216. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  217. metadata_base.PrimitiveMetadata()._validate_namespace_compliance('d3m.primitives.classification.bad_name.SKLearn', metadata_base.PrimitiveFamily.CLASSIFICATION)
  218. self.assertEqual(len(cm.records), 1)
  219. self.assertEqual(cm.records[0].msg,
  220. "%(python_path)s: Primitive's Python path does not adhere to d3m.primitives namespace specification. "
  221. "Reason: must have a known primitive name segment.")
  222. # Ensure a warning message is NOT generated for an acceptable primitive name
  223. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  224. logger.debug("Dummy log")
  225. metadata_base.PrimitiveMetadata()._validate_namespace_compliance('d3m.primitives.classification.random_forest.SKLearn', metadata_base.PrimitiveFamily.CLASSIFICATION)
  226. self.assertEqual(len(cm.records), 1)
  227. def test_can_detect_kind_not_capitalized(self):
  228. logger = logging.getLogger(metadata_base.__name__)
  229. # Ensure a warning message is generated for a primitive kind not capitalized properly
  230. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  231. metadata_base.PrimitiveMetadata()._validate_namespace_compliance('d3m.primitives.classification.random_forest.sklearn', metadata_base.PrimitiveFamily.CLASSIFICATION)
  232. self.assertEqual(len(cm.records), 1)
  233. self.assertEqual(cm.records[0].msg,
  234. "%(python_path)s: Primitive's Python path does not adhere to d3m.primitives namespace specification. "
  235. "Reason: primitive kind segment must start with upper case.")
  236. # Ensure a warning message is NOT generated for an acceptable primitive kind
  237. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  238. logger.debug("Dummy log")
  239. metadata_base.PrimitiveMetadata()._validate_namespace_compliance('d3m.primitives.classification.random_forest.SKLearn', metadata_base.PrimitiveFamily.CLASSIFICATION)
  240. self.assertEqual(len(cm.records), 1)
  241. def test_will_generate_warning_for_missing_contact(self):
  242. logger = logging.getLogger(metadata_base.__name__)
  243. bad_metadata = metadata_base.PrimitiveMetadata({
  244. 'id': 'id',
  245. 'version': '0.1.0',
  246. 'name': "Test Primitive",
  247. 'python_path': 'path',
  248. 'algorithm_types': [
  249. metadata_base.PrimitiveAlgorithmType.PRINCIPAL_COMPONENT_ANALYSIS,
  250. ],
  251. 'primitive_family': metadata_base.PrimitiveFamily.FEATURE_SELECTION,
  252. 'installation': [{
  253. 'type': metadata_base.PrimitiveInstallationType.PIP,
  254. 'package': 'foobar',
  255. 'version': '0.1.0',
  256. }],
  257. 'source': {
  258. 'name': 'Test author',
  259. # 'contact': 'mailto:test@example.com',
  260. 'uris': 'http://someplace'
  261. }
  262. })
  263. good_metadata = metadata_base.PrimitiveMetadata({
  264. 'id': 'id',
  265. 'version': '0.1.0',
  266. 'name': "Test Primitive",
  267. 'python_path': 'path',
  268. 'algorithm_types': [
  269. metadata_base.PrimitiveAlgorithmType.PRINCIPAL_COMPONENT_ANALYSIS,
  270. ],
  271. 'primitive_family': metadata_base.PrimitiveFamily.FEATURE_SELECTION,
  272. 'installation': [{
  273. 'type': metadata_base.PrimitiveInstallationType.PIP,
  274. 'package': 'foobar',
  275. 'version': '0.1.0',
  276. }],
  277. 'source': {
  278. 'name': 'Test author',
  279. 'contact': 'mailto:test@example.com',
  280. 'uris': 'http://someplace'
  281. }
  282. })
  283. # Ensure a warning message is generated for a primitive with no contact specified in the metadata.source
  284. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  285. metadata_base.PrimitiveMetadata()._validate_contact_information(bad_metadata.query())
  286. self.assertEqual(len(cm.records), 1)
  287. self.assertEqual(cm.records[0].msg, "%(python_path)s: Contact information such as the email address of the author (e.g., \"mailto:author@example.com\") should be specified in primitive metadata in its \"source.contact\" field.")
  288. # Ensure a warning message is NOT generated for a primitive with a contact specified in the metadata.source
  289. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  290. logger.debug("Dummy log")
  291. metadata_base.PrimitiveMetadata()._validate_contact_information(good_metadata.query())
  292. self.assertEqual(len(cm.records), 1)
  293. def test_will_generate_warning_for_empty_contact(self):
  294. logger = logging.getLogger(metadata_base.__name__)
  295. bad_metadata = metadata_base.PrimitiveMetadata({
  296. 'id': 'id',
  297. 'version': '0.1.0',
  298. 'name': "Test Primitive",
  299. 'python_path': 'path',
  300. 'algorithm_types': [
  301. metadata_base.PrimitiveAlgorithmType.PRINCIPAL_COMPONENT_ANALYSIS,
  302. ],
  303. 'primitive_family': metadata_base.PrimitiveFamily.FEATURE_SELECTION,
  304. 'installation': [{
  305. 'type': metadata_base.PrimitiveInstallationType.PIP,
  306. 'package': 'foobar',
  307. 'version': '0.1.0',
  308. }],
  309. 'source': {
  310. 'name': 'Test author',
  311. 'contact': '',
  312. 'uris': ['http://someplace']
  313. }
  314. })
  315. good_metadata = metadata_base.PrimitiveMetadata({
  316. 'id': 'id',
  317. 'version': '0.1.0',
  318. 'name': "Test Primitive",
  319. 'python_path': 'path',
  320. 'algorithm_types': [
  321. metadata_base.PrimitiveAlgorithmType.PRINCIPAL_COMPONENT_ANALYSIS,
  322. ],
  323. 'primitive_family': metadata_base.PrimitiveFamily.FEATURE_SELECTION,
  324. 'installation': [{
  325. 'type': metadata_base.PrimitiveInstallationType.PIP,
  326. 'package': 'foobar',
  327. 'version': '0.1.0',
  328. }],
  329. 'source': {
  330. 'name': 'Test author',
  331. 'contact': 'mailto:test@example.com',
  332. 'uris': ['http://someplace']
  333. }
  334. })
  335. # Ensure a warning message is generated for a primitive with empty contact specified in the metadata.source.
  336. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  337. metadata_base.PrimitiveMetadata()._validate_contact_information(bad_metadata.query())
  338. self.assertEqual(len(cm.records), 1)
  339. self.assertEqual(cm.records[0].msg, "%(python_path)s: Contact information such as the email address of the author (e.g., \"mailto:author@example.com\") should be specified in primitive metadata in its \"source.contact\" field.")
  340. # Ensure a warning message is NOT generated when a contact value is specified.
  341. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  342. logger.debug("Dummy log")
  343. metadata_base.PrimitiveMetadata()._validate_contact_information(good_metadata.query())
  344. self.assertEqual(len(cm.records), 1)
  345. def test_will_not_generate_missing_contact_warning_when_installation_not_specified(self):
  346. logger = logging.getLogger(metadata_base.__name__)
  347. good_metadata = metadata_base.PrimitiveMetadata({
  348. 'id': 'id',
  349. 'version': '0.1.0',
  350. 'name': "Test Primitive",
  351. 'python_path': 'path',
  352. 'algorithm_types': [
  353. metadata_base.PrimitiveAlgorithmType.PRINCIPAL_COMPONENT_ANALYSIS,
  354. ],
  355. 'primitive_family': metadata_base.PrimitiveFamily.FEATURE_SELECTION,
  356. 'source': {
  357. 'name': 'Test author',
  358. 'uris': ['http://someplace']
  359. }
  360. })
  361. # Ensure a warning message is NOT generated when a contact value is not specified when installation is also
  362. # not specified.
  363. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  364. logger.debug("Dummy log")
  365. metadata_base.PrimitiveMetadata()._validate_contact_information(good_metadata.query())
  366. self.assertEqual(len(cm.records), 1)
  367. def test_will_generate_warning_for_missing_uris(self):
  368. logger = logging.getLogger(metadata_base.__name__)
  369. bad_metadata = metadata_base.PrimitiveMetadata({
  370. 'id': 'id',
  371. 'version': '0.1.0',
  372. 'name': "Test Primitive",
  373. 'python_path': 'path',
  374. 'algorithm_types': [
  375. metadata_base.PrimitiveAlgorithmType.PRINCIPAL_COMPONENT_ANALYSIS,
  376. ],
  377. 'primitive_family': metadata_base.PrimitiveFamily.FEATURE_SELECTION,
  378. 'installation': [{
  379. 'type': metadata_base.PrimitiveInstallationType.PIP,
  380. 'package': 'foobar',
  381. 'version': '0.1.0',
  382. }],
  383. 'source': {
  384. 'name': 'Test author',
  385. 'contact': 'mailto:test@example.com',
  386. }
  387. })
  388. good_metadata = metadata_base.PrimitiveMetadata({
  389. 'id': 'id',
  390. 'version': '0.1.0',
  391. 'name': "Test Primitive",
  392. 'python_path': 'path',
  393. 'algorithm_types': [
  394. metadata_base.PrimitiveAlgorithmType.PRINCIPAL_COMPONENT_ANALYSIS,
  395. ],
  396. 'primitive_family': metadata_base.PrimitiveFamily.FEATURE_SELECTION,
  397. 'installation': [{
  398. 'type': metadata_base.PrimitiveInstallationType.PIP,
  399. 'package': 'foobar',
  400. 'version': '0.1.0',
  401. }],
  402. 'source': {
  403. 'name': 'Test author',
  404. 'contact': 'mailto:test@example.com',
  405. 'uris': ['http://someplace'],
  406. }
  407. })
  408. # Ensure a warning message is generated for a primitive with no uris specified in the metadata.source.
  409. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  410. metadata_base.PrimitiveMetadata()._validate_contact_information(bad_metadata.query())
  411. self.assertEqual(len(cm.records), 1)
  412. self.assertEqual(cm.records[0].msg, "%(python_path)s: A bug reporting URI should be specified in primitive metadata in its \"source.uris\" field.")
  413. # Ensure a warning message is NOT generated when uris are specified in the metadata.source.
  414. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  415. logger.debug("Dummy log")
  416. metadata_base.PrimitiveMetadata()._validate_contact_information(good_metadata.query())
  417. self.assertEqual(len(cm.records), 1)
  418. def test_will_generate_warning_for_empty_uris(self):
  419. logger = logging.getLogger(metadata_base.__name__)
  420. bad_metadata = metadata_base.PrimitiveMetadata({
  421. 'id': 'id',
  422. 'version': '0.1.0',
  423. 'name': "Test Primitive",
  424. 'python_path': 'path',
  425. 'algorithm_types': [
  426. metadata_base.PrimitiveAlgorithmType.PRINCIPAL_COMPONENT_ANALYSIS,
  427. ],
  428. 'primitive_family': metadata_base.PrimitiveFamily.FEATURE_SELECTION,
  429. 'installation': [{
  430. 'type': metadata_base.PrimitiveInstallationType.PIP,
  431. 'package': 'foobar',
  432. 'version': '0.1.0',
  433. }],
  434. 'source': {
  435. 'name': 'Test author',
  436. 'contact': 'mailto:test@example.com',
  437. 'uris': [],
  438. }
  439. })
  440. good_metadata = metadata_base.PrimitiveMetadata({
  441. 'id': 'id',
  442. 'version': '0.1.0',
  443. 'name': "Test Primitive",
  444. 'python_path': 'path',
  445. 'algorithm_types': [
  446. metadata_base.PrimitiveAlgorithmType.PRINCIPAL_COMPONENT_ANALYSIS,
  447. ],
  448. 'primitive_family': metadata_base.PrimitiveFamily.FEATURE_SELECTION,
  449. 'installation': [{
  450. 'type': metadata_base.PrimitiveInstallationType.PIP,
  451. 'package': 'foobar',
  452. 'version': '0.1.0',
  453. }],
  454. 'source': {
  455. 'name': 'Test author',
  456. 'contact': 'mailto:test@example.com',
  457. 'uris': ['http://someplace'],
  458. }
  459. })
  460. # Ensure a warning message is generated for a primitive with empty uris specified in the metadata.source.
  461. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  462. metadata_base.PrimitiveMetadata()._validate_contact_information(bad_metadata.query())
  463. self.assertEqual(len(cm.records), 1)
  464. self.assertEqual(cm.records[0].msg, "%(python_path)s: A bug reporting URI should be specified in primitive metadata in its \"source.uris\" field.")
  465. # Ensure a warning message is NOT generated when non empty uris are specified in the metadata.source.
  466. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  467. logger.debug("Dummy log")
  468. metadata_base.PrimitiveMetadata()._validate_contact_information(good_metadata.query())
  469. self.assertEqual(len(cm.records), 1)
  470. def test_validation_will_warn_on_missing_source(self):
  471. logger = logging.getLogger(metadata_base.__name__)
  472. bad_metadata = metadata_base.PrimitiveMetadata({
  473. 'id': 'id',
  474. 'version': '0.1.0',
  475. 'name': "Test Primitive",
  476. 'python_path': 'path',
  477. 'algorithm_types': [
  478. metadata_base.PrimitiveAlgorithmType.PRINCIPAL_COMPONENT_ANALYSIS,
  479. ],
  480. 'primitive_family': metadata_base.PrimitiveFamily.FEATURE_SELECTION,
  481. 'installation': [{
  482. 'type': metadata_base.PrimitiveInstallationType.PIP,
  483. 'package': 'foobar',
  484. 'version': '0.1.0',
  485. }],
  486. })
  487. good_metadata = metadata_base.PrimitiveMetadata({
  488. 'id': 'id',
  489. 'version': '0.1.0',
  490. 'name': "Test Primitive",
  491. 'python_path': 'path',
  492. 'algorithm_types': [
  493. metadata_base.PrimitiveAlgorithmType.PRINCIPAL_COMPONENT_ANALYSIS,
  494. ],
  495. 'primitive_family': metadata_base.PrimitiveFamily.FEATURE_SELECTION,
  496. 'installation': [{
  497. 'type': metadata_base.PrimitiveInstallationType.PIP,
  498. 'package': 'foobar',
  499. 'version': '0.1.0',
  500. }],
  501. 'source': {
  502. 'name': 'Test author',
  503. 'contact': 'mailto:test@example.com',
  504. 'uris': ['http://someplace'],
  505. }
  506. })
  507. # Ensure a warning message is generated for a primitive with no source
  508. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  509. metadata_base.PrimitiveMetadata()._validate_contact_information(bad_metadata.query())
  510. self.assertEqual(len(cm.records), 1)
  511. self.assertEqual(cm.records[0].msg, "%(python_path)s: No \"source\" field in the primitive metadata. Metadata should contain contact information and bug reporting URI.")
  512. # Ensure a warning message is NOT generated when source is present
  513. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  514. logger.debug("Dummy log")
  515. metadata_base.PrimitiveMetadata()._validate_contact_information(good_metadata.query())
  516. self.assertEqual(len(cm.records), 1)
  517. def test_validation_will_warn_on_missing_description(self):
  518. logger = logging.getLogger(metadata_base.__name__)
  519. bad_metadata = metadata_base.PrimitiveMetadata({
  520. 'id': 'id',
  521. 'version': '0.1.0',
  522. 'name': "Test Primitive",
  523. 'python_path': 'path',
  524. 'algorithm_types': [
  525. metadata_base.PrimitiveAlgorithmType.PRINCIPAL_COMPONENT_ANALYSIS,
  526. ],
  527. 'primitive_family': metadata_base.PrimitiveFamily.FEATURE_SELECTION,
  528. 'installation': [{
  529. 'type': metadata_base.PrimitiveInstallationType.PIP,
  530. 'package': 'foobar',
  531. 'version': '0.1.0',
  532. }],
  533. })
  534. good_metadata = metadata_base.PrimitiveMetadata({
  535. 'id': 'id',
  536. 'version': '0.1.0',
  537. 'name': "Test Primitive",
  538. 'python_path': 'path',
  539. 'algorithm_types': [
  540. metadata_base.PrimitiveAlgorithmType.PRINCIPAL_COMPONENT_ANALYSIS,
  541. ],
  542. 'primitive_family': metadata_base.PrimitiveFamily.FEATURE_SELECTION,
  543. 'installation': [{
  544. 'type': metadata_base.PrimitiveInstallationType.PIP,
  545. 'package': 'foobar',
  546. 'version': '0.1.0',
  547. }],
  548. 'description': 'primitive description'
  549. })
  550. # Ensure a warning message is generated for a primitive with no description
  551. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  552. metadata_base.PrimitiveMetadata()._validate_description(bad_metadata.query())
  553. self.assertEqual(len(cm.records), 1)
  554. self.assertEqual(cm.records[0].msg, "%(python_path)s: Primitive is not providing a description through its docstring.")
  555. # Ensure a warning message is NOT generated when description is present
  556. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  557. logger.debug("Dummy log")
  558. metadata_base.PrimitiveMetadata()._validate_description(good_metadata.query())
  559. self.assertEqual(len(cm.records), 1)
  560. def test_validation_will_warn_on_empty_description(self):
  561. logger = logging.getLogger(metadata_base.__name__)
  562. bad_metadata = metadata_base.PrimitiveMetadata({
  563. 'id': 'id',
  564. 'version': '0.1.0',
  565. 'name': "Test Primitive",
  566. 'python_path': 'path',
  567. 'algorithm_types': [
  568. metadata_base.PrimitiveAlgorithmType.PRINCIPAL_COMPONENT_ANALYSIS,
  569. ],
  570. 'primitive_family': metadata_base.PrimitiveFamily.FEATURE_SELECTION,
  571. 'installation': [{
  572. 'type': metadata_base.PrimitiveInstallationType.PIP,
  573. 'package': 'foobar',
  574. 'version': '0.1.0',
  575. }],
  576. 'description': ''
  577. })
  578. good_metadata = metadata_base.PrimitiveMetadata({
  579. 'id': 'id',
  580. 'version': '0.1.0',
  581. 'name': "Test Primitive",
  582. 'python_path': 'path',
  583. 'algorithm_types': [
  584. metadata_base.PrimitiveAlgorithmType.PRINCIPAL_COMPONENT_ANALYSIS,
  585. ],
  586. 'primitive_family': metadata_base.PrimitiveFamily.FEATURE_SELECTION,
  587. 'installation': [{
  588. 'type': metadata_base.PrimitiveInstallationType.PIP,
  589. 'package': 'foobar',
  590. 'version': '0.1.0',
  591. }],
  592. 'description': 'primitive description'
  593. })
  594. # Ensure a warning message is generated for a primitive with no description
  595. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  596. metadata_base.PrimitiveMetadata()._validate_description(bad_metadata.query())
  597. self.assertEqual(len(cm.records), 1)
  598. self.assertEqual(cm.records[0].msg, "%(python_path)s: Primitive is not providing a description through its docstring.")
  599. # Ensure a warning message is NOT generated when description is present
  600. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  601. logger.debug("Dummy log")
  602. metadata_base.PrimitiveMetadata()._validate_description(good_metadata.query())
  603. self.assertEqual(len(cm.records), 1)
  604. def test_validation_will_warn_on_inherited_description(self):
  605. logger = logging.getLogger(metadata_base.__name__)
  606. bad_metadata = metadata_base.PrimitiveMetadata({
  607. 'id': 'id',
  608. 'version': '0.1.0',
  609. 'name': "Test Primitive",
  610. 'python_path': 'path',
  611. 'algorithm_types': [
  612. metadata_base.PrimitiveAlgorithmType.PRINCIPAL_COMPONENT_ANALYSIS,
  613. ],
  614. 'primitive_family': metadata_base.PrimitiveFamily.FEATURE_SELECTION,
  615. 'installation': [{
  616. 'type': metadata_base.PrimitiveInstallationType.PIP,
  617. 'package': 'foobar',
  618. 'version': '0.1.0',
  619. }],
  620. 'description': 'A base class for primitives description'
  621. })
  622. good_metadata = metadata_base.PrimitiveMetadata({
  623. 'id': 'id',
  624. 'version': '0.1.0',
  625. 'name': "Test Primitive",
  626. 'python_path': 'path',
  627. 'algorithm_types': [
  628. metadata_base.PrimitiveAlgorithmType.PRINCIPAL_COMPONENT_ANALYSIS,
  629. ],
  630. 'primitive_family': metadata_base.PrimitiveFamily.FEATURE_SELECTION,
  631. 'installation': [{
  632. 'type': metadata_base.PrimitiveInstallationType.PIP,
  633. 'package': 'foobar',
  634. 'version': '0.1.0',
  635. }],
  636. 'description': 'primitive description'
  637. })
  638. # Ensure a warning message is generated for a primitive with no description
  639. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  640. metadata_base.PrimitiveMetadata()._validate_description(bad_metadata.query())
  641. self.assertEqual(len(cm.records), 1)
  642. self.assertEqual(cm.records[0].msg, "%(python_path)s: Primitive is not providing a description through its docstring.")
  643. # Ensure a warning message is NOT generated when description is present
  644. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  645. logger.debug("Dummy log")
  646. metadata_base.PrimitiveMetadata()._validate_description(good_metadata.query())
  647. self.assertEqual(len(cm.records), 1)
  648. def test_neural_network_mixin(self):
  649. class MyNeuralNetworkModuleBase:
  650. pass
  651. class Params(params.Params):
  652. pass
  653. class MyNeuralNetworkModule(MyNeuralNetworkModuleBase):
  654. pass
  655. # Silence any validation warnings.
  656. with utils.silence():
  657. class TestPrimitive(
  658. base.NeuralNetworkModuleMixin[Inputs, Outputs, Params, Hyperparams, MyNeuralNetworkModuleBase],
  659. unsupervised_learning.UnsupervisedLearnerPrimitiveBase[Inputs, Outputs, Params, Hyperparams],
  660. ):
  661. metadata = metadata_base.PrimitiveMetadata({
  662. 'id': '4164deb6-2418-4c96-9959-3d475dcf9584',
  663. 'version': '0.1.0',
  664. 'name': "Test neural network module",
  665. 'python_path': 'd3m.primitives.layer.super.TestPrimitive',
  666. 'algorithm_types': [
  667. metadata_base.PrimitiveAlgorithmType.CONVOLUTIONAL_NEURAL_NETWORK_LAYER,
  668. ],
  669. 'primitive_family': metadata_base.PrimitiveFamily.LAYER,
  670. })
  671. def produce(self, *, inputs: Inputs, timeout: float = None, iterations: int = None) -> base.CallResult[Outputs]:
  672. raise exceptions.NotSupportedError
  673. def set_training_data(self, *, inputs: Inputs) -> None:
  674. raise exceptions.NotSupportedError
  675. def fit(self, *, timeout: float = None, iterations: int = None) -> base.CallResult[None]:
  676. raise exceptions.NotSupportedError
  677. def get_params(self) -> Params:
  678. return Params()
  679. def set_params(self, *, params: Params) -> None:
  680. pass
  681. def get_module(self, *, input_module: MyNeuralNetworkModuleBase) -> MyNeuralNetworkModuleBase:
  682. return MyNeuralNetworkModule()
  683. if __name__ == '__main__':
  684. unittest.main()

全栈的自动化机器学习系统,主要针对多变量时间序列数据的异常检测。TODS提供了详尽的用于构建基于机器学习的异常检测系统的模块,它们包括:数据处理(data processing),时间序列处理( time series processing),特征分析(feature analysis),检测算法(detection algorithms),和强化模块( reinforcement module)。这些模块所提供的功能包括常见的数据预处理、时间序列数据的平滑或变换,从时域或频域中抽取特征、多种多样的检测算