You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_classifier.py 28 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723
  1. """
  2. Auto Classfier for Heterogeneous Node Classification
  3. """
  4. import time
  5. import json
  6. from copy import deepcopy
  7. from typing import Sequence
  8. import torch
  9. import numpy as np
  10. import yaml
  11. from ..base import BaseClassifier
  12. from ...base import _parse_hp_space, _initialize_single_model, _parse_model_hp
  13. from ....module.train import TRAINER_DICT, BaseNodeClassificationHetTrainer
  14. from ....module.train import get_feval
  15. from ...utils import LeaderBoard, set_seed
  16. from ....utils import get_logger
  17. LOGGER = get_logger("HeteroNodeClassifier")
  18. class AutoHeteroNodeClassifier(BaseClassifier):
  19. """
  20. Auto Multi-class HeteroGraph Node Classifier.
  21. Used to automatically solve the heterogeneous node classification problems.
  22. Parameters
  23. ----------
  24. feature_module: autogl.module.feature.BaseFeatureEngineer or str or None
  25. The (name of) auto feature engineer used to process the given dataset. Default ``deepgl``.
  26. Disable feature engineer by setting it to ``None``.
  27. graph_models: list of autogl.module.model.BaseModel or list of str
  28. The (name of) models to be optimized as backbone. Default ``['gat', 'gcn']``.
  29. hpo_module: autogl.module.hpo.BaseHPOptimizer or str or None
  30. The (name of) hpo module used to search for best hyper parameters. Default ``anneal``.
  31. Disable hpo by setting it to ``None``.
  32. ensemble_module: autogl.module.ensemble.BaseEnsembler or str or None
  33. The (name of) ensemble module used to ensemble the multi-models found. Default ``voting``.
  34. Disable ensemble by setting it to ``None``.
  35. max_evals: int (Optional)
  36. If given, will set the number eval times the hpo module will use.
  37. Only be effective when hpo_module is ``str``. Default ``None``.
  38. default_trainer: str (Optional)
  39. The (name of) the trainer used in this solver. Default to ``NodeClassificationFull``.
  40. trainer_hp_space: list of dict (Optional)
  41. trainer hp space or list of trainer hp spaces configuration.
  42. If a single trainer hp is given, will specify the hp space of trainer for every model.
  43. If a list of trainer hp is given, will specify every model with corrsponding
  44. trainer hp space.
  45. Default ``None``.
  46. model_hp_spaces: list of list of dict (Optional)
  47. model hp space configuration.
  48. If given, will specify every hp space of every passed model. Default ``None``.
  49. If the encoder(-decoder) is passed, the space should be a dict containing keys "encoder"
  50. and "decoder", specifying the detailed encoder decoder hp spaces.
  51. size: int (Optional)
  52. The max models ensemble module will use. Default ``None``.
  53. device: torch.device or str
  54. The device where model will be running on. If set to ``auto``, will use gpu when available.
  55. You can also specify the device by directly giving ``gpu`` or ``cuda:0``, etc.
  56. Default ``auto``.
  57. """
  58. def __init__(
  59. self,
  60. graph_models=("han", "hgt"),
  61. hpo_module="anneal",
  62. ensemble_module="voting",
  63. max_evals=50,
  64. default_trainer="NodeClassificationHet",
  65. trainer_hp_space=None,
  66. model_hp_spaces=None,
  67. size=4,
  68. device="auto",
  69. ):
  70. super().__init__(
  71. # currently we do not support feature engineering
  72. feature_module=None,
  73. graph_models=graph_models,
  74. # currently we do not support nas for heterogeneous node classifier
  75. nas_algorithms=None,
  76. nas_spaces=None,
  77. nas_estimators=None,
  78. hpo_module=hpo_module,
  79. ensemble_module=ensemble_module,
  80. max_evals=max_evals,
  81. default_trainer=default_trainer,
  82. trainer_hp_space=trainer_hp_space,
  83. model_hp_spaces=model_hp_spaces,
  84. size=size,
  85. device=device,
  86. )
  87. # data to be kept when fit
  88. self.dataset = None
  89. def _init_graph_module(
  90. self, graph_models, num_classes, num_features, feval, device, loss, dataset
  91. ) -> "AutoHeteroNodeClassifier":
  92. # load graph network module
  93. self.graph_model_list = []
  94. for i, model in enumerate(graph_models):
  95. # init the trainer
  96. if not isinstance(model, BaseNodeClassificationHetTrainer):
  97. trainer = (
  98. self._default_trainer if not isinstance(self._default_trainer, (tuple, list))
  99. else self._default_trainer[i]
  100. )
  101. if isinstance(trainer, str):
  102. if trainer not in TRAINER_DICT:
  103. raise KeyError(f"Does not support trainer {trainer}")
  104. trainer = TRAINER_DICT[trainer]()
  105. if isinstance(model, (tuple, list)):
  106. trainer.encoder = model[0]
  107. trainer.decoder = model[1]
  108. else:
  109. trainer.encoder = model
  110. else:
  111. trainer = model
  112. # set model hp space
  113. if self._model_hp_spaces is not None:
  114. if self._model_hp_spaces[i] is not None:
  115. if isinstance(self._model_hp_spaces[i], dict):
  116. encoder_hp_space = self._model_hp_spaces[i].get('encoder', None)
  117. decoder_hp_space = self._model_hp_spaces[i].get('decoder', None)
  118. else:
  119. encoder_hp_space = self._model_hp_spaces[i]
  120. decoder_hp_space = None
  121. if encoder_hp_space is not None:
  122. trainer.encoder.hyper_parameter_space = encoder_hp_space
  123. if decoder_hp_space is not None:
  124. trainer.decoder.hyper_parameter_space = decoder_hp_space
  125. # set trainer hp space
  126. if self._trainer_hp_space is not None:
  127. if isinstance(self._trainer_hp_space[0], list):
  128. current_hp_for_trainer = self._trainer_hp_space[i]
  129. else:
  130. current_hp_for_trainer = self._trainer_hp_space
  131. trainer.hyper_parameter_space = current_hp_for_trainer
  132. trainer.num_features = num_features
  133. trainer.num_classes = num_classes
  134. trainer.from_dataset(dataset)
  135. trainer.loss = loss
  136. trainer.feval = feval
  137. trainer.to(device)
  138. self.graph_model_list.append(trainer)
  139. return self
  140. # pylint: disable=arguments-differ
  141. def fit(
  142. self,
  143. dataset,
  144. time_limit=-1,
  145. evaluation_method="infer",
  146. seed=None,
  147. ) -> "AutoHeteroNodeClassifier":
  148. """
  149. Fit current solver on given dataset.
  150. Parameters
  151. ----------
  152. dataset: autogl.data.Dataset
  153. The dataset needed to fit on. This dataset must have only one graph.
  154. time_limit: int
  155. The time limit of the whole fit process (in seconds). If set below 0,
  156. will ignore time limit. Default ``-1``.
  157. inplace: bool
  158. Whether we process the given dataset in inplace manner. Default ``False``.
  159. Set it to True if you want to save memory by modifying the given dataset directly.
  160. evaluation_method: (list of) str or autogl.module.train.evaluation
  161. A (list of) evaluation method for current solver. If ``infer``, will automatically
  162. determine. Default ``infer``.
  163. seed: int (Optional)
  164. The random seed. If set to ``None``, will run everything at random.
  165. Default ``None``.
  166. Returns
  167. -------
  168. self: autogl.solver.AutoNodeClassifier
  169. A reference of current solver.
  170. """
  171. set_seed(seed)
  172. if time_limit < 0:
  173. time_limit = 3600 * 24
  174. time_begin = time.time()
  175. graph_data = dataset[0]
  176. field = dataset.schema["target_node_type"]
  177. all_labels = graph_data.nodes[field].data['label']
  178. num_classes = all_labels.max().item() + 1
  179. # initialize leaderboard
  180. if evaluation_method == "infer":
  181. if hasattr(dataset, "metric"):
  182. evaluation_method = [dataset.metric]
  183. else:
  184. num_of_label = num_classes
  185. if num_of_label == 2:
  186. evaluation_method = ["auc"]
  187. else:
  188. evaluation_method = ["acc"]
  189. assert isinstance(evaluation_method, list)
  190. evaluator_list = get_feval(evaluation_method)
  191. self.leaderboard = LeaderBoard(
  192. [e.get_eval_name() for e in evaluator_list],
  193. {e.get_eval_name(): e.is_higher_better() for e in evaluator_list},
  194. )
  195. # set up the dataset
  196. assert ("train_mask" in graph_data.nodes[field].data
  197. and "val_mask" in graph_data.nodes[field].data), ("Currently only support"
  198. " Dataset with default train/val/test split")
  199. self.dataset = dataset
  200. # check whether the dataset has features.
  201. # currently we only support hetero graph classification with features.
  202. feat = graph_data.nodes[field].data['feat']
  203. assert feat is not None, (
  204. "Does not support fit on non node-feature dataset!"
  205. " Please add node features to dataset or specify feature engineers that generate"
  206. " node features."
  207. )
  208. num_features = feat.size(-1)
  209. # initialize graph networks
  210. self._init_graph_module(
  211. self.gml,
  212. num_features=num_features,
  213. num_classes=num_classes,
  214. feval=evaluator_list,
  215. device=self.runtime_device,
  216. loss="nll_loss" if not hasattr(dataset, "loss") else self.dataset.loss,
  217. dataset=dataset
  218. )
  219. # train the models and tune hpo
  220. result_valid = []
  221. names = []
  222. for idx, model in enumerate(self.graph_model_list):
  223. time_for_each_model = (time_limit - time.time() + time_begin) / (
  224. len(self.graph_model_list) - idx
  225. )
  226. if self.hpo_module is None:
  227. model.initialize()
  228. model.train(dataset, True)
  229. optimized = model
  230. else:
  231. optimized, _ = self.hpo_module.optimize(
  232. trainer=model, dataset=dataset, time_limit=time_for_each_model
  233. )
  234. # to save memory, all the trainer derived will be mapped to cpu
  235. optimized.to(torch.device("cpu"))
  236. name = str(optimized) + "_idx%d" % (idx)
  237. names.append(name)
  238. performance_on_valid, _ = optimized.get_valid_score(return_major=False)
  239. result_valid.append(optimized.get_valid_predict_proba().cpu().numpy())
  240. self.leaderboard.insert_model_performance(
  241. name,
  242. dict(
  243. zip(
  244. [e.get_eval_name() for e in evaluator_list],
  245. performance_on_valid,
  246. )
  247. ),
  248. )
  249. self.trained_models[name] = optimized
  250. # fit the ensemble model
  251. if self.ensemble_module is not None:
  252. performance = self.ensemble_module.fit(
  253. result_valid,
  254. all_labels[graph_data.nodes[field].data["val_mask"]].cpu().numpy(),
  255. names,
  256. evaluator_list,
  257. n_classes=num_classes,
  258. )
  259. self.leaderboard.insert_model_performance(
  260. "ensemble",
  261. dict(zip([e.get_eval_name() for e in evaluator_list], performance)),
  262. )
  263. return self
  264. def fit_predict(
  265. self,
  266. dataset,
  267. time_limit=-1,
  268. evaluation_method="infer",
  269. use_ensemble=True,
  270. use_best=True,
  271. name=None,
  272. ) -> np.ndarray:
  273. """
  274. Fit current solver on given dataset and return the predicted value.
  275. Parameters
  276. ----------
  277. dataset: torch_geometric.data.dataset.Dataset
  278. The dataset needed to fit on. This dataset must have only one graph.
  279. time_limit: int
  280. The time limit of the whole fit process (in seconds).
  281. If set below 0, will ignore time limit. Default ``-1``.
  282. inplace: bool
  283. Whether we process the given dataset in inplace manner. Default ``False``.
  284. Set it to True if you want to save memory by modifying the given dataset directly.
  285. train_split: float or int (Optional)
  286. The train ratio (in ``float``) or number (in ``int``) of dataset. If you want to
  287. use default train/val/test split in dataset, please set this to ``None``.
  288. Default ``None``.
  289. val_split: float or int (Optional)
  290. The validation ratio (in ``float``) or number (in ``int``) of dataset. If you want
  291. to use default train/val/test split in dataset, please set this to ``None``.
  292. Default ``None``.
  293. balanced: bool
  294. Wether to create the train/valid/test split in a balanced way.
  295. If set to ``True``, the train/valid will have the same number of different classes.
  296. Default ``False``.
  297. evaluation_method: (list of) str or autogl.module.train.evaluation
  298. A (list of) evaluation method for current solver. If ``infer``, will automatically
  299. determine. Default ``infer``.
  300. use_ensemble: bool
  301. Whether to use ensemble to do the predict. Default ``True``.
  302. use_best: bool
  303. Whether to use the best single model to do the predict. Will only be effective when
  304. ``use_ensemble`` is ``False``.
  305. Default ``True``.
  306. name: str or None
  307. The name of model used to predict. Will only be effective when ``use_ensemble`` and
  308. ``use_best`` both are ``False``.
  309. Default ``None``.
  310. Returns
  311. -------
  312. result: np.ndarray
  313. An array of shape ``(N,)``, where ``N`` is the number of test nodes. The prediction
  314. on given dataset.
  315. """
  316. self.fit(
  317. dataset=dataset,
  318. time_limit=time_limit,
  319. evaluation_method=evaluation_method,
  320. )
  321. return self.predict(
  322. dataset=dataset,
  323. use_ensemble=use_ensemble,
  324. use_best=use_best,
  325. name=name,
  326. )
  327. def predict_proba(
  328. self,
  329. dataset=None,
  330. use_ensemble=True,
  331. use_best=True,
  332. name=None,
  333. mask="test",
  334. ) -> np.ndarray:
  335. """
  336. Predict the node probability.
  337. Parameters
  338. ----------
  339. dataset: torch_geometric.data.dataset.Dataset or None
  340. The dataset needed to predict. If ``None``, will use the processed dataset passed
  341. to ``fit()`` instead. Default ``None``.
  342. inplaced: bool
  343. Whether the given dataset is processed. Only be effective when ``dataset``
  344. is not ``None``. If you pass the dataset to ``fit()`` with ``inplace=True``, and
  345. you pass the dataset again to this method, you should set this argument to ``True``.
  346. Otherwise ``False``. Default ``False``.
  347. inplace: bool
  348. Whether we process the given dataset in inplace manner. Default ``False``. Set it to
  349. True if you want to save memory by modifying the given dataset directly.
  350. use_ensemble: bool
  351. Whether to use ensemble to do the predict. Default ``True``.
  352. use_best: bool
  353. Whether to use the best single model to do the predict. Will only be effective when
  354. ``use_ensemble`` is ``False``. Default ``True``.
  355. name: str or None
  356. The name of model used to predict. Will only be effective when ``use_ensemble`` and
  357. ``use_best`` both are ``False``. Default ``None``.
  358. mask: str
  359. The data split to give prediction on. Default ``test``.
  360. Returns
  361. -------
  362. result: np.ndarray
  363. An array of shape ``(N,C,)``, where ``N`` is the number of test nodes and ``C`` is
  364. the number of classes. The prediction on given dataset.
  365. """
  366. if dataset is None:
  367. dataset = self.dataset
  368. assert dataset is not None, (
  369. "Please execute fit() first before" " predicting on remembered dataset"
  370. )
  371. if use_ensemble:
  372. LOGGER.info("Ensemble argument on, will try using ensemble model.")
  373. if not use_ensemble and use_best:
  374. LOGGER.info(
  375. "Ensemble argument off and best argument on, will try using best model."
  376. )
  377. if (use_ensemble and self.ensemble_module is not None) or (
  378. not use_best and name == "ensemble"
  379. ):
  380. # we need to get all the prediction of every model trained
  381. predict_result = []
  382. names = []
  383. for model_name in self.trained_models:
  384. predict_result.append(
  385. self._predict_proba_by_name(dataset, model_name, mask)
  386. )
  387. names.append(model_name)
  388. return self.ensemble_module.ensemble(predict_result, names)
  389. if use_ensemble and self.ensemble_module is None:
  390. LOGGER.warning(
  391. "Cannot use ensemble because no ensebmle module is given."
  392. " Will use best model instead."
  393. )
  394. if use_best or (use_ensemble and self.ensemble_module is None):
  395. # just return the best model we have found
  396. name = self.leaderboard.get_best_model()
  397. return self._predict_proba_by_name(dataset, name, mask)
  398. if name is not None:
  399. # return model performance by name
  400. return self._predict_proba_by_name(dataset, name, mask)
  401. LOGGER.error(
  402. "No model name is given while ensemble and best arguments are off."
  403. )
  404. raise ValueError(
  405. "You need to specify a model name if you do not want use ensemble and best model."
  406. )
  407. def _predict_proba_by_name(self, dataset, name, mask="test"):
  408. self.trained_models[name].to(self.runtime_device)
  409. predicted = (
  410. self.trained_models[name].predict_proba(dataset, mask=mask).cpu().numpy()
  411. )
  412. self.trained_models[name].to(torch.device("cpu"))
  413. return predicted
  414. def predict(
  415. self,
  416. dataset=None,
  417. use_ensemble=True,
  418. use_best=True,
  419. name=None,
  420. mask="test",
  421. ) -> np.ndarray:
  422. """
  423. Predict the node class number.
  424. Parameters
  425. ----------
  426. dataset: torch_geometric.data.dataset.Dataset or None
  427. The dataset needed to predict. If ``None``, will use the processed dataset passed
  428. to ``fit()`` instead. Default ``None``.
  429. inplaced: bool
  430. Whether the given dataset is processed. Only be effective when ``dataset``
  431. is not ``None``. If you pass the dataset to ``fit()`` with ``inplace=True``,
  432. and you pass the dataset again to this method, you should set this argument
  433. to ``True``. Otherwise ``False``. Default ``False``.
  434. inplace: bool
  435. Whether we process the given dataset in inplace manner. Default ``False``.
  436. Set it to True if you want to save memory by modifying the given dataset directly.
  437. use_ensemble: bool
  438. Whether to use ensemble to do the predict. Default ``True``.
  439. use_best: bool
  440. Whether to use the best single model to do the predict. Will only be effective
  441. when ``use_ensemble`` is ``False``. Default ``True``.
  442. name: str or None
  443. The name of model used to predict. Will only be effective when ``use_ensemble``
  444. and ``use_best`` both are ``False``. Default ``None``.
  445. mask: str
  446. The data split to give prediction on. Default ``test``.
  447. Returns
  448. -------
  449. result: np.ndarray
  450. An array of shape ``(N,)``, where ``N`` is the number of test nodes.
  451. The prediction on given dataset.
  452. """
  453. proba = self.predict_proba(
  454. dataset, use_ensemble, use_best, name, mask
  455. )
  456. return np.argmax(proba, axis=1)
  457. def evaluate(self, dataset=None,
  458. use_ensemble=True,
  459. use_best=True,
  460. name=None,
  461. mask="test",
  462. label=None,
  463. metric="acc"
  464. ):
  465. """
  466. Evaluate the given dataset.
  467. Parameters
  468. ----------
  469. dataset: torch_geometric.data.dataset.Dataset or None
  470. The dataset needed to predict. If ``None``, will use the processed dataset passed
  471. to ``fit()`` instead. Default ``None``.
  472. inplaced: bool
  473. Whether the given dataset is processed. Only be effective when ``dataset``
  474. is not ``None``. If you pass the dataset to ``fit()`` with ``inplace=True``, and
  475. you pass the dataset again to this method, you should set this argument to ``True``.
  476. Otherwise ``False``. Default ``False``.
  477. inplace: bool
  478. Whether we process the given dataset in inplace manner. Default ``False``. Set it to
  479. True if you want to save memory by modifying the given dataset directly.
  480. use_ensemble: bool
  481. Whether to use ensemble to do the predict. Default ``True``.
  482. use_best: bool
  483. Whether to use the best single model to do the predict. Will only be effective when
  484. ``use_ensemble`` is ``False``. Default ``True``.
  485. name: str or None
  486. The name of model used to predict. Will only be effective when ``use_ensemble`` and
  487. ``use_best`` both are ``False``. Default ``None``.
  488. mask: str
  489. The data split to give prediction on. Default ``test``.
  490. label: torch.Tensor (Optional)
  491. The groud truth label of the given predicted dataset split. If not passed, will extract
  492. labels from the input dataset.
  493. metric: str
  494. The metric to be used for evaluating the model. Default ``acc``.
  495. Returns
  496. -------
  497. score(s): (list of) evaluation scores
  498. the evaluation results according to the evaluator passed.
  499. """
  500. predicted = self.predict_proba(dataset, use_ensemble, use_best, name, mask)
  501. if dataset is None:
  502. dataset = self.dataset
  503. if label is None:
  504. graph_nodes = dataset[0].nodes[dataset.schema["target_node_type"]].data
  505. if mask in ["train", "val", "test"]:
  506. mask = graph_nodes[f"{mask}_mask"]
  507. label = graph_nodes["label"][mask].cpu().numpy()
  508. evaluator = get_feval(metric)
  509. if isinstance(evaluator, Sequence):
  510. return [evals.evaluate(predicted, label) for evals in evaluator]
  511. return evaluator.evaluate(predicted, label)
  512. @classmethod
  513. def from_config(cls, path_or_dict, filetype="auto") -> "AutoHeteroNodeClassifier":
  514. """
  515. Load solver from config file.
  516. You can use this function to directly load a solver from predefined config dict
  517. or config file path. Currently, only support file type of ``json`` or ``yaml``,
  518. if you pass a path.
  519. Parameters
  520. ----------
  521. path_or_dict: str or dict
  522. The path to the config file or the config dictionary object
  523. filetype: str
  524. The filetype the given file if the path is specified. Currently only support
  525. ``json`` or ``yaml``. You can set to ``auto`` to automatically detect the file
  526. type (from file name). Default ``auto``.
  527. Returns
  528. -------
  529. solver: autogl.solver.AutoGraphClassifier
  530. The solver that is created from given file or dictionary.
  531. """
  532. assert filetype in ["auto", "yaml", "json"], (
  533. "currently only support yaml file or json file type, but get type "
  534. + filetype
  535. )
  536. if isinstance(path_or_dict, str):
  537. if filetype == "auto":
  538. if path_or_dict.endswith(".yaml") or path_or_dict.endswith(".yml"):
  539. filetype = "yaml"
  540. elif path_or_dict.endswith(".json"):
  541. filetype = "json"
  542. else:
  543. LOGGER.error(
  544. "cannot parse the type of the given file name, "
  545. "please manually set the file type"
  546. )
  547. raise ValueError(
  548. "cannot parse the type of the given file name, "
  549. "please manually set the file type"
  550. )
  551. if filetype == "yaml":
  552. path_or_dict = yaml.load(
  553. open(path_or_dict, "r").read(), Loader=yaml.FullLoader
  554. )
  555. else:
  556. path_or_dict = json.load(open(path_or_dict, "r"))
  557. path_or_dict = deepcopy(path_or_dict)
  558. solver = cls(None, [], None, None)
  559. models = path_or_dict.pop("models", [{"name": "hgt"}, {"name": "han"}])
  560. # models should be a list of model
  561. # with each element in two cases
  562. # * a dict describing a certain model
  563. # * a dict containing {"encoder": encoder, "decoder": decoder}
  564. model_hp_space = [
  565. _parse_model_hp(model) for model in models
  566. ]
  567. model_list = [
  568. _initialize_single_model(model) for model in models
  569. ]
  570. trainer = path_or_dict.pop("trainer", None)
  571. default_trainer = "NodeClassificationHet"
  572. trainer_space = None
  573. if isinstance(trainer, dict):
  574. # global default
  575. default_trainer = trainer.pop("name", "NodeClassificationHet")
  576. trainer_space = _parse_hp_space(trainer.pop("hp_space", None))
  577. default_kwargs = {"num_features": None, "num_classes": None}
  578. default_kwargs.update(trainer)
  579. default_kwargs["init"] = False
  580. for i in range(len(model_list)):
  581. model = model_list[i]
  582. trainer_wrap = TRAINER_DICT[default_trainer](
  583. model=model, **default_kwargs
  584. )
  585. model_list[i] = trainer_wrap
  586. elif isinstance(trainer, list):
  587. # sequential trainer definition
  588. assert len(trainer) == len(
  589. model_list
  590. ), "The number of trainer and model does not match"
  591. trainer_space = []
  592. for i in range(len(model_list)):
  593. train, model = trainer[i], model_list[i]
  594. default_trainer = train.pop("name", "NodeClassificationHet")
  595. trainer_space.append(_parse_hp_space(train.pop("hp_space", None)))
  596. default_kwargs = {"num_features": None, "num_classes": None}
  597. default_kwargs.update(train)
  598. default_kwargs["init"] = False
  599. trainer_wrap = TRAINER_DICT[default_trainer](
  600. model=model, **default_kwargs
  601. )
  602. model_list[i] = trainer_wrap
  603. solver.set_graph_models(
  604. model_list, default_trainer, trainer_space, model_hp_space
  605. )
  606. hpo_dict = path_or_dict.pop("hpo", {"name": "anneal"})
  607. if hpo_dict is not None:
  608. name = hpo_dict.pop("name")
  609. solver.set_hpo_module(name, **hpo_dict)
  610. ensemble_dict = path_or_dict.pop("ensemble", {"name": "voting"})
  611. if ensemble_dict is not None:
  612. name = ensemble_dict.pop("name")
  613. solver.set_ensemble_module(name, **ensemble_dict)
  614. return solver