You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_classifier.py 34 kB

5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848
  1. """
  2. Auto Classfier for Node Classification
  3. """
  4. import time
  5. import json
  6. from copy import deepcopy
  7. import torch
  8. import torch.nn.functional as F
  9. import numpy as np
  10. import yaml
  11. from .base import BaseClassifier
  12. from ..base import _parse_hp_space, _initialize_single_model
  13. from ...module.feature import FEATURE_DICT
  14. from ...module.model import MODEL_DICT, BaseModel
  15. from ...module.train import TRAINER_DICT, BaseNodeClassificationTrainer
  16. from ...module.train import get_feval
  17. from ...module.nas.space import NAS_SPACE_DICT
  18. from ...module.nas.algorithm import NAS_ALGO_DICT
  19. from ...module.nas.estimator import NAS_ESTIMATOR_DICT, BaseEstimator
  20. from ..utils import LeaderBoard, set_seed
  21. from ...datasets import utils
  22. from ...utils import get_logger
  23. from torch_geometric.nn import GATConv, GCNConv
  24. LOGGER = get_logger("NodeClassifier")
  25. class AutoNodeClassifier(BaseClassifier):
  26. """
  27. Auto Multi-class Graph Node Classifier.
  28. Used to automatically solve the node classification problems.
  29. Parameters
  30. ----------
  31. feature_module: autogl.module.feature.BaseFeatureEngineer or str or None
  32. The (name of) auto feature engineer used to process the given dataset. Default ``deepgl``.
  33. Disable feature engineer by setting it to ``None``.
  34. graph_models: list of autogl.module.model.BaseModel or list of str
  35. The (name of) models to be optimized as backbone. Default ``['gat', 'gcn']``.
  36. nas_algorithms: (list of) autogl.module.nas.algorithm.BaseNAS or str (Optional)
  37. The (name of) nas algorithms used. Default ``None``.
  38. nas_spaces: (list of) autogl.module.nas.space.BaseSpace or str (Optional)
  39. The (name of) nas spaces used. Default ``None``.
  40. nas_estimators: (list of) autogl.module.nas.estimator.BaseEstimator or str (Optional)
  41. The (name of) nas estimators used. Default ``None``.
  42. hpo_module: autogl.module.hpo.BaseHPOptimizer or str or None
  43. The (name of) hpo module used to search for best hyper parameters. Default ``anneal``.
  44. Disable hpo by setting it to ``None``.
  45. ensemble_module: autogl.module.ensemble.BaseEnsembler or str or None
  46. The (name of) ensemble module used to ensemble the multi-models found. Default ``voting``.
  47. Disable ensemble by setting it to ``None``.
  48. max_evals: int (Optional)
  49. If given, will set the number eval times the hpo module will use.
  50. Only be effective when hpo_module is ``str``. Default ``None``.
  51. trainer_hp_space: list of dict (Optional)
  52. trainer hp space or list of trainer hp spaces configuration.
  53. If a single trainer hp is given, will specify the hp space of trainer for every model.
  54. If a list of trainer hp is given, will specify every model with corrsponding
  55. trainer hp space.
  56. Default ``None``.
  57. model_hp_spaces: list of list of dict (Optional)
  58. model hp space configuration.
  59. If given, will specify every hp space of every passed model. Default ``None``.
  60. size: int (Optional)
  61. The max models ensemble module will use. Default ``None``.
  62. device: torch.device or str
  63. The device where model will be running on. If set to ``auto``, will use gpu when available.
  64. You can also specify the device by directly giving ``gpu`` or ``cuda:0``, etc.
  65. Default ``auto``.
  66. """
  67. def __init__(
  68. self,
  69. feature_module=None,
  70. graph_models=("gat", "gcn"),
  71. nas_algorithms=None,
  72. nas_spaces=None,
  73. nas_estimators=None,
  74. hpo_module="anneal",
  75. ensemble_module="voting",
  76. max_evals=50,
  77. default_trainer=None,
  78. trainer_hp_space=None,
  79. model_hp_spaces=None,
  80. size=4,
  81. device="auto",
  82. ):
  83. super().__init__(
  84. feature_module=feature_module,
  85. graph_models=graph_models,
  86. nas_algorithms=nas_algorithms,
  87. nas_spaces=nas_spaces,
  88. nas_estimators=nas_estimators,
  89. hpo_module=hpo_module,
  90. ensemble_module=ensemble_module,
  91. max_evals=max_evals,
  92. default_trainer=default_trainer or "NodeClassificationFull",
  93. trainer_hp_space=trainer_hp_space,
  94. model_hp_spaces=model_hp_spaces,
  95. size=size,
  96. device=device,
  97. )
  98. # data to be kept when fit
  99. self.dataset = None
  100. def _init_graph_module(
  101. self, graph_models, num_classes, num_features, feval, device, loss
  102. ) -> "AutoNodeClassifier":
  103. # load graph network module
  104. self.graph_model_list = []
  105. if isinstance(graph_models, (list, tuple)):
  106. for model in graph_models:
  107. if isinstance(model, str):
  108. if model in MODEL_DICT:
  109. self.graph_model_list.append(
  110. MODEL_DICT[model](
  111. num_classes=num_classes,
  112. num_features=num_features,
  113. device=device,
  114. init=False,
  115. )
  116. )
  117. else:
  118. raise KeyError("cannot find model %s" % (model))
  119. elif isinstance(model, type) and issubclass(model, BaseModel):
  120. self.graph_model_list.append(
  121. model(
  122. num_classes=num_classes,
  123. num_features=num_features,
  124. device=device,
  125. init=False,
  126. )
  127. )
  128. elif isinstance(model, BaseModel):
  129. # setup the hp of num_classes and num_features
  130. model.set_num_classes(num_classes)
  131. model.set_num_features(num_features)
  132. self.graph_model_list.append(model.to(device))
  133. elif isinstance(model, BaseNodeClassificationTrainer):
  134. # receive a trainer list, put trainer to list
  135. assert (
  136. model.get_model() is not None
  137. ), "Passed trainer should contain a model"
  138. model.model.set_num_classes(num_classes)
  139. model.model.set_num_features(num_features)
  140. model.update_parameters(
  141. num_classes=num_classes,
  142. num_features=num_features,
  143. loss=loss,
  144. feval=feval,
  145. device=device,
  146. )
  147. self.graph_model_list.append(model)
  148. else:
  149. raise KeyError("cannot find graph network %s." % (model))
  150. else:
  151. raise ValueError(
  152. "need graph network to be (list of) str or a BaseModel class/instance, get",
  153. graph_models,
  154. "instead.",
  155. )
  156. # wrap all model_cls with specified trainer
  157. for i, model in enumerate(self.graph_model_list):
  158. # set model hp space
  159. if self._model_hp_spaces is not None:
  160. if self._model_hp_spaces[i] is not None:
  161. if isinstance(model, BaseNodeClassificationTrainer):
  162. model.model.hyper_parameter_space = self._model_hp_spaces[i]
  163. else:
  164. model.hyper_parameter_space = self._model_hp_spaces[i]
  165. # initialize trainer if needed
  166. if isinstance(model, BaseModel):
  167. name = (
  168. self._default_trainer
  169. if isinstance(self._default_trainer, str)
  170. else self._default_trainer[i]
  171. )
  172. model = TRAINER_DICT[name](
  173. model=model,
  174. num_features=num_features,
  175. num_classes=num_classes,
  176. loss=loss,
  177. feval=feval,
  178. device=device,
  179. init=False,
  180. )
  181. # set trainer hp space
  182. if self._trainer_hp_space is not None:
  183. if isinstance(self._trainer_hp_space[0], list):
  184. current_hp_for_trainer = self._trainer_hp_space[i]
  185. else:
  186. current_hp_for_trainer = self._trainer_hp_space
  187. model.hyper_parameter_space = current_hp_for_trainer
  188. self.graph_model_list[i] = model
  189. return self
  190. def _init_nas_module(self, num_features, num_classes, feval, device, loss):
  191. for algo, space, estimator in zip(
  192. self.nas_algorithms, self.nas_spaces, self.nas_estimators
  193. ):
  194. estimator: BaseEstimator
  195. algo.to(device)
  196. space.instantiate(input_dim=num_features, output_dim=num_classes)
  197. estimator.setEvaluation(feval)
  198. estimator.setLossFunction(loss)
  199. # pylint: disable=arguments-differ
  200. def fit(
  201. self,
  202. dataset,
  203. time_limit=-1,
  204. inplace=False,
  205. train_split=None,
  206. val_split=None,
  207. balanced=True,
  208. evaluation_method="infer",
  209. seed=None,
  210. ) -> "AutoNodeClassifier":
  211. """
  212. Fit current solver on given dataset.
  213. Parameters
  214. ----------
  215. dataset: torch_geometric.data.dataset.Dataset
  216. The dataset needed to fit on. This dataset must have only one graph.
  217. time_limit: int
  218. The time limit of the whole fit process (in seconds). If set below 0,
  219. will ignore time limit. Default ``-1``.
  220. inplace: bool
  221. Whether we process the given dataset in inplace manner. Default ``False``.
  222. Set it to True if you want to save memory by modifying the given dataset directly.
  223. train_split: float or int (Optional)
  224. The train ratio (in ``float``) or number (in ``int``) of dataset. If you want to
  225. use default train/val/test split in dataset, please set this to ``None``.
  226. Default ``None``.
  227. val_split: float or int (Optional)
  228. The validation ratio (in ``float``) or number (in ``int``) of dataset. If you want
  229. to use default train/val/test split in dataset, please set this to ``None``.
  230. Default ``None``.
  231. balanced: bool
  232. Wether to create the train/valid/test split in a balanced way.
  233. If set to ``True``, the train/valid will have the same number of different classes.
  234. Default ``True``.
  235. evaluation_method: (list of) str or autogl.module.train.evaluation
  236. A (list of) evaluation method for current solver. If ``infer``, will automatically
  237. determine. Default ``infer``.
  238. seed: int (Optional)
  239. The random seed. If set to ``None``, will run everything at random.
  240. Default ``None``.
  241. Returns
  242. -------
  243. self: autogl.solver.AutoNodeClassifier
  244. A reference of current solver.
  245. """
  246. set_seed(seed)
  247. if time_limit < 0:
  248. time_limit = 3600 * 24
  249. time_begin = time.time()
  250. # initialize leaderboard
  251. if evaluation_method == "infer":
  252. if hasattr(dataset, "metric"):
  253. evaluation_method = [dataset.metric]
  254. else:
  255. num_of_label = dataset.num_classes
  256. if num_of_label == 2:
  257. evaluation_method = ["auc"]
  258. else:
  259. evaluation_method = ["acc"]
  260. assert isinstance(evaluation_method, list)
  261. evaluator_list = get_feval(evaluation_method)
  262. self.leaderboard = LeaderBoard(
  263. [e.get_eval_name() for e in evaluator_list],
  264. {e.get_eval_name(): e.is_higher_better() for e in evaluator_list},
  265. )
  266. # set up the dataset
  267. if train_split is not None and val_split is not None:
  268. size = dataset.data.x.shape[0]
  269. if balanced:
  270. train_split = (
  271. train_split if train_split > 1 else int(train_split * size)
  272. )
  273. val_split = val_split if val_split > 1 else int(val_split * size)
  274. utils.random_splits_mask_class(
  275. dataset,
  276. num_train_per_class=train_split // dataset.num_classes,
  277. num_val_per_class=val_split // dataset.num_classes,
  278. seed=seed,
  279. )
  280. else:
  281. train_split = train_split if train_split < 1 else train_split / size
  282. val_split = val_split if val_split < 1 else val_split / size
  283. utils.random_splits_mask(
  284. dataset, train_ratio=train_split, val_ratio=val_split
  285. )
  286. else:
  287. assert hasattr(dataset.data, "train_mask") and hasattr(
  288. dataset.data, "val_mask"
  289. ), (
  290. "The dataset has no default train/val split! Please manually pass "
  291. "train and val ratio."
  292. )
  293. LOGGER.info("Use the default train/val/test ratio in given dataset")
  294. # feature engineering
  295. if self.feature_module is not None:
  296. dataset = self.feature_module.fit_transform(dataset, inplace=inplace)
  297. self.dataset = dataset
  298. assert self.dataset[0].x is not None, (
  299. "Does not support fit on non node-feature dataset!"
  300. " Please add node features to dataset or specify feature engineers that generate"
  301. " node features."
  302. )
  303. # initialize graph networks
  304. self._init_graph_module(
  305. self.gml,
  306. num_features=self.dataset[0].x.shape[1],
  307. num_classes=dataset.num_classes,
  308. feval=evaluator_list,
  309. device=self.runtime_device,
  310. loss="nll_loss" if not hasattr(dataset, "loss") else dataset.loss,
  311. )
  312. if self.nas_algorithms is not None:
  313. # perform neural architecture search
  314. self._init_nas_module(
  315. num_features=self.dataset[0].x.shape[1],
  316. num_classes=self.dataset.num_classes,
  317. feval=evaluator_list,
  318. device=self.runtime_device,
  319. loss="nll_loss" if not hasattr(dataset, "loss") else dataset.loss,
  320. )
  321. assert not isinstance(self._default_trainer, list) or len(
  322. self.nas_algorithms
  323. ) == len(self._default_trainer) - len(
  324. self.graph_model_list
  325. ), "length of default trainer should match total graph models and nas models passed"
  326. # perform nas and add them to model list
  327. idx_trainer = len(self.graph_model_list)
  328. for algo, space, estimator in zip(
  329. self.nas_algorithms, self.nas_spaces, self.nas_estimators
  330. ):
  331. model = algo.search(space, self.dataset, estimator)
  332. # insert model into default trainer
  333. if isinstance(self._default_trainer, list):
  334. train_name = self._default_trainer[idx_trainer]
  335. idx_trainer += 1
  336. else:
  337. train_name = self._default_trainer
  338. if isinstance(train_name, str):
  339. trainer = TRAINER_DICT[train_name](
  340. model=model,
  341. num_features=self.dataset[0].x.shape[1],
  342. num_classes=self.dataset.num_classes,
  343. loss="nll_loss"
  344. if not hasattr(dataset, "loss")
  345. else dataset.loss,
  346. feval=evaluator_list,
  347. device=self.runtime_device,
  348. init=False,
  349. )
  350. else:
  351. trainer = train_name
  352. trainer.model = model
  353. trainer.update_parameters(
  354. num_classes=self.dataset.num_classes,
  355. num_features=self.dataset[0].x.shape[1],
  356. loss="nll_loss"
  357. if not hasattr(dataset, "loss")
  358. else dataset.loss,
  359. feval=evaluator_list,
  360. device=self.runtime_device,
  361. )
  362. self.graph_model_list.append(trainer)
  363. # train the models and tune hpo
  364. result_valid = []
  365. names = []
  366. for idx, model in enumerate(self.graph_model_list):
  367. time_for_each_model = (time_limit - time.time() + time_begin) / (
  368. len(self.graph_model_list) - idx
  369. )
  370. if self.hpo_module is None:
  371. model.initialize()
  372. model.train(self.dataset, True)
  373. optimized = model
  374. else:
  375. optimized, _ = self.hpo_module.optimize(
  376. trainer=model, dataset=self.dataset, time_limit=time_for_each_model
  377. )
  378. # to save memory, all the trainer derived will be mapped to cpu
  379. optimized.to(torch.device("cpu"))
  380. name = str(optimized) + "_idx%d" % (idx)
  381. names.append(name)
  382. performance_on_valid, _ = optimized.get_valid_score(return_major=False)
  383. result_valid.append(optimized.get_valid_predict_proba().cpu().numpy())
  384. self.leaderboard.insert_model_performance(
  385. name,
  386. dict(
  387. zip(
  388. [e.get_eval_name() for e in evaluator_list],
  389. performance_on_valid,
  390. )
  391. ),
  392. )
  393. self.trained_models[name] = optimized
  394. # fit the ensemble model
  395. if self.ensemble_module is not None:
  396. performance = self.ensemble_module.fit(
  397. result_valid,
  398. self.dataset[0].y[self.dataset[0].val_mask].cpu().numpy(),
  399. names,
  400. evaluator_list,
  401. n_classes=dataset.num_classes,
  402. )
  403. self.leaderboard.insert_model_performance(
  404. "ensemble",
  405. dict(zip([e.get_eval_name() for e in evaluator_list], performance)),
  406. )
  407. return self
  408. def fit_predict(
  409. self,
  410. dataset,
  411. time_limit=-1,
  412. inplace=False,
  413. train_split=None,
  414. val_split=None,
  415. balanced=True,
  416. evaluation_method="infer",
  417. use_ensemble=True,
  418. use_best=True,
  419. name=None,
  420. ) -> np.ndarray:
  421. """
  422. Fit current solver on given dataset and return the predicted value.
  423. Parameters
  424. ----------
  425. dataset: torch_geometric.data.dataset.Dataset
  426. The dataset needed to fit on. This dataset must have only one graph.
  427. time_limit: int
  428. The time limit of the whole fit process (in seconds).
  429. If set below 0, will ignore time limit. Default ``-1``.
  430. inplace: bool
  431. Whether we process the given dataset in inplace manner. Default ``False``.
  432. Set it to True if you want to save memory by modifying the given dataset directly.
  433. train_split: float or int (Optional)
  434. The train ratio (in ``float``) or number (in ``int``) of dataset. If you want to
  435. use default train/val/test split in dataset, please set this to ``None``.
  436. Default ``None``.
  437. val_split: float or int (Optional)
  438. The validation ratio (in ``float``) or number (in ``int``) of dataset. If you want
  439. to use default train/val/test split in dataset, please set this to ``None``.
  440. Default ``None``.
  441. balanced: bool
  442. Wether to create the train/valid/test split in a balanced way.
  443. If set to ``True``, the train/valid will have the same number of different classes.
  444. Default ``False``.
  445. evaluation_method: (list of) str or autogl.module.train.evaluation
  446. A (list of) evaluation method for current solver. If ``infer``, will automatically
  447. determine. Default ``infer``.
  448. use_ensemble: bool
  449. Whether to use ensemble to do the predict. Default ``True``.
  450. use_best: bool
  451. Whether to use the best single model to do the predict. Will only be effective when
  452. ``use_ensemble`` is ``False``.
  453. Default ``True``.
  454. name: str or None
  455. The name of model used to predict. Will only be effective when ``use_ensemble`` and
  456. ``use_best`` both are ``False``.
  457. Default ``None``.
  458. Returns
  459. -------
  460. result: np.ndarray
  461. An array of shape ``(N,)``, where ``N`` is the number of test nodes. The prediction
  462. on given dataset.
  463. """
  464. self.fit(
  465. dataset=dataset,
  466. time_limit=time_limit,
  467. inplace=inplace,
  468. train_split=train_split,
  469. val_split=val_split,
  470. balanced=balanced,
  471. evaluation_method=evaluation_method,
  472. )
  473. return self.predict(
  474. dataset=dataset,
  475. inplaced=inplace,
  476. inplace=inplace,
  477. use_ensemble=use_ensemble,
  478. use_best=use_best,
  479. name=name,
  480. )
  481. def predict_proba(
  482. self,
  483. dataset=None,
  484. inplaced=False,
  485. inplace=False,
  486. use_ensemble=True,
  487. use_best=True,
  488. name=None,
  489. mask="test",
  490. ) -> np.ndarray:
  491. """
  492. Predict the node probability.
  493. Parameters
  494. ----------
  495. dataset: torch_geometric.data.dataset.Dataset or None
  496. The dataset needed to predict. If ``None``, will use the processed dataset passed
  497. to ``fit()`` instead. Default ``None``.
  498. inplaced: bool
  499. Whether the given dataset is processed. Only be effective when ``dataset``
  500. is not ``None``. If you pass the dataset to ``fit()`` with ``inplace=True``, and
  501. you pass the dataset again to this method, you should set this argument to ``True``.
  502. Otherwise ``False``. Default ``False``.
  503. inplace: bool
  504. Whether we process the given dataset in inplace manner. Default ``False``. Set it to
  505. True if you want to save memory by modifying the given dataset directly.
  506. use_ensemble: bool
  507. Whether to use ensemble to do the predict. Default ``True``.
  508. use_best: bool
  509. Whether to use the best single model to do the predict. Will only be effective when
  510. ``use_ensemble`` is ``False``. Default ``True``.
  511. name: str or None
  512. The name of model used to predict. Will only be effective when ``use_ensemble`` and
  513. ``use_best`` both are ``False``. Default ``None``.
  514. mask: str
  515. The data split to give prediction on. Default ``test``.
  516. Returns
  517. -------
  518. result: np.ndarray
  519. An array of shape ``(N,C,)``, where ``N`` is the number of test nodes and ``C`` is
  520. the number of classes. The prediction on given dataset.
  521. """
  522. if dataset is None:
  523. dataset = self.dataset
  524. assert dataset is not None, (
  525. "Please execute fit() first before" " predicting on remembered dataset"
  526. )
  527. elif not inplaced and self.feature_module is not None:
  528. dataset = self.feature_module.transform(dataset, inplace=inplace)
  529. if use_ensemble:
  530. LOGGER.info("Ensemble argument on, will try using ensemble model.")
  531. if not use_ensemble and use_best:
  532. LOGGER.info(
  533. "Ensemble argument off and best argument on, will try using best model."
  534. )
  535. if (use_ensemble and self.ensemble_module is not None) or (
  536. not use_best and name == "ensemble"
  537. ):
  538. # we need to get all the prediction of every model trained
  539. predict_result = []
  540. names = []
  541. for model_name in self.trained_models:
  542. predict_result.append(
  543. self._predict_proba_by_name(dataset, model_name, mask)
  544. )
  545. names.append(model_name)
  546. return self.ensemble_module.ensemble(predict_result, names)
  547. if use_ensemble and self.ensemble_module is None:
  548. LOGGER.warning(
  549. "Cannot use ensemble because no ensebmle module is given."
  550. " Will use best model instead."
  551. )
  552. if use_best or (use_ensemble and self.ensemble_module is None):
  553. # just return the best model we have found
  554. name = self.leaderboard.get_best_model()
  555. return self._predict_proba_by_name(dataset, name, mask)
  556. if name is not None:
  557. # return model performance by name
  558. return self._predict_proba_by_name(dataset, name, mask)
  559. LOGGER.error(
  560. "No model name is given while ensemble and best arguments are off."
  561. )
  562. raise ValueError(
  563. "You need to specify a model name if you do not want use ensemble and best model."
  564. )
  565. def _predict_proba_by_name(self, dataset, name, mask="test"):
  566. self.trained_models[name].to(self.runtime_device)
  567. predicted = (
  568. self.trained_models[name].predict_proba(dataset, mask=mask).cpu().numpy()
  569. )
  570. self.trained_models[name].to(torch.device("cpu"))
  571. return predicted
  572. def predict(
  573. self,
  574. dataset=None,
  575. inplaced=False,
  576. inplace=False,
  577. use_ensemble=True,
  578. use_best=True,
  579. name=None,
  580. mask="test",
  581. ) -> np.ndarray:
  582. """
  583. Predict the node class number.
  584. Parameters
  585. ----------
  586. dataset: torch_geometric.data.dataset.Dataset or None
  587. The dataset needed to predict. If ``None``, will use the processed dataset passed
  588. to ``fit()`` instead. Default ``None``.
  589. inplaced: bool
  590. Whether the given dataset is processed. Only be effective when ``dataset``
  591. is not ``None``. If you pass the dataset to ``fit()`` with ``inplace=True``,
  592. and you pass the dataset again to this method, you should set this argument
  593. to ``True``. Otherwise ``False``. Default ``False``.
  594. inplace: bool
  595. Whether we process the given dataset in inplace manner. Default ``False``.
  596. Set it to True if you want to save memory by modifying the given dataset directly.
  597. use_ensemble: bool
  598. Whether to use ensemble to do the predict. Default ``True``.
  599. use_best: bool
  600. Whether to use the best single model to do the predict. Will only be effective
  601. when ``use_ensemble`` is ``False``. Default ``True``.
  602. name: str or None
  603. The name of model used to predict. Will only be effective when ``use_ensemble``
  604. and ``use_best`` both are ``False``. Default ``None``.
  605. mask: str
  606. The data split to give prediction on. Default ``test``.
  607. Returns
  608. -------
  609. result: np.ndarray
  610. An array of shape ``(N,)``, where ``N`` is the number of test nodes.
  611. The prediction on given dataset.
  612. """
  613. proba = self.predict_proba(
  614. dataset, inplaced, inplace, use_ensemble, use_best, name, mask
  615. )
  616. return np.argmax(proba, axis=1)
  617. @classmethod
  618. def from_config(cls, path_or_dict, filetype="auto") -> "AutoNodeClassifier":
  619. """
  620. Load solver from config file.
  621. You can use this function to directly load a solver from predefined config dict
  622. or config file path. Currently, only support file type of ``json`` or ``yaml``,
  623. if you pass a path.
  624. Parameters
  625. ----------
  626. path_or_dict: str or dict
  627. The path to the config file or the config dictionary object
  628. filetype: str
  629. The filetype the given file if the path is specified. Currently only support
  630. ``json`` or ``yaml``. You can set to ``auto`` to automatically detect the file
  631. type (from file name). Default ``auto``.
  632. Returns
  633. -------
  634. solver: autogl.solver.AutoGraphClassifier
  635. The solver that is created from given file or dictionary.
  636. """
  637. assert filetype in ["auto", "yaml", "json"], (
  638. "currently only support yaml file or json file type, but get type "
  639. + filetype
  640. )
  641. if isinstance(path_or_dict, str):
  642. if filetype == "auto":
  643. if path_or_dict.endswith(".yaml") or path_or_dict.endswith(".yml"):
  644. filetype = "yaml"
  645. elif path_or_dict.endswith(".json"):
  646. filetype = "json"
  647. else:
  648. LOGGER.error(
  649. "cannot parse the type of the given file name, "
  650. "please manually set the file type"
  651. )
  652. raise ValueError(
  653. "cannot parse the type of the given file name, "
  654. "please manually set the file type"
  655. )
  656. if filetype == "yaml":
  657. path_or_dict = yaml.load(
  658. open(path_or_dict, "r").read(), Loader=yaml.FullLoader
  659. )
  660. else:
  661. path_or_dict = json.load(open(path_or_dict, "r"))
  662. path_or_dict = deepcopy(path_or_dict)
  663. solver = cls(None, [], None, None)
  664. fe_list = path_or_dict.pop("feature", None)
  665. if fe_list is not None:
  666. fe_list_ele = []
  667. for feature_engineer in fe_list:
  668. name = feature_engineer.pop("name")
  669. if name is not None:
  670. fe_list_ele.append(FEATURE_DICT[name](**feature_engineer))
  671. if fe_list_ele != []:
  672. solver.set_feature_module(fe_list_ele)
  673. models = path_or_dict.pop("models", [{"name": "gcn"}, {"name": "gat"}])
  674. model_hp_space = [
  675. _parse_hp_space(model.pop("hp_space", None)) for model in models
  676. ]
  677. model_list = [
  678. _initialize_single_model(model.pop("name"), model) for model in models
  679. ]
  680. trainer = path_or_dict.pop("trainer", None)
  681. default_trainer = "NodeClassificationFull"
  682. trainer_space = None
  683. if isinstance(trainer, dict):
  684. # global default
  685. default_trainer = trainer.pop("name", "NodeClassificationFull")
  686. trainer_space = _parse_hp_space(trainer.pop("hp_space", None))
  687. default_kwargs = {"num_features": None, "num_classes": None}
  688. default_kwargs.update(trainer)
  689. default_kwargs["init"] = False
  690. for i in range(len(model_list)):
  691. model = model_list[i]
  692. trainer_wrap = TRAINER_DICT[default_trainer](
  693. model=model, **default_kwargs
  694. )
  695. model_list[i] = trainer_wrap
  696. elif isinstance(trainer, list):
  697. # sequential trainer definition
  698. assert len(trainer) == len(
  699. model_list
  700. ), "The number of trainer and model does not match"
  701. trainer_space = []
  702. for i in range(len(model_list)):
  703. train, model = trainer[i], model_list[i]
  704. default_trainer = train.pop("name", "NodeClassificationFull")
  705. trainer_space.append(_parse_hp_space(train.pop("hp_space", None)))
  706. default_kwargs = {"num_features": None, "num_classes": None}
  707. default_kwargs.update(train)
  708. default_kwargs["init"] = False
  709. trainer_wrap = TRAINER_DICT[default_trainer](
  710. model=model, **default_kwargs
  711. )
  712. model_list[i] = trainer_wrap
  713. solver.set_graph_models(
  714. model_list, default_trainer, trainer_space, model_hp_space
  715. )
  716. hpo_dict = path_or_dict.pop("hpo", {"name": "anneal"})
  717. if hpo_dict is not None:
  718. name = hpo_dict.pop("name")
  719. solver.set_hpo_module(name, **hpo_dict)
  720. ensemble_dict = path_or_dict.pop("ensemble", {"name": "voting"})
  721. if ensemble_dict is not None:
  722. name = ensemble_dict.pop("name")
  723. solver.set_ensemble_module(name, **ensemble_dict)
  724. nas_dict = path_or_dict.pop("nas", None)
  725. if nas_dict is not None:
  726. keys: set = set(nas_dict.keys())
  727. needed = {"space", "algorithm", "estimator"}
  728. if keys != needed:
  729. LOGGER.error("Key mismatch, we need %s, you give %s", needed, keys)
  730. raise KeyError("Key mismatch, we need %s, you give %s" % (needed, keys))
  731. spaces, algorithms, estimators = [], [], []
  732. for container, indexer, k in zip(
  733. [spaces, algorithms, estimators],
  734. [NAS_SPACE_DICT, NAS_ALGO_DICT, NAS_ESTIMATOR_DICT],
  735. ["space", "algorithm", "estimator"],
  736. ):
  737. configs = nas_dict[k]
  738. if isinstance(configs, list):
  739. for item in configs:
  740. container.append(indexer[item.pop("name")](**item))
  741. else:
  742. container.append(indexer[configs.pop("name")](**configs))
  743. solver.set_nas_module(algorithms, spaces, estimators)
  744. return solver