You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_hyperparams.py 90 kB

first commit Former-commit-id: 08bc23ba02cffbce3cf63962390a65459a132e48 [formerly 0795edd4834b9b7dc66db8d10d4cbaf42bbf82cb] [formerly b5010b42541add7e2ea2578bf2da537efc457757 [formerly a7ca09c2c34c4fc8b3d8e01fcfa08eeeb2cae99d]] [formerly 615058473a2177ca5b89e9edbb797f4c2a59c7e5 [formerly 743d8dfc6843c4c205051a8ab309fbb2116c895e] [formerly bb0ea98b1e14154ef464e2f7a16738705894e54b [formerly 960a69da74b81ef8093820e003f2d6c59a34974c]]] [formerly 2fa3be52c1b44665bc81a7cc7d4cea4bbf0d91d5 [formerly 2054589f0898627e0a17132fd9d4cc78efc91867] [formerly 3b53730e8a895e803dfdd6ca72bc05e17a4164c1 [formerly 8a2fa8ab7baf6686d21af1f322df46fd58c60e69]] [formerly 87d1e3a07a19d03c7d7c94d93ab4fa9f58dada7c [formerly f331916385a5afac1234854ee8d7f160f34b668f] [formerly 69fb3c78a483343f5071da4f7e2891b83a49dd18 [formerly 386086f05aa9487f65bce2ee54438acbdce57650]]]] Former-commit-id: a00aed8c934a6460c4d9ac902b9a74a3d6864697 [formerly 26fdeca29c2f07916d837883983ca2982056c78e] [formerly 0e3170d41a2f99ecf5c918183d361d4399d793bf [formerly 3c12ad4c88ac5192e0f5606ac0d88dd5bf8602dc]] [formerly d5894f84f2fd2e77a6913efdc5ae388cf1be0495 [formerly ad3e7bc670ff92c992730d29c9d3aa1598d844e8] [formerly 69fb3c78a483343f5071da4f7e2891b83a49dd18]] Former-commit-id: 3c19c9fae64f6106415fbc948a4dc613b9ee12f8 [formerly 467ddc0549c74bb007e8f01773bb6dc9103b417d] [formerly 5fa518345d958e2760e443b366883295de6d991c [formerly 3530e130b9fdb7280f638dbc2e785d2165ba82aa]] Former-commit-id: 9f5d473d42a435ec0d60149939d09be1acc25d92 [formerly be0b25c4ec2cde052a041baf0e11f774a158105d] Former-commit-id: 9eca71cb73ba9edccd70ac06a3b636b8d4093b04
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795
  1. import json
  2. import logging
  3. import os
  4. import typing
  5. import pickle
  6. import subprocess
  7. import sys
  8. import unittest
  9. from collections import OrderedDict
  10. import frozendict
  11. import numpy
  12. from sklearn.utils import validation as sklearn_validation
  13. from d3m import container, exceptions, index, utils
  14. from d3m.metadata import base as metadata_base, hyperparams
  15. from d3m.primitive_interfaces import base, transformer
  16. TEST_PRIMITIVES_DIR = os.path.join(os.path.dirname(__file__), 'data', 'primitives')
  17. sys.path.insert(0, TEST_PRIMITIVES_DIR)
  18. from test_primitives.monomial import MonomialPrimitive
  19. from test_primitives.random import RandomPrimitive
  20. from test_primitives.sum import SumPrimitive
  21. from test_primitives.increment import IncrementPrimitive
  22. # It's defined at global scope so it can be pickled.
  23. class TestPicklingHyperparams(hyperparams.Hyperparams):
  24. choice = hyperparams.Choice(
  25. choices={
  26. 'alpha': hyperparams.Hyperparams.define(OrderedDict(
  27. value=hyperparams.Union(
  28. OrderedDict(
  29. float=hyperparams.Hyperparameter[float](0),
  30. int=hyperparams.Hyperparameter[int](0)
  31. ),
  32. default='float'
  33. ),
  34. ))
  35. },
  36. default='alpha',
  37. semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter']
  38. )
  39. class TestHyperparams(unittest.TestCase):
  40. def test_hyperparameter(self):
  41. hyperparameter = hyperparams.Hyperparameter[str]('nothing')
  42. self.assertEqual(hyperparameter.get_default(), 'nothing')
  43. with self.assertLogs(hyperparams.logger) as cm:
  44. self.assertEqual(hyperparameter.sample(42), 'nothing')
  45. self.assertEqual(len(cm.records), 1)
  46. with self.assertLogs(hyperparams.logger) as cm:
  47. self.assertEqual(hyperparameter.sample_multiple(0, 1, 42), ('nothing',))
  48. self.assertEqual(len(cm.records), 1)
  49. with self.assertLogs(hyperparams.logger) as cm:
  50. self.assertEqual(hyperparameter.sample_multiple(0, 0, 42), ())
  51. self.assertEqual(len(cm.records), 1)
  52. self.assertEqual(hyperparameter.to_simple_structure(), {
  53. 'default': 'nothing',
  54. 'semantic_types': [],
  55. 'structural_type': str,
  56. 'type': hyperparams.Hyperparameter,
  57. })
  58. self.assertEqual(hyperparameter.value_to_json_structure(hyperparameter.get_default()), 'nothing')
  59. with self.assertLogs(hyperparams.logger) as cm:
  60. self.assertEqual(hyperparameter.value_to_json_structure(hyperparameter.sample(42)), 'nothing')
  61. self.assertEqual(len(cm.records), 1)
  62. self.assertEqual(hyperparameter.value_from_json_structure(hyperparameter.value_to_json_structure(hyperparameter.get_default())), hyperparameter.get_default())
  63. with self.assertLogs(hyperparams.logger) as cm:
  64. self.assertEqual(hyperparameter.value_from_json_structure(hyperparameter.value_to_json_structure(hyperparameter.sample(42))), hyperparameter.sample(42))
  65. self.assertEqual(len(cm.records), 1)
  66. with self.assertRaisesRegex(TypeError, 'Value \'.*\' is not an instance of the structural type'):
  67. hyperparams.Hyperparameter[int]('nothing')
  68. with self.assertRaisesRegex(ValueError, '\'max_samples\' cannot be larger than'):
  69. hyperparameter.sample_multiple(0, 2, 42)
  70. def test_constant(self):
  71. hyperparameter = hyperparams.Constant(12345)
  72. self.assertEqual(hyperparameter.get_default(), 12345)
  73. self.assertEqual(hyperparameter.sample(), 12345)
  74. self.assertEqual(hyperparameter.sample_multiple(0, 1, 42), (12345,))
  75. self.assertEqual(hyperparameter.sample_multiple(0, 0, 42), ())
  76. self.assertEqual(hyperparameter.to_simple_structure(), {
  77. 'default': 12345,
  78. 'semantic_types': [],
  79. 'structural_type': int,
  80. 'type': hyperparams.Constant,
  81. })
  82. self.assertEqual(hyperparameter.value_to_json_structure(hyperparameter.get_default()), 12345)
  83. self.assertEqual(hyperparameter.value_to_json_structure(hyperparameter.sample(42)), 12345)
  84. self.assertEqual(hyperparameter.value_from_json_structure(hyperparameter.value_to_json_structure(hyperparameter.get_default())), hyperparameter.get_default())
  85. self.assertEqual(hyperparameter.value_from_json_structure(hyperparameter.value_to_json_structure(hyperparameter.sample(42))), hyperparameter.sample(42))
  86. with self.assertRaisesRegex(TypeError, 'Value \'.*\' is not an instance of the structural type'):
  87. hyperparams.Hyperparameter[int]('different')
  88. with self.assertRaisesRegex(ValueError, 'Value \'.*\' is not the constant default value'):
  89. hyperparameter.validate(54321)
  90. with self.assertRaisesRegex(ValueError, '\'max_samples\' cannot be larger than'):
  91. self.assertEqual(hyperparameter.sample_multiple(0, 2, 42), {12345})
  92. hyperparameter = hyperparams.Constant('constant')
  93. with self.assertRaisesRegex(ValueError, 'Value \'.*\' is not the constant default value'):
  94. hyperparameter.validate('different')
  95. def test_bounded(self):
  96. hyperparameter = hyperparams.Bounded[float](0.0, 1.0, 0.2)
  97. self.assertEqual(hyperparameter.get_default(), 0.2)
  98. with self.assertLogs(hyperparams.logger) as cm:
  99. self.assertEqual(hyperparameter.sample(42), 0.37454011884736255)
  100. self.assertEqual(len(cm.records), 1)
  101. with self.assertLogs(hyperparams.logger) as cm:
  102. self.assertEqual(hyperparameter.sample_multiple(0, 1, 7), (0.22733907982646523,))
  103. self.assertEqual(len(cm.records), 1)
  104. self.assertEqual(hyperparameter.sample_multiple(0, 0, 42), ())
  105. self.assertEqual(hyperparameter.to_simple_structure(), {
  106. 'default': 0.2,
  107. 'semantic_types': [],
  108. 'structural_type': float,
  109. 'type': hyperparams.Bounded,
  110. 'lower': 0.0,
  111. 'upper': 1.0,
  112. 'lower_inclusive': True,
  113. 'upper_inclusive': True,
  114. })
  115. self.assertEqual(hyperparameter.value_to_json_structure(hyperparameter.get_default()), 0.2)
  116. with self.assertLogs(hyperparams.logger) as cm:
  117. self.assertEqual(hyperparameter.value_to_json_structure(hyperparameter.sample(42)), 0.37454011884736255)
  118. self.assertEqual(len(cm.records), 1)
  119. self.assertEqual(hyperparameter.value_from_json_structure(hyperparameter.value_to_json_structure(hyperparameter.get_default())), hyperparameter.get_default())
  120. with self.assertLogs(hyperparams.logger) as cm:
  121. self.assertEqual(hyperparameter.value_from_json_structure(hyperparameter.value_to_json_structure(hyperparameter.sample(42))), hyperparameter.sample(42))
  122. self.assertEqual(len(cm.records), 1)
  123. with self.assertRaisesRegex(TypeError, 'Value \'.*\' is not an instance of the structural type'):
  124. hyperparams.Bounded[str]('lower', 'upper', 0.2)
  125. with self.assertRaisesRegex(TypeError, 'Lower bound \'.*\' is not an instance of the structural type'):
  126. hyperparams.Bounded[str](0.0, 'upper', 'default')
  127. with self.assertRaisesRegex(TypeError, 'Upper bound \'.*\' is not an instance of the structural type'):
  128. hyperparams.Bounded[str]('lower', 1.0, 'default')
  129. with self.assertRaisesRegex(ValueError, 'Value \'.*\' is outside of range'):
  130. hyperparams.Bounded[str]('lower', 'upper', 'default')
  131. with self.assertRaisesRegex(ValueError, 'Value \'.*\' is outside of range'):
  132. hyperparams.Bounded[float](0.0, 1.0, 1.2)
  133. hyperparams.Bounded[typing.Optional[float]](0.0, None, 0.2)
  134. hyperparams.Bounded[typing.Optional[float]](None, 1.0, 0.2)
  135. with self.assertRaisesRegex(ValueError, 'Lower and upper bounds cannot both be None'):
  136. hyperparams.Bounded[typing.Optional[float]](None, None, 0.2)
  137. with self.assertRaisesRegex(TypeError, 'Value \'.*\' is not an instance of the structural type'):
  138. hyperparams.Bounded[float](0.0, 1.0, None)
  139. with self.assertRaises(TypeError):
  140. hyperparams.Bounded[typing.Optional[float]](0.0, 1.0, None)
  141. hyperparams.Bounded[typing.Optional[float]](None, 1.0, None)
  142. hyperparams.Bounded[typing.Optional[float]](0.0, None, None)
  143. hyperparameter = hyperparams.Bounded[float](0.0, None, 0.2)
  144. with self.assertRaisesRegex(ValueError, '\'max_samples\' cannot be larger than'):
  145. hyperparameter.sample_multiple(0, 2, 42)
  146. with self.assertRaisesRegex(exceptions.InvalidArgumentValueError, 'must be finite'):
  147. hyperparams.Bounded[typing.Optional[float]](0.0, numpy.nan, 0)
  148. with self.assertRaisesRegex(exceptions.InvalidArgumentValueError, 'must be finite'):
  149. hyperparams.Bounded[typing.Optional[float]](numpy.inf, 0.0, 0)
  150. def test_enumeration(self):
  151. hyperparameter = hyperparams.Enumeration(['a', 'b', 1, 2, None], None)
  152. self.assertEqual(hyperparameter.get_default(), None)
  153. self.assertEqual(hyperparameter.sample(42), 2)
  154. self.assertEqual(hyperparameter.sample_multiple(0, 1, 42), ())
  155. self.assertEqual(hyperparameter.sample_multiple(0, 2, 42), ('b', None))
  156. self.assertEqual(hyperparameter.sample_multiple(0, 3, 42), ('b', None))
  157. self.assertEqual(hyperparameter.to_simple_structure(), {
  158. 'default': None,
  159. 'semantic_types': [],
  160. 'structural_type': typing.Union[str, int, type(None)],
  161. 'type': hyperparams.Enumeration,
  162. 'values': ['a', 'b', 1, 2, None],
  163. })
  164. self.assertEqual(hyperparameter.value_to_json_structure(hyperparameter.get_default()), None)
  165. self.assertEqual(hyperparameter.value_to_json_structure(hyperparameter.sample(42)), 2)
  166. self.assertEqual(hyperparameter.value_from_json_structure(hyperparameter.value_to_json_structure(hyperparameter.get_default())), hyperparameter.get_default())
  167. self.assertEqual(hyperparameter.value_from_json_structure(hyperparameter.value_to_json_structure(hyperparameter.sample(42))), hyperparameter.sample(42))
  168. with self.assertRaisesRegex(ValueError, 'Value \'.*\' is not among values'):
  169. hyperparams.Enumeration(['a', 'b', 1, 2], None)
  170. with self.assertRaisesRegex(TypeError, 'Value \'.*\' is not an instance of the structural type'):
  171. hyperparams.Enumeration[typing.Union[str, int]](['a', 'b', 1, 2, None], None)
  172. with self.assertRaisesRegex(ValueError, '\'max_samples\' cannot be larger than'):
  173. self.assertEqual(hyperparameter.sample_multiple(0, 6, 42), ())
  174. hyperparameter = hyperparams.Enumeration(['a', 'b', 'c'], 'a')
  175. self.assertEqual(hyperparameter.value_to_json_structure('c'), 'c')
  176. self.assertEqual(hyperparameter.value_from_json_structure(hyperparameter.value_to_json_structure('c')), 'c')
  177. with self.assertRaisesRegex(exceptions.InvalidArgumentValueError, 'contain duplicates'):
  178. hyperparams.Enumeration([1.0, 1], 1)
  179. hyperparameter = hyperparams.Enumeration([1.0, float('nan'), float('infinity'), float('-infinity')], 1.0)
  180. hyperparameter.validate(float('nan'))
  181. self.assertEqual(utils.to_json_structure(hyperparameter.to_simple_structure()), {
  182. 'type': 'd3m.metadata.hyperparams.Enumeration',
  183. 'default': 1.0,
  184. 'structural_type': 'float',
  185. 'semantic_types': [],
  186. 'values': [1.0, 'nan', 'inf', '-inf'],
  187. })
  188. self.assertEqual(json.dumps(hyperparameter.value_to_json_structure(float('nan')), allow_nan=False), '{"encoding": "pickle", "value": "gANHf/gAAAAAAAAu"}')
  189. self.assertEqual(json.dumps(hyperparameter.value_to_json_structure(float('inf')), allow_nan=False), '{"encoding": "pickle", "value": "gANHf/AAAAAAAAAu"}')
  190. def test_other(self):
  191. hyperparameter = hyperparams.UniformInt(1, 10, 2)
  192. self.assertEqual(hyperparameter.get_default(), 2)
  193. self.assertEqual(hyperparameter.sample(42), 7)
  194. self.assertEqual(hyperparameter.sample_multiple(0, 1, 42), ())
  195. self.assertEqual(hyperparameter.sample_multiple(0, 2, 42), (4, 8))
  196. self.assertEqual(hyperparameter.to_simple_structure(), {
  197. 'default': 2,
  198. 'semantic_types': [],
  199. 'structural_type': int,
  200. 'type': hyperparams.UniformInt,
  201. 'lower': 1,
  202. 'upper': 10,
  203. 'lower_inclusive': True,
  204. 'upper_inclusive': False,
  205. })
  206. with self.assertRaisesRegex(ValueError, 'Value \'.*\' is outside of range'):
  207. hyperparams.UniformInt(1, 10, 0)
  208. with self.assertRaisesRegex(ValueError, '\'max_samples\' cannot be larger than'):
  209. self.assertEqual(hyperparameter.sample_multiple(0, 10, 42), ())
  210. hyperparameter = hyperparams.Uniform(1.0, 10.0, 2.0)
  211. self.assertEqual(hyperparameter.get_default(), 2.0)
  212. self.assertEqual(hyperparameter.sample(42), 4.370861069626263)
  213. self.assertEqual(hyperparameter.to_simple_structure(), {
  214. 'default': 2.0,
  215. 'semantic_types': [],
  216. 'structural_type': float,
  217. 'type': hyperparams.Uniform,
  218. 'lower': 1.0,
  219. 'upper': 10.0,
  220. 'lower_inclusive': True,
  221. 'upper_inclusive': False,
  222. })
  223. with self.assertRaisesRegex(ValueError, 'Value \'.*\' is outside of range'):
  224. hyperparams.Uniform(1.0, 10.0, 0.0)
  225. hyperparameter = hyperparams.LogUniform(1.0, 10.0, 2.0)
  226. self.assertEqual(hyperparameter.get_default(), 2.0)
  227. self.assertEqual(hyperparameter.sample(42), 2.368863950364078)
  228. self.assertEqual(hyperparameter.to_simple_structure(), {
  229. 'default': 2.0,
  230. 'semantic_types': [],
  231. 'structural_type': float,
  232. 'type': hyperparams.LogUniform,
  233. 'lower': 1.0,
  234. 'upper': 10.0,
  235. 'lower_inclusive': True,
  236. 'upper_inclusive': False,
  237. })
  238. with self.assertRaisesRegex(ValueError, 'Value \'.*\' is outside of range'):
  239. hyperparams.LogUniform(1.0, 10.0, 0.0)
  240. hyperparameter = hyperparams.UniformBool(True)
  241. self.assertEqual(hyperparameter.get_default(), True)
  242. self.assertEqual(hyperparameter.sample(42), True)
  243. self.assertEqual(hyperparameter.to_simple_structure(), {
  244. 'default': True,
  245. 'semantic_types': [],
  246. 'structural_type': bool,
  247. 'type': hyperparams.UniformBool,
  248. })
  249. with self.assertRaises(exceptions.InvalidArgumentValueError):
  250. hyperparams.UniformInt(0, 1, 1, lower_inclusive=False, upper_inclusive=False)
  251. hyperparameter = hyperparams.UniformInt(0, 2, 1, lower_inclusive=False, upper_inclusive=False)
  252. self.assertEqual(hyperparameter.sample(42), 1)
  253. with self.assertRaises(exceptions.InvalidArgumentValueError):
  254. hyperparameter.sample_multiple(2, 2, 42)
  255. self.assertEqual(hyperparameter.sample_multiple(2, 2, 42, with_replacement=True), (1, 1))
  256. def test_union(self):
  257. hyperparameter = hyperparams.Union(
  258. OrderedDict(
  259. none=hyperparams.Hyperparameter(None),
  260. range=hyperparams.UniformInt(1, 10, 2)
  261. ),
  262. 'none',
  263. )
  264. self.assertEqual(hyperparameter.get_default(), None)
  265. self.assertEqual(hyperparameter.sample(45), 4)
  266. self.assertEqual(hyperparameter.to_simple_structure(), {
  267. 'default': None,
  268. 'semantic_types': [],
  269. 'structural_type': typing.Optional[int],
  270. 'type': hyperparams.Union,
  271. 'configuration': {
  272. 'none': {
  273. 'default': None,
  274. 'semantic_types': [],
  275. 'structural_type': type(None),
  276. 'type': hyperparams.Hyperparameter,
  277. },
  278. 'range': {
  279. 'default': 2,
  280. 'semantic_types': [],
  281. 'structural_type': int,
  282. 'type': hyperparams.UniformInt,
  283. 'lower': 1,
  284. 'upper': 10,
  285. 'lower_inclusive': True,
  286. 'upper_inclusive': False,
  287. }
  288. }
  289. })
  290. self.assertEqual(hyperparameter.value_to_json_structure(hyperparameter.get_default()), {'case': 'none', 'value': None})
  291. self.assertEqual(hyperparameter.value_to_json_structure(hyperparameter.sample(45)), {'case': 'range', 'value': 4})
  292. self.assertEqual(hyperparameter.value_from_json_structure(hyperparameter.value_to_json_structure(hyperparameter.get_default())), hyperparameter.get_default())
  293. with self.assertLogs(hyperparams.logger) as cm:
  294. self.assertEqual(hyperparameter.value_from_json_structure(hyperparameter.value_to_json_structure(hyperparameter.sample(42))), hyperparameter.sample(42))
  295. self.assertEqual(len(cm.records), 1)
  296. with self.assertRaisesRegex(TypeError, 'Hyper-parameter name is not a string'):
  297. hyperparams.Union(OrderedDict({1: hyperparams.Hyperparameter(None)}), 1)
  298. with self.assertRaisesRegex(TypeError, 'Hyper-parameter description is not an instance of the Hyperparameter class'):
  299. hyperparams.Union(OrderedDict(none=None), 'none')
  300. with self.assertRaisesRegex(ValueError, 'Default value \'.*\' is not in configuration'):
  301. hyperparams.Union(OrderedDict(range=hyperparams.UniformInt(1, 10, 2)), 'none')
  302. hyperparams.Union(OrderedDict(range=hyperparams.UniformInt(1, 10, 2), default=hyperparams.Hyperparameter('nothing')), 'default')
  303. hyperparams.Union[typing.Union[str, int]](OrderedDict(range=hyperparams.UniformInt(1, 10, 2), default=hyperparams.Hyperparameter('nothing')), 'default')
  304. with self.assertRaisesRegex(TypeError, 'Hyper-parameter \'.*\' is not a subclass of the structural type'):
  305. hyperparams.Union[str](OrderedDict(range=hyperparams.UniformInt(1, 10, 2), default=hyperparams.Hyperparameter('nothing')), 'default')
  306. def test_hyperparams(self):
  307. class TestHyperparams(hyperparams.Hyperparams):
  308. a = hyperparams.Union(OrderedDict(
  309. range=hyperparams.UniformInt(1, 10, 2),
  310. none=hyperparams.Hyperparameter(None),
  311. ), 'range')
  312. b = hyperparams.Uniform(1.0, 10.0, 2.0)
  313. testCls = hyperparams.Hyperparams.define(OrderedDict(
  314. a=hyperparams.Union(OrderedDict(
  315. range=hyperparams.UniformInt(1, 10, 2),
  316. none=hyperparams.Hyperparameter(None),
  317. ), 'range'),
  318. b=hyperparams.Uniform(1.0, 10.0, 2.0),
  319. ), set_names=True)
  320. for cls in (TestHyperparams, testCls):
  321. self.assertEqual(cls.configuration['a'].name, 'a', cls)
  322. self.assertEqual(cls.defaults(), {'a': 2, 'b': 2.0}, cls)
  323. self.assertEqual(cls.defaults(), cls({'a': 2, 'b': 2.0}), cls)
  324. self.assertEqual(cls.sample(42), {'a': 4, 'b': 9.556428757689245}, cls)
  325. self.assertEqual(cls.sample(42), cls({'a': 4, 'b': 9.556428757689245}), cls)
  326. self.assertEqual(cls(cls.defaults(), b=3.0), {'a': 2, 'b': 3.0}, cls)
  327. self.assertEqual(cls(cls.defaults(), **{'b': 4.0}), {'a': 2, 'b': 4.0}, cls)
  328. self.assertEqual(cls.defaults('a'), 2, cls)
  329. self.assertEqual(cls.defaults('b'), 2.0, cls)
  330. self.assertEqual(cls.to_simple_structure(), {
  331. 'a': {
  332. 'default': 2,
  333. 'semantic_types': [],
  334. 'structural_type': typing.Optional[int],
  335. 'type': hyperparams.Union,
  336. 'configuration': {
  337. 'none': {
  338. 'default': None,
  339. 'semantic_types': [],
  340. 'structural_type': type(None),
  341. 'type': hyperparams.Hyperparameter,
  342. },
  343. 'range': {
  344. 'default': 2,
  345. 'lower': 1,
  346. 'semantic_types': [],
  347. 'structural_type': int,
  348. 'type': hyperparams.UniformInt,
  349. 'upper': 10,
  350. 'lower_inclusive': True,
  351. 'upper_inclusive': False,
  352. },
  353. },
  354. },
  355. 'b': {
  356. 'default': 2.0,
  357. 'semantic_types': [],
  358. 'structural_type': float,
  359. 'type': hyperparams.Uniform,
  360. 'lower': 1.0,
  361. 'upper': 10.0,
  362. 'lower_inclusive': True,
  363. 'upper_inclusive': False,
  364. }
  365. }, cls)
  366. test_hyperparams = cls({'a': cls.configuration['a'].get_default(), 'b': cls.configuration['b'].get_default()})
  367. self.assertEqual(test_hyperparams['a'], 2, cls)
  368. self.assertEqual(test_hyperparams['b'], 2.0, cls)
  369. self.assertEqual(test_hyperparams.values_to_json_structure(), {'a': {'case': 'range', 'value': 2}, 'b': 2.0})
  370. self.assertEqual(cls.values_from_json_structure(test_hyperparams.values_to_json_structure()), test_hyperparams)
  371. with self.assertRaisesRegex(ValueError, 'Not all hyper-parameters are specified', msg=cls):
  372. cls({'a': cls.configuration['a'].get_default()})
  373. with self.assertRaisesRegex(ValueError, 'Additional hyper-parameters are specified', msg=cls):
  374. cls({'a': cls.configuration['a'].get_default(), 'b': cls.configuration['b'].get_default(), 'c': 'two'})
  375. cls({'a': 3, 'b': 3.0})
  376. cls({'a': None, 'b': 3.0})
  377. test_hyperparams = cls(a=None, b=3.0)
  378. self.assertEqual(test_hyperparams['a'], None, cls)
  379. self.assertEqual(test_hyperparams['b'], 3.0, cls)
  380. with self.assertRaisesRegex(ValueError, 'Value \'.*\' for hyper-parameter \'.*\' has not validated with any of configured hyper-parameters', msg=cls):
  381. cls({'a': 0, 'b': 3.0})
  382. with self.assertRaisesRegex(ValueError, 'Value \'.*\' for hyper-parameter \'.*\' is outside of range', msg=cls):
  383. cls({'a': 3, 'b': 100.0})
  384. class SubTestHyperparams(cls):
  385. c = hyperparams.Hyperparameter[int](0)
  386. self.assertEqual(SubTestHyperparams.defaults(), {'a': 2, 'b': 2.0, 'c': 0}, cls)
  387. testSubCls = cls.define(OrderedDict(
  388. c=hyperparams.Hyperparameter[int](0),
  389. ), set_names=True)
  390. self.assertEqual(testSubCls.defaults(), {'a': 2, 'b': 2.0, 'c': 0}, cls)
  391. class ConfigurationHyperparams(hyperparams.Hyperparams):
  392. configuration = hyperparams.Uniform(1.0, 10.0, 2.0)
  393. self.assertEqual(ConfigurationHyperparams.configuration['configuration'].to_simple_structure(), hyperparams.Uniform(1.0, 10.0, 2.0).to_simple_structure())
  394. def test_numpy(self):
  395. class TestHyperparams(hyperparams.Hyperparams):
  396. value = hyperparams.Hyperparameter[container.ndarray](
  397. default=container.ndarray([0], generate_metadata=True),
  398. )
  399. values = TestHyperparams(value=container.ndarray([1, 2, 3], generate_metadata=True))
  400. self.assertEqual(values.values_to_json_structure(), {'value': {'encoding': 'pickle', 'value': 'gANjbnVtcHkuY29yZS5tdWx0aWFycmF5Cl9yZWNvbnN0cnVjdApxAGNkM20uY29udGFpbmVyLm51bXB5Cm5kYXJyYXkKcQFLAIVxAkMBYnEDh3EEUnEFfXEGKFgFAAAAbnVtcHlxByhLAUsDhXEIY251bXB5CmR0eXBlCnEJWAIAAABpOHEKSwBLAYdxC1JxDChLA1gBAAAAPHENTk5OSv////9K/////0sAdHEOYolDGAEAAAAAAAAAAgAAAAAAAAADAAAAAAAAAHEPdHEQWAgAAABtZXRhZGF0YXERY2QzbS5tZXRhZGF0YS5iYXNlCkRhdGFNZXRhZGF0YQpxEimBcRN9cRQoWBEAAABfY3VycmVudF9tZXRhZGF0YXEVY2QzbS5tZXRhZGF0YS5iYXNlCk1ldGFkYXRhRW50cnkKcRYpgXEXTn1xGChYCAAAAGVsZW1lbnRzcRljZDNtLnV0aWxzCnBtYXAKcRp9cRuFcRxScR1YDAAAAGFsbF9lbGVtZW50c3EeaBYpgXEfTn1xIChoGWgdaB5OaBFjZnJvemVuZGljdApGcm96ZW5PcmRlcmVkRGljdApxISmBcSJ9cSMoWAUAAABfZGljdHEkY2NvbGxlY3Rpb25zCk9yZGVyZWREaWN0CnElKVJxJlgPAAAAc3RydWN0dXJhbF90eXBlcSdjbnVtcHkKaW50NjQKcShzWAUAAABfaGFzaHEpTnViWAgAAABpc19lbXB0eXEqiVgRAAAAaXNfZWxlbWVudHNfZW1wdHlxK4h1hnEsYmgRaCEpgXEtfXEuKGgkaCUpUnEvKFgGAAAAc2NoZW1hcTBYQgAAAGh0dHBzOi8vbWV0YWRhdGEuZGF0YWRyaXZlbmRpc2NvdmVyeS5vcmcvc2NoZW1hcy92MC9jb250YWluZXIuanNvbnExaCdoAVgJAAAAZGltZW5zaW9ucTJoISmBcTN9cTQoaCRoJSlScTVYBgAAAGxlbmd0aHE2SwNzaClOdWJ1aClOdWJoKoloK4h1hnE3YmgpTnVidWIu'}})
  401. self.assertTrue(numpy.array_equal(TestHyperparams.values_from_json_structure(values.values_to_json_structure())['value'], values['value']))
  402. def test_set(self):
  403. set_hyperparameter = hyperparams.Set(hyperparams.Hyperparameter[int](1), [])
  404. with self.assertLogs(hyperparams.logger) as cm:
  405. self.assertEqual(set(set_hyperparameter.sample_multiple(min_samples=2, max_samples=2)), {(1,), ()})
  406. self.assertEqual(len(cm.records), 1)
  407. elements = hyperparams.Enumeration(['a', 'b', 1, 2, None], None)
  408. set_hyperparameter = hyperparams.Set(elements, ('a', 'b', 1, 2, None), 5, 5)
  409. self.assertEqual(set_hyperparameter.get_default(), ('a', 'b', 1, 2, None))
  410. self.assertEqual(set_hyperparameter.sample(45), ('b', None, 'a', 1, 2))
  411. self.assertEqual(set_hyperparameter.get_max_samples(), 1)
  412. self.assertEqual(set_hyperparameter.sample_multiple(1, 1, 42), (('b', None, 1, 'a', 2),))
  413. self.assertEqual(set_hyperparameter.sample_multiple(0, 1, 42), ())
  414. self.maxDiff = None
  415. self.assertEqual(set_hyperparameter.to_simple_structure(), {
  416. 'default': ('a', 'b', 1, 2, None),
  417. 'semantic_types': [],
  418. 'structural_type': typing.Sequence[typing.Union[str, int, type(None)]],
  419. 'type': hyperparams.Set,
  420. 'min_size': 5,
  421. 'max_size': 5,
  422. 'elements': {
  423. 'default': None,
  424. 'semantic_types': [],
  425. 'structural_type': typing.Union[str, int, type(None)],
  426. 'type': hyperparams.Enumeration,
  427. 'values': ['a', 'b', 1, 2, None],
  428. },
  429. 'is_configuration': False,
  430. })
  431. self.assertEqual(set_hyperparameter.value_to_json_structure(set_hyperparameter.get_default()), ['a', 'b', 1, 2, None])
  432. self.assertEqual(set_hyperparameter.value_to_json_structure(set_hyperparameter.sample(45)), ['b', None, 'a', 1, 2])
  433. self.assertEqual(set_hyperparameter.value_from_json_structure(set_hyperparameter.value_to_json_structure(set_hyperparameter.get_default())), set_hyperparameter.get_default())
  434. self.assertEqual(set_hyperparameter.value_from_json_structure(set_hyperparameter.value_to_json_structure(set_hyperparameter.sample(45))), set_hyperparameter.sample(45))
  435. with self.assertRaisesRegex(ValueError, 'Value \'.*\' has less than 5 elements'):
  436. elements = hyperparams.Enumeration(['a', 'b', 1, 2, None], None)
  437. hyperparams.Set(elements, (), 5, 5)
  438. with self.assertRaisesRegex(ValueError, 'Value \'.*\' is not among values'):
  439. elements = hyperparams.Enumeration(['a', 'b', 1, 2, None], None)
  440. hyperparams.Set(elements, ('a', 'b', 1, 2, 3), 5, 5)
  441. with self.assertRaisesRegex(ValueError, 'Value \'.*\' has duplicate elements'):
  442. elements = hyperparams.Enumeration(['a', 'b', 1, 2, None], None)
  443. hyperparams.Set(elements, ('a', 'b', 1, 2, 2), 5, 5)
  444. set_hyperparameter.contribute_to_class('foo')
  445. with self.assertRaises(KeyError):
  446. set_hyperparameter.get_default('foo')
  447. list_of_supported_metafeatures = ['f1', 'f2', 'f3']
  448. metafeature = hyperparams.Enumeration(list_of_supported_metafeatures, list_of_supported_metafeatures[0], semantic_types=['https://metadata.datadrivendiscovery.org/types/MetafeatureParameter'])
  449. set_hyperparameter = hyperparams.Set(metafeature, (), 0, 3)
  450. self.assertEqual(set_hyperparameter.get_default(), ())
  451. self.assertEqual(set_hyperparameter.sample(42), ('f2', 'f3'))
  452. self.assertEqual(set_hyperparameter.get_max_samples(), 8)
  453. self.assertEqual(set_hyperparameter.sample_multiple(0, 3, 42), (('f2', 'f3', 'f1'), ('f2', 'f3')))
  454. self.assertEqual(set_hyperparameter.value_to_json_structure(set_hyperparameter.get_default()), [])
  455. self.assertEqual(set_hyperparameter.value_to_json_structure(set_hyperparameter.sample(42)), ['f2', 'f3'])
  456. self.assertEqual(set_hyperparameter.value_from_json_structure(set_hyperparameter.value_to_json_structure(set_hyperparameter.get_default())), set_hyperparameter.get_default())
  457. self.assertEqual(set_hyperparameter.value_from_json_structure(set_hyperparameter.value_to_json_structure(set_hyperparameter.sample(42))), set_hyperparameter.sample(42))
  458. set_hyperparameter = hyperparams.Set(metafeature, (), 0, None)
  459. self.assertEqual(set_hyperparameter.get_default(), ())
  460. self.assertEqual(set_hyperparameter.sample(42), ('f2', 'f3'))
  461. self.assertEqual(set_hyperparameter.get_max_samples(), 8)
  462. self.assertEqual(set_hyperparameter.sample_multiple(0, 3, 42), (('f2', 'f3', 'f1'), ('f2', 'f3')))
  463. def test_set_with_hyperparams(self):
  464. elements = hyperparams.Hyperparams.define(OrderedDict(
  465. range=hyperparams.UniformInt(1, 10, 2),
  466. enum=hyperparams.Enumeration(['a', 'b', 1, 2, None], None),
  467. ))
  468. set_hyperparameter = hyperparams.Set(elements, (elements(range=2, enum='a'),), 0, 5)
  469. self.assertEqual(set_hyperparameter.get_default(), ({'range': 2, 'enum': 'a'},))
  470. self.assertEqual(set_hyperparameter.sample(45), ({'range': 4, 'enum': None}, {'range': 1, 'enum': 2}, {'range': 5, 'enum': 'b'}))
  471. self.assertEqual(set_hyperparameter.get_max_samples(), 1385980)
  472. self.assertEqual(set_hyperparameter.sample_multiple(1, 1, 42), (({'range': 8, 'enum': None}, {'range': 5, 'enum': 'b'}, {'range': 3, 'enum': 1}),))
  473. self.assertEqual(set_hyperparameter.sample_multiple(0, 1, 42), ())
  474. self.maxDiff = None
  475. self.assertEqual(set_hyperparameter.to_simple_structure(), {
  476. 'default': ({'range': 2, 'enum': 'a'},),
  477. 'elements': {
  478. 'enum': {
  479. 'default': None,
  480. 'semantic_types': [],
  481. 'structural_type': typing.Union[str, int, type(None)],
  482. 'type': hyperparams.Enumeration,
  483. 'values': ['a', 'b', 1, 2, None],
  484. },
  485. 'range': {
  486. 'default': 2,
  487. 'lower': 1,
  488. 'semantic_types': [],
  489. 'structural_type': int,
  490. 'type': hyperparams.UniformInt,
  491. 'upper': 10,
  492. 'lower_inclusive': True,
  493. 'upper_inclusive': False,
  494. },
  495. },
  496. 'is_configuration': True,
  497. 'max_size': 5,
  498. 'min_size': 0,
  499. 'semantic_types': [],
  500. 'structural_type': typing.Sequence[elements],
  501. 'type': hyperparams.Set,
  502. })
  503. self.assertEqual(set_hyperparameter.value_to_json_structure(set_hyperparameter.get_default()), [{'range': 2, 'enum': 'a'}])
  504. self.assertEqual(set_hyperparameter.value_to_json_structure(set_hyperparameter.sample(45)), [{'range': 4, 'enum': None}, {'range': 1, 'enum': 2}, {'range': 5, 'enum': 'b'}])
  505. self.assertEqual(set_hyperparameter.value_from_json_structure(set_hyperparameter.value_to_json_structure(set_hyperparameter.get_default())), set_hyperparameter.get_default())
  506. self.assertEqual(set_hyperparameter.value_from_json_structure(set_hyperparameter.value_to_json_structure(set_hyperparameter.sample(45))), set_hyperparameter.sample(45))
  507. # We have to explicitly disable setting names if we want to use it for "Set" hyper-parameter.
  508. class SetHyperparams(hyperparams.Hyperparams, set_names=False):
  509. choice = hyperparams.Choice({
  510. 'none': hyperparams.Hyperparams,
  511. 'range': hyperparams.Hyperparams.define(OrderedDict(
  512. value=hyperparams.UniformInt(1, 10, 2),
  513. )),
  514. }, 'none')
  515. class TestHyperparams(hyperparams.Hyperparams):
  516. a = set_hyperparameter
  517. b = hyperparams.Set(SetHyperparams, (SetHyperparams({'choice': {'choice': 'none'}}),), 0, 3)
  518. self.assertEqual(TestHyperparams.to_simple_structure(), {
  519. 'a': {
  520. 'type': hyperparams.Set,
  521. 'default': ({'range': 2, 'enum': 'a'},),
  522. 'structural_type': typing.Sequence[elements],
  523. 'semantic_types': [],
  524. 'elements': {
  525. 'range': {
  526. 'type': hyperparams.UniformInt,
  527. 'default': 2,
  528. 'structural_type': int,
  529. 'semantic_types': [],
  530. 'lower': 1,
  531. 'upper': 10,
  532. 'lower_inclusive': True,
  533. 'upper_inclusive': False,
  534. },
  535. 'enum': {
  536. 'type': hyperparams.Enumeration,
  537. 'default': None,
  538. 'structural_type': typing.Union[str, int, type(None)],
  539. 'semantic_types': [],
  540. 'values': ['a', 'b', 1, 2, None],
  541. },
  542. },
  543. 'is_configuration': True,
  544. 'min_size': 0,
  545. 'max_size': 5,
  546. },
  547. 'b': {
  548. 'type': hyperparams.Set,
  549. 'default': ({'choice': {'choice': 'none'}},),
  550. 'structural_type': typing.Sequence[SetHyperparams],
  551. 'semantic_types': [],
  552. 'elements': {
  553. 'choice': {
  554. 'type': hyperparams.Choice,
  555. 'default': {'choice': 'none'},
  556. 'structural_type': typing.Dict,
  557. 'semantic_types': [],
  558. 'choices': {
  559. 'none': {
  560. 'choice': {
  561. 'type': hyperparams.Hyperparameter,
  562. 'default': 'none',
  563. 'structural_type': str,
  564. 'semantic_types': ['https://metadata.datadrivendiscovery.org/types/ChoiceParameter'],
  565. },
  566. },
  567. 'range': {
  568. 'value': {
  569. 'type': hyperparams.UniformInt,
  570. 'default': 2,
  571. 'structural_type': int,
  572. 'semantic_types': [],
  573. 'lower': 1,
  574. 'upper': 10,
  575. 'lower_inclusive': True,
  576. 'upper_inclusive': False,
  577. },
  578. 'choice': {
  579. 'type': hyperparams.Hyperparameter,
  580. 'default': 'range',
  581. 'structural_type': str,
  582. 'semantic_types': ['https://metadata.datadrivendiscovery.org/types/ChoiceParameter'],
  583. },
  584. },
  585. },
  586. },
  587. },
  588. 'is_configuration': True,
  589. 'min_size': 0,
  590. 'max_size': 3,
  591. },
  592. })
  593. self.assertEqual(TestHyperparams.configuration['b'].elements.configuration['choice'].choices['range'].configuration['value'].name, 'b.choice.range.value')
  594. self.assertEqual(TestHyperparams.defaults(), {
  595. 'a': ({'range': 2, 'enum': 'a'},),
  596. 'b': ({'choice': {'choice': 'none'}},),
  597. })
  598. self.assertTrue(utils.is_instance(TestHyperparams.defaults()['a'], typing.Sequence[elements]))
  599. self.assertTrue(utils.is_instance(TestHyperparams.defaults()['b'], typing.Sequence[SetHyperparams]))
  600. with self.assertLogs(hyperparams.logger) as cm:
  601. self.assertEqual(TestHyperparams.sample(42), {
  602. 'a': ({'range': 8, 'enum': None}, {'range': 5, 'enum': 'b'}, {'range': 3, 'enum': 1}),
  603. 'b': (
  604. {
  605. 'choice': {'value': 5, 'choice': 'range'},
  606. }, {
  607. 'choice': {'value': 8, 'choice': 'range'},
  608. },
  609. ),
  610. })
  611. self.assertEqual(len(cm.records), 1)
  612. with self.assertLogs(hyperparams.logger) as cm:
  613. self.assertEqual(TestHyperparams.sample(42).values_to_json_structure(), {
  614. 'a': [{'range': 8, 'enum': None}, {'range': 5, 'enum': 'b'}, {'range': 3, 'enum': 1}],
  615. 'b': [
  616. {
  617. 'choice': {'value': 5, 'choice': 'range'},
  618. }, {
  619. 'choice': {'value': 8, 'choice': 'range'},
  620. },
  621. ],
  622. })
  623. self.assertEqual(len(cm.records), 1)
  624. with self.assertLogs(hyperparams.logger) as cm:
  625. self.assertEqual(TestHyperparams.values_from_json_structure(TestHyperparams.sample(42).values_to_json_structure()), TestHyperparams.sample(42))
  626. self.assertEqual(len(cm.records), 1)
  627. self.assertEqual(len(list(TestHyperparams.traverse())), 8)
  628. self.assertEqual(TestHyperparams.defaults('a'), ({'range': 2, 'enum': 'a'},))
  629. self.assertEqual(TestHyperparams.defaults('a.range'), 2)
  630. # Default of a whole "Set" hyper-parameter can be different than of nested hyper-parameters.
  631. self.assertEqual(TestHyperparams.defaults('a.enum'), None)
  632. self.assertEqual(TestHyperparams.defaults('b'), ({'choice': {'choice': 'none'}},))
  633. self.assertEqual(TestHyperparams.defaults('b.choice'), {'choice': 'none'})
  634. self.assertEqual(TestHyperparams.defaults('b.choice.none'), {'choice': 'none'})
  635. self.assertEqual(TestHyperparams.defaults('b.choice.none.choice'), 'none')
  636. self.assertEqual(TestHyperparams.defaults('b.choice.range'), {'choice': 'range', 'value': 2})
  637. self.assertEqual(TestHyperparams.defaults('b.choice.range.value'), 2)
  638. self.assertEqual(TestHyperparams.defaults('b.choice.range.choice'), 'range')
  639. self.assertEqual(TestHyperparams(TestHyperparams.defaults(), b=(
  640. SetHyperparams({
  641. 'choice': {'value': 5, 'choice': 'range'},
  642. }),
  643. SetHyperparams({
  644. 'choice': {'value': 8, 'choice': 'range'},
  645. }),
  646. )), {
  647. 'a': ({'range': 2, 'enum': 'a'},),
  648. 'b': (
  649. {
  650. 'choice': {'value': 5, 'choice': 'range'},
  651. },
  652. {
  653. 'choice': {'value': 8, 'choice': 'range'},
  654. },
  655. ),
  656. })
  657. self.assertEqual(TestHyperparams(TestHyperparams.defaults(), **{'a': (
  658. elements({'range': 8, 'enum': None}),
  659. elements({'range': 5, 'enum': 'b'}),
  660. elements({'range': 3, 'enum': 1}),
  661. )}), {
  662. 'a': (
  663. {'range': 8, 'enum': None},
  664. {'range': 5, 'enum': 'b'},
  665. {'range': 3, 'enum': 1},
  666. ),
  667. 'b': ({'choice': {'choice': 'none'}},)
  668. })
  669. self.assertEqual(TestHyperparams.defaults().replace({'a': (
  670. elements({'range': 8, 'enum': None}),
  671. elements({'range': 5, 'enum': 'b'}),
  672. elements({'range': 3, 'enum': 1}),
  673. )}), {
  674. 'a': (
  675. {'range': 8, 'enum': None},
  676. {'range': 5, 'enum': 'b'},
  677. {'range': 3, 'enum': 1},
  678. ),
  679. 'b': ({'choice': {'choice': 'none'}},),
  680. })
  681. def test_choice(self):
  682. choices_hyperparameter = hyperparams.Choice({
  683. 'none': hyperparams.Hyperparams,
  684. 'range': hyperparams.Hyperparams.define(OrderedDict(
  685. # To test that we can use this name.
  686. configuration=hyperparams.UniformInt(1, 10, 2),
  687. )),
  688. }, 'none')
  689. # Class should not be changed directly (when adding "choice").
  690. self.assertEqual(hyperparams.Hyperparams.configuration, {})
  691. self.assertEqual(choices_hyperparameter.get_default(), {'choice': 'none'})
  692. with self.assertLogs(hyperparams.logger) as cm:
  693. self.assertEqual(choices_hyperparameter.sample(45), {'choice': 'range', 'configuration': 4})
  694. self.assertEqual(len(cm.records), 1)
  695. self.assertEqual(choices_hyperparameter.get_max_samples(), 10)
  696. with self.assertLogs(hyperparams.logger) as cm:
  697. self.assertEqual(choices_hyperparameter.sample_multiple(0, 3, 42), (frozendict.frozendict({'choice': 'range', 'configuration': 8}), frozendict.frozendict({'choice': 'none'})))
  698. self.assertEqual(len(cm.records), 1)
  699. self.maxDiff = None
  700. self.assertEqual(choices_hyperparameter.to_simple_structure(), {
  701. 'default': {'choice': 'none'},
  702. 'semantic_types': [],
  703. 'structural_type': typing.Dict,
  704. 'type': hyperparams.Choice,
  705. 'choices': {
  706. 'none': {
  707. 'choice': {
  708. 'default': 'none',
  709. 'semantic_types': ['https://metadata.datadrivendiscovery.org/types/ChoiceParameter'],
  710. 'structural_type': str,
  711. 'type': hyperparams.Hyperparameter,
  712. },
  713. },
  714. 'range': {
  715. 'choice': {
  716. 'default': 'range',
  717. 'semantic_types': ['https://metadata.datadrivendiscovery.org/types/ChoiceParameter'],
  718. 'structural_type': str,
  719. 'type': hyperparams.Hyperparameter,
  720. },
  721. 'configuration': {
  722. 'default': 2,
  723. 'lower': 1,
  724. 'lower_inclusive': True,
  725. 'upper': 10,
  726. 'upper_inclusive': False,
  727. 'semantic_types': [],
  728. 'structural_type': int,
  729. 'type': hyperparams.UniformInt,
  730. },
  731. },
  732. },
  733. })
  734. self.assertEqual(choices_hyperparameter.value_to_json_structure(choices_hyperparameter.get_default()), {'choice': 'none'})
  735. with self.assertLogs(hyperparams.logger) as cm:
  736. self.assertEqual(choices_hyperparameter.value_to_json_structure(choices_hyperparameter.sample(45)), {'configuration': 4, 'choice': 'range'})
  737. self.assertEqual(len(cm.records), 1)
  738. self.assertEqual(choices_hyperparameter.value_from_json_structure(choices_hyperparameter.value_to_json_structure(choices_hyperparameter.get_default())), choices_hyperparameter.get_default())
  739. with self.assertLogs(hyperparams.logger) as cm:
  740. self.assertEqual(choices_hyperparameter.value_from_json_structure(choices_hyperparameter.value_to_json_structure(choices_hyperparameter.sample(45))), choices_hyperparameter.sample(45))
  741. self.assertEqual(len(cm.records), 1)
  742. # We have to explicitly disable setting names if we want to use it for "Choice" hyper-parameter.
  743. class ChoicesHyperparams(hyperparams.Hyperparams, set_names=False):
  744. foo = hyperparams.UniformInt(5, 20, 10)
  745. class TestHyperparams(hyperparams.Hyperparams):
  746. a = choices_hyperparameter
  747. b = hyperparams.Choice({
  748. 'nochoice': ChoicesHyperparams,
  749. }, 'nochoice')
  750. self.assertEqual(TestHyperparams.configuration['a'].choices['range'].configuration['configuration'].name, 'a.range.configuration')
  751. self.assertEqual(TestHyperparams.defaults(), {'a': {'choice': 'none'}, 'b': {'choice': 'nochoice', 'foo': 10}})
  752. self.assertIsInstance(TestHyperparams.defaults()['a'], hyperparams.Hyperparams)
  753. self.assertIsInstance(TestHyperparams.defaults()['b'], ChoicesHyperparams)
  754. with self.assertLogs(hyperparams.logger) as cm:
  755. self.assertEqual(TestHyperparams.sample(42), {'a': {'choice': 'none'}, 'b': {'choice': 'nochoice', 'foo': 8}})
  756. self.assertEqual(len(cm.records), 1)
  757. with self.assertLogs(hyperparams.logger) as cm:
  758. self.assertEqual(TestHyperparams.sample(42).values_to_json_structure(), {'a': {'choice': 'none'}, 'b': {'choice': 'nochoice', 'foo': 8}})
  759. self.assertEqual(len(cm.records), 1)
  760. with self.assertLogs(hyperparams.logger) as cm:
  761. self.assertEqual(TestHyperparams.values_from_json_structure(TestHyperparams.sample(42).values_to_json_structure()), TestHyperparams.sample(42))
  762. self.assertEqual(len(cm.records), 1)
  763. self.assertEqual(len(list(TestHyperparams.traverse())), 7)
  764. self.assertEqual(TestHyperparams.defaults('a'), {'choice': 'none'})
  765. self.assertEqual(TestHyperparams.defaults('a.none'), {'choice': 'none'})
  766. self.assertEqual(TestHyperparams.defaults('a.none.choice'), 'none')
  767. self.assertEqual(TestHyperparams.defaults('a.range'), {'choice': 'range', 'configuration': 2})
  768. self.assertEqual(TestHyperparams.defaults('a.range.configuration'), 2)
  769. self.assertEqual(TestHyperparams.defaults('a.range.choice'), 'range')
  770. self.assertEqual(TestHyperparams.defaults('b'), {'choice': 'nochoice', 'foo': 10})
  771. self.assertEqual(TestHyperparams.defaults('b.nochoice'), {'choice': 'nochoice', 'foo': 10})
  772. self.assertEqual(TestHyperparams.defaults('b.nochoice.foo'), 10)
  773. self.assertEqual(TestHyperparams.defaults('b.nochoice.choice'), 'nochoice')
  774. def test_primitive(self):
  775. # To hide any logging or stdout output.
  776. with utils.silence():
  777. index.register_primitive('d3m.primitives.regression.monomial.Test', MonomialPrimitive)
  778. index.register_primitive('d3m.primitives.data_generation.random.Test', RandomPrimitive)
  779. index.register_primitive('d3m.primitives.operator.sum.Test', SumPrimitive)
  780. index.register_primitive('d3m.primitives.operator.increment.Test', IncrementPrimitive)
  781. hyperparameter = hyperparams.Primitive(MonomialPrimitive)
  782. self.assertEqual(hyperparameter.structural_type, MonomialPrimitive)
  783. self.assertEqual(hyperparameter.get_default(), MonomialPrimitive)
  784. # To hide any logging or stdout output.
  785. with utils.silence():
  786. self.assertEqual(hyperparameter.sample(42), MonomialPrimitive)
  787. hyperparams_class = MonomialPrimitive.metadata.get_hyperparams()
  788. primitive = MonomialPrimitive(hyperparams=hyperparams_class.defaults())
  789. hyperparameter = hyperparams.Enumeration([MonomialPrimitive, RandomPrimitive, SumPrimitive, IncrementPrimitive, None], None)
  790. self.assertEqual(hyperparameter.structural_type, typing.Union[MonomialPrimitive, RandomPrimitive, SumPrimitive, IncrementPrimitive, type(None)])
  791. self.assertEqual(hyperparameter.get_default(), None)
  792. self.assertEqual(hyperparameter.sample(42), IncrementPrimitive)
  793. hyperparameter = hyperparams.Enumeration[typing.Optional[base.PrimitiveBase]]([MonomialPrimitive, RandomPrimitive, SumPrimitive, IncrementPrimitive, None], None)
  794. self.assertEqual(hyperparameter.structural_type, typing.Optional[base.PrimitiveBase])
  795. self.assertEqual(hyperparameter.get_default(), None)
  796. self.assertEqual(hyperparameter.sample(42), IncrementPrimitive)
  797. set_hyperparameter = hyperparams.Set(hyperparameter, (MonomialPrimitive, RandomPrimitive), 2, 4)
  798. self.assertEqual(set_hyperparameter.get_default(), (MonomialPrimitive, RandomPrimitive))
  799. self.assertEqual(set_hyperparameter.sample(42), (RandomPrimitive, None, SumPrimitive, MonomialPrimitive))
  800. union_hyperparameter = hyperparams.Union(OrderedDict(
  801. none=hyperparams.Hyperparameter(None),
  802. primitive=hyperparams.Enumeration[base.PrimitiveBase]([MonomialPrimitive, RandomPrimitive, SumPrimitive, IncrementPrimitive], MonomialPrimitive),
  803. ), 'none')
  804. self.assertEqual(union_hyperparameter.get_default(), None)
  805. self.assertEqual(union_hyperparameter.sample(45), SumPrimitive)
  806. hyperparameter = hyperparams.Enumeration([primitive, RandomPrimitive, SumPrimitive, IncrementPrimitive, None], None)
  807. self.assertEqual(hyperparameter.structural_type, typing.Union[MonomialPrimitive, RandomPrimitive, SumPrimitive, IncrementPrimitive, type(None)])
  808. self.assertEqual(hyperparameter.get_default(), None)
  809. self.assertEqual(hyperparameter.sample(42), IncrementPrimitive)
  810. hyperparameter = hyperparams.Enumeration[typing.Optional[base.PrimitiveBase]]([primitive, RandomPrimitive, SumPrimitive, IncrementPrimitive, None], None)
  811. self.assertEqual(hyperparameter.structural_type, typing.Optional[base.PrimitiveBase])
  812. self.assertEqual(hyperparameter.get_default(), None)
  813. self.assertEqual(hyperparameter.sample(42), IncrementPrimitive)
  814. set_hyperparameter = hyperparams.Set(hyperparameter, (primitive, RandomPrimitive), 2, 4)
  815. self.assertEqual(set_hyperparameter.get_default(), (primitive, RandomPrimitive))
  816. self.assertEqual(set_hyperparameter.sample(42), (RandomPrimitive, None, SumPrimitive, primitive))
  817. union_hyperparameter = hyperparams.Union(OrderedDict(
  818. none=hyperparams.Hyperparameter(None),
  819. primitive=hyperparams.Enumeration[base.PrimitiveBase]([primitive, RandomPrimitive, SumPrimitive, IncrementPrimitive], primitive),
  820. ), 'none')
  821. self.assertEqual(union_hyperparameter.get_default(), None)
  822. self.assertEqual(union_hyperparameter.sample(45), SumPrimitive)
  823. hyperparameter = hyperparams.Primitive(primitive)
  824. self.assertEqual(hyperparameter.structural_type, MonomialPrimitive)
  825. self.assertEqual(hyperparameter.get_default(), primitive)
  826. # To hide any logging or stdout output.
  827. with utils.silence():
  828. self.assertEqual(hyperparameter.sample(42), primitive)
  829. hyperparameter = hyperparams.Primitive[base.PrimitiveBase](MonomialPrimitive)
  830. self.assertEqual(hyperparameter.get_default(), MonomialPrimitive)
  831. # To hide any logging or stdout output.
  832. with utils.silence():
  833. # There might be additional primitives available in the system,
  834. # so we cannot know which one will really be returned.
  835. self.assertTrue(hyperparameter.sample(42), hyperparameter.matching_primitives)
  836. self.maxDiff = None
  837. self.assertEqual(hyperparameter.to_simple_structure(), {
  838. 'default': MonomialPrimitive,
  839. 'semantic_types': [],
  840. 'structural_type': base.PrimitiveBase,
  841. 'type': hyperparams.Primitive,
  842. 'primitive_families': [],
  843. 'algorithm_types': [],
  844. 'produce_methods': [],
  845. })
  846. self.assertEqual(hyperparameter.value_to_json_structure(hyperparameter.get_default()), {'class': 'd3m.primitives.regression.monomial.Test'})
  847. self.assertEqual(hyperparameter.value_from_json_structure(hyperparameter.value_to_json_structure(hyperparameter.get_default())), hyperparameter.get_default())
  848. self.assertTrue(hyperparameter.get_max_samples() >= 4, hyperparameter.get_max_samples())
  849. hyperparameter = hyperparams.Primitive[base.PrimitiveBase](primitive)
  850. self.assertEqual(hyperparameter.get_default(), primitive)
  851. self.assertEqual(hyperparameter.to_simple_structure(), {
  852. 'default': primitive,
  853. 'semantic_types': [],
  854. 'structural_type': base.PrimitiveBase,
  855. 'type': hyperparams.Primitive,
  856. 'primitive_families': [],
  857. 'algorithm_types': [],
  858. 'produce_methods': [],
  859. })
  860. self.assertEqual(hyperparameter.value_to_json_structure(hyperparameter.get_default()), {'instance': 'gANjdGVzdF9wcmltaXRpdmVzLm1vbm9taWFsCk1vbm9taWFsUHJpbWl0aXZlCnEAKYFxAX1xAihYCwAAAGNvbnN0cnVjdG9ycQN9cQQoWAsAAABoeXBlcnBhcmFtc3EFY3Rlc3RfcHJpbWl0aXZlcy5tb25vbWlhbApIeXBlcnBhcmFtcwpxBimBcQd9cQhYBAAAAGJpYXNxCUcAAAAAAAAAAHNiWAsAAAByYW5kb21fc2VlZHEKSwB1WAYAAABwYXJhbXNxC2N0ZXN0X3ByaW1pdGl2ZXMubW9ub21pYWwKUGFyYW1zCnEMKYFxDVgBAAAAYXEOSwBzdWIu'})
  861. set_hyperparameter = hyperparams.Set(hyperparameter, (MonomialPrimitive, RandomPrimitive), 2, 4)
  862. self.assertEqual(set_hyperparameter.get_default(), (MonomialPrimitive, RandomPrimitive))
  863. union_hyperparameter = hyperparams.Union(OrderedDict(
  864. none=hyperparams.Hyperparameter(None),
  865. primitive=hyperparameter,
  866. ), 'none')
  867. self.assertEqual(union_hyperparameter.get_default(), None)
  868. def test_invalid_name(self):
  869. with self.assertRaisesRegex(ValueError, 'Hyper-parameter name \'.*\' contains invalid characters.'):
  870. hyperparams.Hyperparams.define({
  871. 'foo.bar': hyperparams.Uniform(1.0, 10.0, 2.0),
  872. })
  873. def test_class_as_default(self):
  874. class Foo:
  875. pass
  876. foo = Foo()
  877. hyperparameter = hyperparams.Enumeration(['a', 'b', 1, 2, foo], foo)
  878. self.assertEqual(hyperparameter.value_to_json_structure(1), {'encoding': 'pickle', 'value': 'gANLAS4='})
  879. hyperparameter = hyperparams.Enumeration(['a', 'b', 1, 2], 2)
  880. self.assertEqual(hyperparameter.value_to_json_structure(1), 1)
  881. def test_configuration_immutability(self):
  882. class TestHyperparams(hyperparams.Hyperparams):
  883. a = hyperparams.Union(OrderedDict(
  884. range=hyperparams.UniformInt(1, 10, 2),
  885. none=hyperparams.Hyperparameter(None),
  886. ), 'range')
  887. b = hyperparams.Uniform(1.0, 10.0, 2.0)
  888. with self.assertRaisesRegex(TypeError, '\'FrozenOrderedDict\' object does not support item assignment'):
  889. TestHyperparams.configuration['c'] = hyperparams.Enumeration(['a', 'b', 1, 2, None], None)
  890. with self.assertRaisesRegex(AttributeError, 'Hyper-parameters configuration is immutable'):
  891. TestHyperparams.configuration = OrderedDict(
  892. range=hyperparams.UniformInt(1, 10, 2),
  893. none=hyperparams.Hyperparameter(None),
  894. )
  895. def test_dict_as_default(self):
  896. Inputs = container.DataFrame
  897. Outputs = container.DataFrame
  898. class Hyperparams(hyperparams.Hyperparams):
  899. value = hyperparams.Hyperparameter({}, semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'])
  900. # Silence any validation warnings.
  901. with utils.silence():
  902. class Primitive(transformer.TransformerPrimitiveBase[Inputs, Outputs, Hyperparams]):
  903. metadata = metadata_base.PrimitiveMetadata({
  904. 'id': '152ea984-d8a4-4a37-87a0-29829b082e54',
  905. 'version': '0.1.0',
  906. 'name': "Test Primitive",
  907. 'python_path': 'd3m.primitives.test.dict_as_default',
  908. 'algorithm_types': [
  909. metadata_base.PrimitiveAlgorithmType.PRINCIPAL_COMPONENT_ANALYSIS,
  910. ],
  911. 'primitive_family': metadata_base.PrimitiveFamily.FEATURE_SELECTION,
  912. })
  913. def produce(self, *, inputs: Inputs, timeout: float = None, iterations: int = None) -> base.CallResult[Outputs]:
  914. pass
  915. self.assertEqual(Primitive.metadata.query()['primitive_code']['hyperparams']['value']['default'], {})
  916. def test_comma_warning(self):
  917. logger = logging.getLogger('d3m.metadata.hyperparams')
  918. with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
  919. class Hyperparams(hyperparams.Hyperparams):
  920. value = hyperparams.Hyperparameter({}, semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter']),
  921. self.assertEqual(len(cm.records), 1)
  922. self.assertEqual(cm.records[0].message, 'Probably invalid definition of a hyper-parameter. Hyper-parameter should be defined as class attribute without a trailing comma.')
  923. def test_json_schema(self):
  924. Inputs = container.DataFrame
  925. Outputs = container.DataFrame
  926. # Silence any validation warnings.
  927. with utils.silence():
  928. # Defining primitive triggers checking against JSON schema.
  929. class TestJsonPrimitive(transformer.TransformerPrimitiveBase[Inputs, Outputs, TestPicklingHyperparams]):
  930. metadata = metadata_base.PrimitiveMetadata({
  931. 'id': 'cdfada09-5161-4f2e-bc7f-223d843d59c1',
  932. 'version': '0.1.0',
  933. 'name': "Test Primitive",
  934. 'python_path': 'd3m.primitives.test.json_schema',
  935. 'algorithm_types': [
  936. metadata_base.PrimitiveAlgorithmType.PRINCIPAL_COMPONENT_ANALYSIS,
  937. ],
  938. 'primitive_family': metadata_base.PrimitiveFamily.FEATURE_SELECTION,
  939. })
  940. def produce(self, *, inputs: Inputs, timeout: float = None, iterations: int = None) -> base.CallResult[Outputs]:
  941. pass
  942. def test_pickling(self):
  943. pickle.loads(pickle.dumps(TestPicklingHyperparams))
  944. unpickled = pickle.loads(pickle.dumps(TestPicklingHyperparams.defaults()))
  945. self.assertEqual(unpickled['choice'].configuration['value'].structural_type, typing.Union[float, int])
  946. def test_sorted_set(self):
  947. set_hyperparameter = hyperparams.SortedSet(hyperparams.Hyperparameter[int](1), [])
  948. with self.assertLogs(hyperparams.logger) as cm:
  949. self.assertEqual(set(set_hyperparameter.sample_multiple(min_samples=2, max_samples=2)), {(1,), ()})
  950. self.assertEqual(len(cm.records), 1)
  951. elements = hyperparams.Enumeration(['a', 'b', 'c', 'd', 'e'], 'e')
  952. set_hyperparameter = hyperparams.SortedSet(elements, ('a', 'b', 'c', 'd', 'e'), 5, 5)
  953. self.assertEqual(set_hyperparameter.get_default(), ('a', 'b', 'c', 'd', 'e'))
  954. self.assertEqual(set_hyperparameter.sample(45), ('a', 'b', 'c', 'd', 'e'))
  955. self.assertEqual(set_hyperparameter.get_max_samples(), 1)
  956. self.assertEqual(set_hyperparameter.sample_multiple(1, 1, 42), (('a', 'b', 'c', 'd', 'e'),))
  957. self.assertEqual(set_hyperparameter.sample_multiple(0, 1, 42), ())
  958. self.maxDiff = None
  959. self.assertEqual(set_hyperparameter.to_simple_structure(), {
  960. 'default': ('a', 'b', 'c', 'd', 'e'),
  961. 'semantic_types': [],
  962. 'structural_type': typing.Sequence[str],
  963. 'type': hyperparams.SortedSet,
  964. 'min_size': 5,
  965. 'max_size': 5,
  966. 'elements': {
  967. 'default': 'e',
  968. 'semantic_types': [],
  969. 'structural_type': str,
  970. 'type': hyperparams.Enumeration,
  971. 'values': ['a', 'b', 'c', 'd', 'e'],
  972. },
  973. 'ascending': True,
  974. })
  975. self.assertEqual(set_hyperparameter.value_to_json_structure(set_hyperparameter.get_default()), ['a', 'b', 'c', 'd', 'e'])
  976. self.assertEqual(set_hyperparameter.value_to_json_structure(set_hyperparameter.sample(45)), ['a', 'b', 'c', 'd', 'e'])
  977. self.assertEqual(set_hyperparameter.value_from_json_structure(set_hyperparameter.value_to_json_structure(set_hyperparameter.get_default())), set_hyperparameter.get_default())
  978. self.assertEqual(set_hyperparameter.value_from_json_structure(set_hyperparameter.value_to_json_structure(set_hyperparameter.sample(45))), set_hyperparameter.sample(45))
  979. with self.assertRaisesRegex(ValueError, 'Value \'.*\' has less than 5 elements'):
  980. elements = hyperparams.Enumeration(['a', 'b', 'c', 'd', 'e'], 'e')
  981. hyperparams.SortedSet(elements, (), 5, 5)
  982. with self.assertRaisesRegex(ValueError, 'Value \'.*\' is not among values'):
  983. elements = hyperparams.Enumeration(['a', 'b', 'c', 'd', 'e'], 'e')
  984. hyperparams.SortedSet(elements, ('a', 'b', 'c', 'd', 'f'), 5, 5)
  985. with self.assertRaisesRegex(ValueError, 'Value \'.*\' has duplicate elements'):
  986. elements = hyperparams.Enumeration(['a', 'b', 'c', 'd', 'e'], 'e')
  987. hyperparams.SortedSet(elements, ('a', 'b', 'c', 'd', 'd'), 5, 5)
  988. set_hyperparameter.contribute_to_class('foo')
  989. with self.assertRaises(KeyError):
  990. set_hyperparameter.get_default('foo')
  991. list_of_supported_metafeatures = ['f1', 'f2', 'f3']
  992. metafeature = hyperparams.Enumeration(list_of_supported_metafeatures, list_of_supported_metafeatures[0], semantic_types=['https://metadata.datadrivendiscovery.org/types/MetafeatureParameter'])
  993. set_hyperparameter = hyperparams.SortedSet(metafeature, (), 0, 3)
  994. self.assertEqual(set_hyperparameter.get_default(), ())
  995. self.assertEqual(set_hyperparameter.sample(42), ('f2', 'f3'))
  996. self.assertEqual(set_hyperparameter.get_max_samples(), 8)
  997. self.assertEqual(set_hyperparameter.sample_multiple(0, 3, 42), (('f1', 'f2', 'f3'), ('f2', 'f3')))
  998. self.assertEqual(set_hyperparameter.value_to_json_structure(set_hyperparameter.get_default()), [])
  999. self.assertEqual(set_hyperparameter.value_to_json_structure(set_hyperparameter.sample(42)), ['f2', 'f3'])
  1000. self.assertEqual(set_hyperparameter.value_from_json_structure(set_hyperparameter.value_to_json_structure(set_hyperparameter.get_default())), set_hyperparameter.get_default())
  1001. self.assertEqual(set_hyperparameter.value_from_json_structure(set_hyperparameter.value_to_json_structure(set_hyperparameter.sample(42))), set_hyperparameter.sample(42))
  1002. set_hyperparameter = hyperparams.SortedSet(metafeature, (), 0, None)
  1003. self.assertEqual(set_hyperparameter.get_default(), ())
  1004. self.assertEqual(set_hyperparameter.sample(42), ('f2', 'f3'))
  1005. self.assertEqual(set_hyperparameter.get_max_samples(), 8)
  1006. self.assertEqual(set_hyperparameter.sample_multiple(0, 3, 42), (('f1', 'f2', 'f3'), ('f2', 'f3')))
  1007. set_hyperparameter = hyperparams.SortedSet(hyperparams.Hyperparameter[int](0), (0, 1), min_size=2, max_size=2)
  1008. with self.assertLogs(hyperparams.logger) as cm:
  1009. self.assertEqual(set_hyperparameter.sample_multiple(1, 1, 42), ((0, 1),))
  1010. self.assertEqual(len(cm.records), 1)
  1011. with self.assertLogs(hyperparams.logger) as cm:
  1012. self.assertEqual(set_hyperparameter.sample(42), (0, 1))
  1013. self.assertEqual(len(cm.records), 1)
  1014. set_hyperparameter = hyperparams.SortedSet(hyperparams.Hyperparameter[int](0), (0,), min_size=1, max_size=1)
  1015. with self.assertLogs(hyperparams.logger) as cm:
  1016. set_hyperparameter.sample(42)
  1017. self.assertEqual(len(cm.records), 1)
  1018. set_hyperparameter = hyperparams.SortedSet(hyperparams.Uniform(0.0, 100.0, 50.0, lower_inclusive=False, upper_inclusive=False), (25.0, 75.0), min_size=2, max_size=2)
  1019. self.assertEqual(set_hyperparameter.sample(42), (37.454011884736246, 95.07143064099162))
  1020. def test_sorted_set_with_hyperparams(self):
  1021. elements = hyperparams.Hyperparams.define(OrderedDict(
  1022. range=hyperparams.UniformInt(1, 10, 2),
  1023. enum=hyperparams.Enumeration(['a', 'b', 'c', 'd', 'e'], 'e'),
  1024. ))
  1025. with self.assertRaises(exceptions.NotSupportedError):
  1026. hyperparams.SortedSet(elements, (elements(range=2, enum='a'),), 0, 5)
  1027. def test_list(self):
  1028. list_hyperparameter = hyperparams.List(hyperparams.Hyperparameter[int](1), [], 0, 1)
  1029. with self.assertLogs(hyperparams.logger) as cm:
  1030. self.assertEqual(set(list_hyperparameter.sample_multiple(min_samples=2, max_samples=2)), {(1,), ()})
  1031. self.assertEqual(len(cm.records), 1)
  1032. elements = hyperparams.Enumeration(['a', 'b', 1, 2, None], None)
  1033. list_hyperparameter = hyperparams.List(elements, ('a', 'b', 1, 2, None), 5, 5)
  1034. self.assertEqual(list_hyperparameter.get_default(), ('a', 'b', 1, 2, None))
  1035. self.assertEqual(list_hyperparameter.sample(45), (2, 2, None, 'a', 2))
  1036. self.assertEqual(list_hyperparameter.get_max_samples(), 3125)
  1037. self.assertEqual(list_hyperparameter.sample_multiple(1, 1, 42), ((2, None, 1, None, None),))
  1038. self.assertEqual(list_hyperparameter.sample_multiple(0, 1, 42), ())
  1039. self.maxDiff = None
  1040. self.assertEqual(list_hyperparameter.to_simple_structure(), {
  1041. 'default': ('a', 'b', 1, 2, None),
  1042. 'semantic_types': [],
  1043. 'structural_type': typing.Sequence[typing.Union[str, int, type(None)]],
  1044. 'type': hyperparams.List,
  1045. 'min_size': 5,
  1046. 'max_size': 5,
  1047. 'elements': {
  1048. 'default': None,
  1049. 'semantic_types': [],
  1050. 'structural_type': typing.Union[str, int, type(None)],
  1051. 'type': hyperparams.Enumeration,
  1052. 'values': ['a', 'b', 1, 2, None],
  1053. },
  1054. 'is_configuration': False,
  1055. })
  1056. self.assertEqual(list_hyperparameter.value_to_json_structure(list_hyperparameter.get_default()), ['a', 'b', 1, 2, None])
  1057. self.assertEqual(list_hyperparameter.value_to_json_structure(list_hyperparameter.sample(45)), [2, 2, None, 'a', 2])
  1058. self.assertEqual(list_hyperparameter.value_from_json_structure(list_hyperparameter.value_to_json_structure(list_hyperparameter.get_default())), list_hyperparameter.get_default())
  1059. self.assertEqual(list_hyperparameter.value_from_json_structure(list_hyperparameter.value_to_json_structure(list_hyperparameter.sample(45))), list_hyperparameter.sample(45))
  1060. with self.assertRaisesRegex(ValueError, 'Value \'.*\' has less than 5 elements'):
  1061. elements = hyperparams.Enumeration(['a', 'b', 1, 2, None], None)
  1062. hyperparams.List(elements, (), 5, 5)
  1063. with self.assertRaisesRegex(ValueError, 'Value \'.*\' is not among values'):
  1064. elements = hyperparams.Enumeration(['a', 'b', 1, 2, None], None)
  1065. hyperparams.List(elements, ('a', 'b', 1, 2, 3), 5, 5)
  1066. list_hyperparameter.contribute_to_class('foo')
  1067. with self.assertRaises(KeyError):
  1068. list_hyperparameter.get_default('foo')
  1069. list_of_supported_metafeatures = ['f1', 'f2', 'f3']
  1070. metafeature = hyperparams.Enumeration(list_of_supported_metafeatures, list_of_supported_metafeatures[0], semantic_types=['https://metadata.datadrivendiscovery.org/types/MetafeatureParameter'])
  1071. list_hyperparameter = hyperparams.List(metafeature, (), 0, 3)
  1072. self.assertEqual(list_hyperparameter.get_default(), ())
  1073. self.assertEqual(list_hyperparameter.sample(42), ('f1', 'f3'))
  1074. self.assertEqual(list_hyperparameter.get_max_samples(), 40)
  1075. self.assertEqual(list_hyperparameter.sample_multiple(0, 3, 42), (('f1', 'f3', 'f3'), ('f1', 'f1', 'f3')))
  1076. self.assertEqual(list_hyperparameter.value_to_json_structure(list_hyperparameter.get_default()), [])
  1077. self.assertEqual(list_hyperparameter.value_to_json_structure(list_hyperparameter.sample(42)), ['f1', 'f3'])
  1078. self.assertEqual(list_hyperparameter.value_from_json_structure(list_hyperparameter.value_to_json_structure(list_hyperparameter.get_default())), list_hyperparameter.get_default())
  1079. self.assertEqual(list_hyperparameter.value_from_json_structure(list_hyperparameter.value_to_json_structure(list_hyperparameter.sample(42))), list_hyperparameter.sample(42))
  1080. list_hyperparameter = hyperparams.List(metafeature, (), 0, 10)
  1081. self.assertEqual(list_hyperparameter.get_default(), ())
  1082. self.assertEqual(list_hyperparameter.sample(42), ('f1', 'f3', 'f3', 'f1', 'f1', 'f3'))
  1083. self.assertEqual(list_hyperparameter.get_max_samples(), 88573)
  1084. self.assertEqual(list_hyperparameter.sample_multiple(0, 3, 42), (('f1', 'f3', 'f3'), ('f1', 'f1', 'f3', 'f2', 'f3', 'f3', 'f3')))
  1085. list_hyperparameter = hyperparams.List(hyperparams.Bounded(1, None, 100), (100,), min_size=1, max_size=None)
  1086. with self.assertLogs(hyperparams.logger) as cm:
  1087. self.assertEqual(list_hyperparameter.sample(42), (100,))
  1088. self.assertEqual(len(cm.records), 1)
  1089. def test_list_with_hyperparams(self):
  1090. elements = hyperparams.Hyperparams.define(OrderedDict(
  1091. range=hyperparams.UniformInt(1, 10, 2),
  1092. enum=hyperparams.Enumeration(['a', 'b', 1, 2, None], None),
  1093. ))
  1094. list_hyperparameter = hyperparams.List(elements, (elements(range=2, enum='a'),), 0, 5)
  1095. self.assertEqual(list_hyperparameter.get_default(), ({'range': 2, 'enum': 'a'},))
  1096. self.assertEqual(list_hyperparameter.sample(45), ({'range': 4, 'enum': None}, {'range': 1, 'enum': 2}, {'range': 5, 'enum': 'b'}))
  1097. self.assertEqual(list_hyperparameter.get_max_samples(), 188721946)
  1098. self.assertEqual(list_hyperparameter.sample_multiple(1, 1, 42), (({'range': 8, 'enum': None}, {'range': 5, 'enum': 'b'}, {'range': 3, 'enum': 1}),))
  1099. self.assertEqual(list_hyperparameter.sample_multiple(0, 1, 42), ())
  1100. self.maxDiff = None
  1101. self.assertEqual(list_hyperparameter.to_simple_structure(), {
  1102. 'default': ({'range': 2, 'enum': 'a'},),
  1103. 'elements': {
  1104. 'enum': {
  1105. 'default': None,
  1106. 'semantic_types': [],
  1107. 'structural_type': typing.Union[str, int, type(None)],
  1108. 'type': hyperparams.Enumeration,
  1109. 'values': ['a', 'b', 1, 2, None],
  1110. },
  1111. 'range': {
  1112. 'default': 2,
  1113. 'lower': 1,
  1114. 'semantic_types': [],
  1115. 'structural_type': int,
  1116. 'type': hyperparams.UniformInt,
  1117. 'upper': 10,
  1118. 'lower_inclusive': True,
  1119. 'upper_inclusive': False,
  1120. },
  1121. },
  1122. 'is_configuration': True,
  1123. 'max_size': 5,
  1124. 'min_size': 0,
  1125. 'semantic_types': [],
  1126. 'structural_type': typing.Sequence[elements],
  1127. 'type': hyperparams.List,
  1128. })
  1129. self.assertEqual(list_hyperparameter.value_to_json_structure(list_hyperparameter.get_default()), [{'range': 2, 'enum': 'a'}])
  1130. self.assertEqual(list_hyperparameter.value_to_json_structure(list_hyperparameter.sample(45)), [{'range': 4, 'enum': None}, {'range': 1, 'enum': 2}, {'range': 5, 'enum': 'b'}])
  1131. self.assertEqual(list_hyperparameter.value_from_json_structure(list_hyperparameter.value_to_json_structure(list_hyperparameter.get_default())), list_hyperparameter.get_default())
  1132. self.assertEqual(list_hyperparameter.value_from_json_structure(list_hyperparameter.value_to_json_structure(list_hyperparameter.sample(45))), list_hyperparameter.sample(45))
  1133. # We have to explicitly disable setting names if we want to use it for "List" hyper-parameter.
  1134. class ListHyperparams(hyperparams.Hyperparams, set_names=False):
  1135. choice = hyperparams.Choice({
  1136. 'none': hyperparams.Hyperparams,
  1137. 'range': hyperparams.Hyperparams.define(OrderedDict(
  1138. value=hyperparams.UniformInt(1, 10, 2),
  1139. )),
  1140. }, 'none')
  1141. class TestHyperparams(hyperparams.Hyperparams):
  1142. a = list_hyperparameter
  1143. b = hyperparams.List(ListHyperparams, (ListHyperparams({'choice': {'choice': 'none'}}),), 0, 3)
  1144. self.assertEqual(TestHyperparams.to_simple_structure(), {
  1145. 'a': {
  1146. 'type': hyperparams.List,
  1147. 'default': ({'range': 2, 'enum': 'a'},),
  1148. 'structural_type': typing.Sequence[elements],
  1149. 'semantic_types': [],
  1150. 'elements': {
  1151. 'range': {
  1152. 'type': hyperparams.UniformInt,
  1153. 'default': 2,
  1154. 'structural_type': int,
  1155. 'semantic_types': [],
  1156. 'lower': 1,
  1157. 'upper': 10,
  1158. 'lower_inclusive': True,
  1159. 'upper_inclusive': False,
  1160. },
  1161. 'enum': {
  1162. 'type': hyperparams.Enumeration,
  1163. 'default': None,
  1164. 'structural_type': typing.Union[str, int, type(None)],
  1165. 'semantic_types': [],
  1166. 'values': ['a', 'b', 1, 2, None],
  1167. },
  1168. },
  1169. 'is_configuration': True,
  1170. 'min_size': 0,
  1171. 'max_size': 5,
  1172. },
  1173. 'b': {
  1174. 'type': hyperparams.List,
  1175. 'default': ({'choice': {'choice': 'none'}},),
  1176. 'structural_type': typing.Sequence[ListHyperparams],
  1177. 'semantic_types': [],
  1178. 'elements': {
  1179. 'choice': {
  1180. 'type': hyperparams.Choice,
  1181. 'default': {'choice': 'none'},
  1182. 'structural_type': typing.Dict,
  1183. 'semantic_types': [],
  1184. 'choices': {
  1185. 'none': {
  1186. 'choice': {
  1187. 'type': hyperparams.Hyperparameter,
  1188. 'default': 'none',
  1189. 'structural_type': str,
  1190. 'semantic_types': ['https://metadata.datadrivendiscovery.org/types/ChoiceParameter'],
  1191. },
  1192. },
  1193. 'range': {
  1194. 'value': {
  1195. 'type': hyperparams.UniformInt,
  1196. 'default': 2,
  1197. 'structural_type': int,
  1198. 'semantic_types': [],
  1199. 'lower': 1,
  1200. 'upper': 10,
  1201. 'lower_inclusive': True,
  1202. 'upper_inclusive': False,
  1203. },
  1204. 'choice': {
  1205. 'type': hyperparams.Hyperparameter,
  1206. 'default': 'range',
  1207. 'structural_type': str,
  1208. 'semantic_types': ['https://metadata.datadrivendiscovery.org/types/ChoiceParameter'],
  1209. },
  1210. },
  1211. },
  1212. },
  1213. },
  1214. 'is_configuration': True,
  1215. 'min_size': 0,
  1216. 'max_size': 3,
  1217. },
  1218. })
  1219. self.assertEqual(TestHyperparams.configuration['b'].elements.configuration['choice'].choices['range'].configuration['value'].name, 'b.choice.range.value')
  1220. self.assertEqual(TestHyperparams.defaults(), {
  1221. 'a': ({'range': 2, 'enum': 'a'},),
  1222. 'b': ({'choice': {'choice': 'none'}},),
  1223. })
  1224. self.assertTrue(utils.is_instance(TestHyperparams.defaults()['a'], typing.Sequence[elements]))
  1225. self.assertTrue(utils.is_instance(TestHyperparams.defaults()['b'], typing.Sequence[ListHyperparams]))
  1226. with self.assertLogs(hyperparams.logger) as cm:
  1227. self.assertEqual(TestHyperparams.sample(42), {
  1228. 'a': ({'range': 8, 'enum': None}, {'range': 5, 'enum': 'b'}, {'range': 3, 'enum': 1}),
  1229. 'b': (
  1230. {
  1231. 'choice': {'value': 5, 'choice': 'range'},
  1232. }, {
  1233. 'choice': {'value': 8, 'choice': 'range'},
  1234. },
  1235. ),
  1236. })
  1237. self.assertEqual(len(cm.records), 1)
  1238. with self.assertLogs(hyperparams.logger) as cm:
  1239. self.assertEqual(TestHyperparams.sample(42).values_to_json_structure(), {
  1240. 'a': [{'range': 8, 'enum': None}, {'range': 5, 'enum': 'b'}, {'range': 3, 'enum': 1}],
  1241. 'b': [
  1242. {
  1243. 'choice': {'value': 5, 'choice': 'range'},
  1244. }, {
  1245. 'choice': {'value': 8, 'choice': 'range'},
  1246. },
  1247. ],
  1248. })
  1249. self.assertEqual(len(cm.records), 1)
  1250. with self.assertLogs(hyperparams.logger) as cm:
  1251. self.assertEqual(TestHyperparams.values_from_json_structure(TestHyperparams.sample(42).values_to_json_structure()), TestHyperparams.sample(42))
  1252. self.assertEqual(len(cm.records), 1)
  1253. self.assertEqual(len(list(TestHyperparams.traverse())), 8)
  1254. self.assertEqual(TestHyperparams.defaults('a'), ({'range': 2, 'enum': 'a'},))
  1255. self.assertEqual(TestHyperparams.defaults('a.range'), 2)
  1256. # Default of a whole "List" hyper-parameter can be different than of nested hyper-parameters.
  1257. self.assertEqual(TestHyperparams.defaults('a.enum'), None)
  1258. self.assertEqual(TestHyperparams.defaults('b'), ({'choice': {'choice': 'none'}},))
  1259. self.assertEqual(TestHyperparams.defaults('b.choice'), {'choice': 'none'})
  1260. self.assertEqual(TestHyperparams.defaults('b.choice.none'), {'choice': 'none'})
  1261. self.assertEqual(TestHyperparams.defaults('b.choice.none.choice'), 'none')
  1262. self.assertEqual(TestHyperparams.defaults('b.choice.range'), {'choice': 'range', 'value': 2})
  1263. self.assertEqual(TestHyperparams.defaults('b.choice.range.value'), 2)
  1264. self.assertEqual(TestHyperparams.defaults('b.choice.range.choice'), 'range')
  1265. self.assertEqual(TestHyperparams(TestHyperparams.defaults(), b=(
  1266. ListHyperparams({
  1267. 'choice': {'value': 5, 'choice': 'range'},
  1268. }),
  1269. ListHyperparams({
  1270. 'choice': {'value': 8, 'choice': 'range'},
  1271. }),
  1272. )), {
  1273. 'a': ({'range': 2, 'enum': 'a'},),
  1274. 'b': (
  1275. {
  1276. 'choice': {'value': 5, 'choice': 'range'},
  1277. },
  1278. {
  1279. 'choice': {'value': 8, 'choice': 'range'},
  1280. },
  1281. ),
  1282. })
  1283. self.assertEqual(TestHyperparams(TestHyperparams.defaults(), **{'a': (
  1284. elements({'range': 8, 'enum': None}),
  1285. elements({'range': 5, 'enum': 'b'}),
  1286. elements({'range': 3, 'enum': 1}),
  1287. )}), {
  1288. 'a': (
  1289. {'range': 8, 'enum': None},
  1290. {'range': 5, 'enum': 'b'},
  1291. {'range': 3, 'enum': 1},
  1292. ),
  1293. 'b': ({'choice': {'choice': 'none'}},)
  1294. })
  1295. self.assertEqual(TestHyperparams.defaults().replace({'a': (
  1296. elements({'range': 8, 'enum': None}),
  1297. elements({'range': 5, 'enum': 'b'}),
  1298. elements({'range': 3, 'enum': 1}),
  1299. )}), {
  1300. 'a': (
  1301. {'range': 8, 'enum': None},
  1302. {'range': 5, 'enum': 'b'},
  1303. {'range': 3, 'enum': 1},
  1304. ),
  1305. 'b': ({'choice': {'choice': 'none'}},),
  1306. })
  1307. def test_sorted_list(self):
  1308. list_hyperparameter = hyperparams.SortedList(hyperparams.Hyperparameter[int](1), [], 0, 1)
  1309. with self.assertLogs(hyperparams.logger) as cm:
  1310. self.assertEqual(set(list_hyperparameter.sample_multiple(min_samples=2, max_samples=2)), {(1,), ()})
  1311. self.assertEqual(len(cm.records), 1)
  1312. elements = hyperparams.Enumeration(['a', 'b', 'c', 'd', 'e'], 'e')
  1313. list_hyperparameter = hyperparams.SortedList(elements, ('a', 'b', 'c', 'd', 'e'), 5, 5)
  1314. self.assertEqual(list_hyperparameter.get_default(), ('a', 'b', 'c', 'd', 'e'))
  1315. self.assertEqual(list_hyperparameter.sample(45), ('a', 'd', 'd', 'd', 'e'))
  1316. self.assertEqual(list_hyperparameter.get_max_samples(), 126)
  1317. self.assertEqual(list_hyperparameter.sample_multiple(1, 1, 42), (('c', 'd', 'e', 'e', 'e'),))
  1318. self.assertEqual(list_hyperparameter.sample_multiple(0, 1, 42), ())
  1319. self.maxDiff = None
  1320. self.assertEqual(list_hyperparameter.to_simple_structure(), {
  1321. 'default': ('a', 'b', 'c', 'd', 'e'),
  1322. 'semantic_types': [],
  1323. 'structural_type': typing.Sequence[str],
  1324. 'type': hyperparams.SortedList,
  1325. 'min_size': 5,
  1326. 'max_size': 5,
  1327. 'elements': {
  1328. 'default': 'e',
  1329. 'semantic_types': [],
  1330. 'structural_type': str,
  1331. 'type': hyperparams.Enumeration,
  1332. 'values': ['a', 'b', 'c', 'd', 'e'],
  1333. },
  1334. 'ascending': True,
  1335. })
  1336. self.assertEqual(list_hyperparameter.value_to_json_structure(list_hyperparameter.get_default()), ['a', 'b', 'c', 'd', 'e'])
  1337. self.assertEqual(list_hyperparameter.value_to_json_structure(list_hyperparameter.sample(45)), ['a', 'd', 'd', 'd', 'e'])
  1338. self.assertEqual(list_hyperparameter.value_from_json_structure(list_hyperparameter.value_to_json_structure(list_hyperparameter.get_default())), list_hyperparameter.get_default())
  1339. self.assertEqual(list_hyperparameter.value_from_json_structure(list_hyperparameter.value_to_json_structure(list_hyperparameter.sample(45))), list_hyperparameter.sample(45))
  1340. with self.assertRaisesRegex(ValueError, 'Value \'.*\' has less than 5 elements'):
  1341. elements = hyperparams.Enumeration(['a', 'b', 1, 2, None], None)
  1342. hyperparams.SortedList(elements, (), 5, 5)
  1343. with self.assertRaisesRegex(ValueError, 'Value \'.*\' is not among values'):
  1344. elements = hyperparams.Enumeration(['a', 'b', 1, 2, None], None)
  1345. hyperparams.SortedList(elements, ('a', 'b', 1, 2, 3), 5, 5)
  1346. list_hyperparameter.contribute_to_class('foo')
  1347. with self.assertRaises(KeyError):
  1348. list_hyperparameter.get_default('foo')
  1349. list_of_supported_metafeatures = ['f1', 'f2', 'f3']
  1350. metafeature = hyperparams.Enumeration(list_of_supported_metafeatures, list_of_supported_metafeatures[0], semantic_types=['https://metadata.datadrivendiscovery.org/types/MetafeatureParameter'])
  1351. list_hyperparameter = hyperparams.SortedList(metafeature, (), 0, 3)
  1352. self.assertEqual(list_hyperparameter.get_default(), ())
  1353. self.assertEqual(list_hyperparameter.sample(42), ('f1', 'f3'))
  1354. self.assertEqual(list_hyperparameter.get_max_samples(), 20)
  1355. self.assertEqual(list_hyperparameter.sample_multiple(0, 3, 42), (('f1', 'f3', 'f3'), ('f1', 'f1', 'f3')))
  1356. self.assertEqual(list_hyperparameter.value_to_json_structure(list_hyperparameter.get_default()), [])
  1357. self.assertEqual(list_hyperparameter.value_to_json_structure(list_hyperparameter.sample(42)), ['f1', 'f3'])
  1358. self.assertEqual(list_hyperparameter.value_from_json_structure(list_hyperparameter.value_to_json_structure(list_hyperparameter.get_default())), list_hyperparameter.get_default())
  1359. self.assertEqual(list_hyperparameter.value_from_json_structure(list_hyperparameter.value_to_json_structure(list_hyperparameter.sample(42))), list_hyperparameter.sample(42))
  1360. list_hyperparameter = hyperparams.SortedList(metafeature, (), 0, 10)
  1361. self.assertEqual(list_hyperparameter.get_default(), ())
  1362. self.assertEqual(list_hyperparameter.sample(42), ('f1', 'f1', 'f1', 'f3', 'f3', 'f3'))
  1363. self.assertEqual(list_hyperparameter.get_max_samples(), 286)
  1364. self.assertEqual(list_hyperparameter.sample_multiple(0, 3, 42), (('f1', 'f3', 'f3'), ('f1', 'f1', 'f2', 'f3', 'f3', 'f3', 'f3')))
  1365. list_hyperparameter = hyperparams.SortedList(hyperparams.Bounded[int](1, None, 1), (1, 1), min_size=2, max_size=2)
  1366. with self.assertLogs(hyperparams.logger) as cm:
  1367. self.assertEqual(list_hyperparameter.sample(42), (1, 1))
  1368. self.assertEqual(len(cm.records), 1)
  1369. def test_sorted_list_with_hyperparams(self):
  1370. elements = hyperparams.Hyperparams.define(OrderedDict(
  1371. range=hyperparams.UniformInt(1, 10, 2),
  1372. enum=hyperparams.Enumeration(['a', 'b', 'c', 'd', 'e'], 'e'),
  1373. ))
  1374. with self.assertRaises(exceptions.NotSupportedError):
  1375. hyperparams.SortedList(elements, (elements(range=2, enum='a'),), 0, 5)
  1376. def test_import_cycle(self):
  1377. # All references to "hyperparams_module" in "d3m.metadata.base" should be lazy:
  1378. # for example, as a string in the typing signature, because we have an import cycle.
  1379. subprocess.run([sys.executable, '-c', 'import d3m.metadata.base'], check=True)
  1380. subprocess.run([sys.executable, '-c', 'import d3m.metadata.hyperparams'], check=True)
  1381. def test_union_float_int(self):
  1382. float_hp = hyperparams.Uniform(1, 10, 2)
  1383. int_hp = hyperparams.UniformInt(1, 10, 2)
  1384. x = float_hp.value_from_json_structure(2.0)
  1385. self.assertEqual(x, 2.0)
  1386. self.assertIs(type(x), float)
  1387. x = float_hp.value_from_json_structure(2)
  1388. self.assertEqual(x, 2.0)
  1389. self.assertIs(type(x), float)
  1390. x = float_hp.value_from_json_structure(2.1)
  1391. self.assertEqual(x, 2.1)
  1392. self.assertIs(type(x), float)
  1393. x = int_hp.value_from_json_structure(2.0)
  1394. self.assertEqual(x, 2)
  1395. self.assertIs(type(x), int)
  1396. x = int_hp.value_from_json_structure(2)
  1397. self.assertEqual(x, 2)
  1398. self.assertIs(type(x), int)
  1399. with self.assertRaises(exceptions.InvalidArgumentTypeError):
  1400. int_hp.value_from_json_structure(2.1)
  1401. hyperparameter = hyperparams.Union(
  1402. OrderedDict(
  1403. float=hyperparams.Uniform(1, 5, 2),
  1404. int=hyperparams.UniformInt(6, 10, 7),
  1405. ),
  1406. 'float',
  1407. )
  1408. self.assertEqual(hyperparameter.value_to_json_structure(2.0), {'case': 'float', 'value': 2.0})
  1409. self.assertEqual(hyperparameter.value_to_json_structure(7), {'case': 'int', 'value': 7})
  1410. x = hyperparameter.value_from_json_structure({'case': 'float', 'value': 2.0})
  1411. self.assertEqual(x, 2.0)
  1412. self.assertIs(type(x), float)
  1413. x = hyperparameter.value_from_json_structure({'case': 'float', 'value': 2.1})
  1414. self.assertEqual(x, 2.1)
  1415. self.assertIs(type(x), float)
  1416. x = hyperparameter.value_from_json_structure({'case': 'float', 'value': 2})
  1417. self.assertEqual(x, 2.0)
  1418. self.assertIs(type(x), float)
  1419. x = hyperparameter.value_from_json_structure({'case': 'int', 'value': 7})
  1420. self.assertEqual(x, 7)
  1421. self.assertIs(type(x), int)
  1422. x = hyperparameter.value_from_json_structure({'case': 'int', 'value': 7.0})
  1423. self.assertEqual(x, 7)
  1424. self.assertIs(type(x), int)
  1425. def test_can_serialize_to_json(self):
  1426. # See: https://gitlab.com/datadrivendiscovery/d3m/-/issues/440
  1427. # This is enumeration internally so it tests also that enumeration values are kept as-is when sampled.
  1428. hyperparameter = hyperparams.UniformBool(True)
  1429. sample = hyperparameter.sample()
  1430. self.assertIsInstance(sample, bool)
  1431. x = hyperparameter.value_to_json_structure(sample)
  1432. json.dumps(x)
  1433. def test_sampling_type(self):
  1434. sample = hyperparams.Uniform(0, 10, 5).sample()
  1435. self.assertIsInstance(sample, float)
  1436. def test_numpy_sampling(self):
  1437. class UniformInt64(hyperparams.Bounded[numpy.int64]):
  1438. def __init__(
  1439. self, lower: numpy.int64, upper: numpy.int64, default: numpy.int64, *, lower_inclusive: bool = True, upper_inclusive: bool = False,
  1440. semantic_types: typing.Sequence[str] = None, description: str = None,
  1441. ) -> None:
  1442. if lower is None or upper is None:
  1443. raise exceptions.InvalidArgumentValueError("Bounds cannot be None.")
  1444. super().__init__(lower, upper, default, lower_inclusive=lower_inclusive, upper_inclusive=upper_inclusive, semantic_types=semantic_types, description=description)
  1445. def _initialize_effective_bounds(self) -> None:
  1446. self._initialize_effective_bounds_int()
  1447. super()._initialize_effective_bounds()
  1448. def sample(self, random_state: numpy.random.RandomState = None) -> int:
  1449. random_state = sklearn_validation.check_random_state(random_state)
  1450. return self.structural_type(random_state.randint(self._effective_lower, self._effective_upper))
  1451. def get_max_samples(self) -> typing.Optional[int]:
  1452. return self._effective_upper - self._effective_lower
  1453. with self.assertRaises(exceptions.InvalidArgumentTypeError):
  1454. UniformInt64(0, 10, 5)
  1455. hyperparameter = UniformInt64(numpy.int64(0), numpy.int64(10), numpy.int64(5))
  1456. sample = hyperparameter.sample()
  1457. self.assertIsInstance(sample, numpy.int64)
  1458. x = hyperparameter.value_to_json_structure(sample)
  1459. json.dumps(x)
  1460. if __name__ == '__main__':
  1461. unittest.main()

全栈的自动化机器学习系统,主要针对多变量时间序列数据的异常检测。TODS提供了详尽的用于构建基于机器学习的异常检测系统的模块,它们包括:数据处理(data processing),时间序列处理( time series processing),特征分析(feature analysis),检测算法(detection algorithms),和强化模块( reinforcement module)。这些模块所提供的功能包括常见的数据预处理、时间序列数据的平滑或变换,从时域或频域中抽取特征、多种多样的检测算