You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_classifier.py 36 kB

5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897
  1. """
  2. Auto Classfier for Node Classification
  3. """
  4. import time
  5. import json
  6. from copy import deepcopy
  7. from typing import Sequence
  8. import torch
  9. import numpy as np
  10. import yaml
  11. from .base import BaseClassifier
  12. from ..base import _parse_hp_space, _initialize_single_model, _parse_model_hp
  13. from ...module.feature import FEATURE_DICT
  14. from ...module.train import TRAINER_DICT, BaseNodeClassificationTrainer
  15. from ...module.train import get_feval
  16. from ...module.nas.space import NAS_SPACE_DICT
  17. from ...module.nas.algorithm import NAS_ALGO_DICT
  18. from ...module.nas.estimator import NAS_ESTIMATOR_DICT, BaseEstimator
  19. from ..utils import LeaderBoard, get_graph_from_dataset, get_graph_labels, get_graph_masks, get_graph_node_features, get_graph_node_number, set_seed, convert_dataset
  20. from ...datasets import utils
  21. from ...utils import get_logger
  22. LOGGER = get_logger("NodeClassifier")
  23. class AutoNodeClassifier(BaseClassifier):
  24. """
  25. Auto Multi-class Graph Node Classifier.
  26. Used to automatically solve the node classification problems.
  27. Parameters
  28. ----------
  29. feature_module: autogl.module.feature.BaseFeatureEngineer or str or None
  30. The (name of) auto feature engineer used to process the given dataset. Default ``deepgl``.
  31. Disable feature engineer by setting it to ``None``.
  32. graph_models: Sequence of models
  33. Models can be ``str``, ``autogl.module.model.BaseAutoModel``,
  34. ``autogl.module.model.encoders.BaseEncoderMaintainer`` or a tuple of (encoder, decoder)
  35. if need to specify both encoder and decoder. Encoder can be ``str`` or
  36. ``autogl.module.model.encoders.BaseEncoderMaintainer``, and decoder can be ``str``
  37. or ``autogl.module.model.decoders.BaseDecoderMaintainer``.
  38. nas_algorithms: (list of) autogl.module.nas.algorithm.BaseNAS or str (Optional)
  39. The (name of) nas algorithms used. Default ``None``.
  40. nas_spaces: (list of) autogl.module.nas.space.BaseSpace or str (Optional)
  41. The (name of) nas spaces used. Default ``None``.
  42. nas_estimators: (list of) autogl.module.nas.estimator.BaseEstimator or str (Optional)
  43. The (name of) nas estimators used. Default ``None``.
  44. hpo_module: autogl.module.hpo.BaseHPOptimizer or str or None
  45. The (name of) hpo module used to search for best hyper parameters. Default ``anneal``.
  46. Disable hpo by setting it to ``None``.
  47. ensemble_module: autogl.module.ensemble.BaseEnsembler or str or None
  48. The (name of) ensemble module used to ensemble the multi-models found. Default ``voting``.
  49. Disable ensemble by setting it to ``None``.
  50. max_evals: int (Optional)
  51. If given, will set the number eval times the hpo module will use.
  52. Only be effective when hpo_module is ``str``. Default ``None``.
  53. default_trainer: str (Optional)
  54. The (name of) the trainer used in this solver. Default to ``NodeClassificationFull``.
  55. trainer_hp_space: list of dict (Optional)
  56. trainer hp space or list of trainer hp spaces configuration.
  57. If a single trainer hp is given, will specify the hp space of trainer for every model.
  58. If a list of trainer hp is given, will specify every model with corrsponding
  59. trainer hp space.
  60. Default ``None``.
  61. model_hp_spaces: list of list of dict (Optional)
  62. model hp space configuration.
  63. If given, will specify every hp space of every passed model. Default ``None``.
  64. If the encoder(-decoder) is passed, the space should be a dict containing keys "encoder"
  65. and "decoder", specifying the detailed encoder decoder hp spaces.
  66. size: int (Optional)
  67. The max models ensemble module will use. Default ``None``.
  68. device: torch.device or str
  69. The device where model will be running on. If set to ``auto``, will use gpu when available.
  70. You can also specify the device by directly giving ``gpu`` or ``cuda:0``, etc.
  71. Default ``auto``.
  72. """
  73. def __init__(
  74. self,
  75. feature_module=None,
  76. graph_models=("gat", "gcn"), # TODO: support a list of model
  77. nas_algorithms=None,
  78. nas_spaces=None,
  79. nas_estimators=None,
  80. hpo_module="anneal",
  81. ensemble_module="voting",
  82. max_evals=50,
  83. default_trainer="NodeClassificationFull",
  84. trainer_hp_space=None,
  85. model_hp_spaces=None,
  86. size=4,
  87. device="auto",
  88. ):
  89. super().__init__(
  90. feature_module=feature_module,
  91. graph_models=graph_models,
  92. nas_algorithms=nas_algorithms,
  93. nas_spaces=nas_spaces,
  94. nas_estimators=nas_estimators,
  95. hpo_module=hpo_module,
  96. ensemble_module=ensemble_module,
  97. max_evals=max_evals,
  98. default_trainer=default_trainer,
  99. trainer_hp_space=trainer_hp_space,
  100. model_hp_spaces=model_hp_spaces,
  101. size=size,
  102. device=device,
  103. )
  104. # data to be kept when fit
  105. self.dataset = None
  106. def _init_graph_module(
  107. self, graph_models, num_classes, num_features, feval, device, loss
  108. ) -> "AutoNodeClassifier":
  109. # load graph network module
  110. self.graph_model_list = []
  111. for i, model in enumerate(graph_models):
  112. # init the trainer
  113. if not isinstance(model, BaseNodeClassificationTrainer):
  114. trainer = (
  115. self._default_trainer if not isinstance(self._default_trainer, (tuple, list))
  116. else self._default_trainer[i]
  117. )
  118. if isinstance(trainer, str):
  119. if trainer not in TRAINER_DICT:
  120. raise KeyError(f"Does not support trainer {trainer}")
  121. trainer = TRAINER_DICT[trainer]()
  122. if isinstance(model, (tuple, list)):
  123. trainer.encoder = model[0]
  124. trainer.decoder = model[1]
  125. else:
  126. trainer.encoder = model
  127. else:
  128. trainer = model
  129. # set model hp space
  130. if self._model_hp_spaces is not None:
  131. if self._model_hp_spaces[i] is not None:
  132. if isinstance(self._model_hp_spaces[i], dict):
  133. encoder_hp_space = self._model_hp_spaces[i].get('encoder', None)
  134. decoder_hp_space = self._model_hp_spaces[i].get('decoder', None)
  135. else:
  136. encoder_hp_space = self._model_hp_spaces[i]
  137. decoder_hp_space = None
  138. if encoder_hp_space is not None:
  139. trainer.encoder.hyper_parameter_space = encoder_hp_space
  140. if decoder_hp_space is not None:
  141. trainer.decoder.hyper_parameter_space = decoder_hp_space
  142. # set trainer hp space
  143. if self._trainer_hp_space is not None:
  144. if isinstance(self._trainer_hp_space[0], list):
  145. current_hp_for_trainer = self._trainer_hp_space[i]
  146. else:
  147. current_hp_for_trainer = self._trainer_hp_space
  148. trainer.hyper_parameter_space = current_hp_for_trainer
  149. trainer.num_features = num_features
  150. trainer.num_classes = num_classes
  151. trainer.loss = loss
  152. trainer.feval = feval
  153. trainer.to(device)
  154. self.graph_model_list.append(trainer)
  155. return self
  156. def _init_nas_module(self, num_features, num_classes, feval, device, loss):
  157. for algo, space, estimator in zip(
  158. self.nas_algorithms, self.nas_spaces, self.nas_estimators
  159. ):
  160. estimator: BaseEstimator
  161. algo.to(device)
  162. space.instantiate(input_dim=num_features, output_dim=num_classes)
  163. estimator.setEvaluation(feval)
  164. estimator.setLossFunction(loss)
  165. # pylint: disable=arguments-differ
  166. def fit(
  167. self,
  168. dataset,
  169. time_limit=-1,
  170. inplace=False,
  171. train_split=None,
  172. val_split=None,
  173. balanced=True,
  174. evaluation_method="infer",
  175. seed=None,
  176. ) -> "AutoNodeClassifier":
  177. """
  178. Fit current solver on given dataset.
  179. Parameters
  180. ----------
  181. dataset: autogl.data.Dataset
  182. The dataset needed to fit on. This dataset must have only one graph.
  183. time_limit: int
  184. The time limit of the whole fit process (in seconds). If set below 0,
  185. will ignore time limit. Default ``-1``.
  186. inplace: bool
  187. Whether we process the given dataset in inplace manner. Default ``False``.
  188. Set it to True if you want to save memory by modifying the given dataset directly.
  189. train_split: float or int (Optional)
  190. The train ratio (in ``float``) or number (in ``int``) of dataset. If you want to
  191. use default train/val/test split in dataset, please set this to ``None``.
  192. Default ``None``.
  193. val_split: float or int (Optional)
  194. The validation ratio (in ``float``) or number (in ``int``) of dataset. If you want
  195. to use default train/val/test split in dataset, please set this to ``None``.
  196. Default ``None``.
  197. balanced: bool
  198. Wether to create the train/valid/test split in a balanced way.
  199. If set to ``True``, the train/valid will have the same number of different classes.
  200. Default ``True``.
  201. evaluation_method: (list of) str or autogl.module.train.evaluation
  202. A (list of) evaluation method for current solver. If ``infer``, will automatically
  203. determine. Default ``infer``.
  204. seed: int (Optional)
  205. The random seed. If set to ``None``, will run everything at random.
  206. Default ``None``.
  207. Returns
  208. -------
  209. self: autogl.solver.AutoNodeClassifier
  210. A reference of current solver.
  211. """
  212. set_seed(seed)
  213. if time_limit < 0:
  214. time_limit = 3600 * 24
  215. time_begin = time.time()
  216. graph_data = get_graph_from_dataset(dataset, 0)
  217. all_labels = get_graph_labels(graph_data)
  218. num_classes = all_labels.max().item() + 1
  219. # initialize leaderboard
  220. if evaluation_method == "infer":
  221. if hasattr(dataset, "metric"):
  222. evaluation_method = [dataset.metric]
  223. else:
  224. num_of_label = num_classes
  225. if num_of_label == 2:
  226. evaluation_method = ["auc"]
  227. else:
  228. evaluation_method = ["acc"]
  229. assert isinstance(evaluation_method, list)
  230. evaluator_list = get_feval(evaluation_method)
  231. self.leaderboard = LeaderBoard(
  232. [e.get_eval_name() for e in evaluator_list],
  233. {e.get_eval_name(): e.is_higher_better() for e in evaluator_list},
  234. )
  235. # set up the dataset
  236. if train_split is not None and val_split is not None:
  237. size = get_graph_node_number(graph_data)
  238. if balanced:
  239. train_split = (
  240. train_split if train_split > 1 else int(train_split * size)
  241. )
  242. val_split = val_split if val_split > 1 else int(val_split * size)
  243. utils.random_splits_mask_class(
  244. dataset,
  245. num_train_per_class=train_split // num_classes,
  246. num_val_per_class=val_split // num_classes,
  247. seed=seed,
  248. )
  249. else:
  250. train_split = train_split if train_split < 1 else train_split / size
  251. val_split = val_split if val_split < 1 else val_split / size
  252. utils.random_splits_mask(
  253. dataset, train_ratio=train_split, val_ratio=val_split
  254. )
  255. else:
  256. assert get_graph_masks(graph_data, 'train') is not None and get_graph_masks(graph_data, 'val') is not None, (
  257. "The dataset has no default train/val split! Please manually pass "
  258. "train and val ratio."
  259. )
  260. LOGGER.info("Use the default train/val/test ratio in given dataset")
  261. # feature engineering
  262. if self.feature_module is not None:
  263. dataset = self.feature_module.fit_transform(dataset, inplace=inplace)
  264. self.dataset = dataset
  265. # check whether the dataset has features.
  266. # currently we only support graph classification with features.
  267. graph_data = get_graph_from_dataset(dataset, 0)
  268. feat = get_graph_node_features(graph_data)
  269. assert feat is not None, (
  270. "Does not support fit on non node-feature dataset!"
  271. " Please add node features to dataset or specify feature engineers that generate"
  272. " node features."
  273. )
  274. num_features = feat.size(-1)
  275. # initialize graph networks
  276. self._init_graph_module(
  277. self.gml,
  278. num_features=num_features,
  279. num_classes=num_classes,
  280. feval=evaluator_list,
  281. device=self.runtime_device,
  282. loss="nll_loss" if not hasattr(dataset, "loss") else self.dataset.loss,
  283. )
  284. if self.nas_algorithms is not None:
  285. # perform neural architecture search
  286. self._init_nas_module(
  287. num_features=num_features,
  288. num_classes=num_classes,
  289. feval=evaluator_list,
  290. device=self.runtime_device,
  291. loss="nll_loss" if not hasattr(dataset, "loss") else dataset.loss,
  292. )
  293. assert not isinstance(self._default_trainer, list) or len(
  294. self.nas_algorithms
  295. ) == len(self._default_trainer) - len(
  296. self.graph_model_list
  297. ), "length of default trainer should match total graph models and nas models passed"
  298. # perform nas and add them to model list
  299. idx_trainer = len(self.graph_model_list)
  300. for algo, space, estimator in zip(
  301. self.nas_algorithms, self.nas_spaces, self.nas_estimators
  302. ):
  303. model = algo.search(space, convert_dataset(self.dataset), estimator)
  304. # insert model into default trainer
  305. if isinstance(self._default_trainer, list):
  306. train_name = self._default_trainer[idx_trainer]
  307. idx_trainer += 1
  308. else:
  309. train_name = self._default_trainer
  310. if isinstance(train_name, str):
  311. trainer = TRAINER_DICT[train_name](
  312. model=model,
  313. num_features=num_features,
  314. num_classes=num_classes,
  315. loss="nll_loss"
  316. if not hasattr(dataset, "loss")
  317. else dataset.loss,
  318. feval=evaluator_list,
  319. device=self.runtime_device,
  320. init=False,
  321. )
  322. elif isinstance(train_name, BaseNodeClassificationTrainer):
  323. trainer = train_name
  324. trainer.encoder = model
  325. trainer.num_features = num_features
  326. trainer.num_classes = num_classes
  327. trainer.loss = "nll_loss" if not hasattr(dataset, "loss") else dataset.loss
  328. trainer.feval = evaluator_list
  329. trainer.to(self.runtime_device)
  330. else:
  331. raise ValueError()
  332. self.graph_model_list.append(trainer)
  333. # train the models and tune hpo
  334. result_valid = []
  335. names = []
  336. for idx, model in enumerate(self.graph_model_list):
  337. model: BaseNodeClassificationTrainer
  338. time_for_each_model = (time_limit - time.time() + time_begin) / (
  339. len(self.graph_model_list) - idx
  340. )
  341. if self.hpo_module is None:
  342. model.initialize()
  343. model.train(convert_dataset(self.dataset), True)
  344. optimized = model
  345. else:
  346. optimized, _ = self.hpo_module.optimize(
  347. trainer=model, dataset=convert_dataset(self.dataset), time_limit=time_for_each_model
  348. )
  349. # to save memory, all the trainer derived will be mapped to cpu
  350. optimized.to(torch.device("cpu"))
  351. name = str(optimized) + "_idx%d" % (idx)
  352. names.append(name)
  353. performance_on_valid, _ = optimized.get_valid_score(return_major=False)
  354. result_valid.append(optimized.get_valid_predict_proba().cpu().numpy())
  355. self.leaderboard.insert_model_performance(
  356. name,
  357. dict(
  358. zip(
  359. [e.get_eval_name() for e in evaluator_list],
  360. performance_on_valid,
  361. )
  362. ),
  363. )
  364. self.trained_models[name] = optimized
  365. # fit the ensemble model
  366. if self.ensemble_module is not None:
  367. performance = self.ensemble_module.fit(
  368. result_valid,
  369. all_labels[get_graph_masks(graph_data, 'val')].cpu().numpy(),
  370. names,
  371. evaluator_list,
  372. n_classes=num_classes,
  373. )
  374. self.leaderboard.insert_model_performance(
  375. "ensemble",
  376. dict(zip([e.get_eval_name() for e in evaluator_list], performance)),
  377. )
  378. return self
  379. def fit_predict(
  380. self,
  381. dataset,
  382. time_limit=-1,
  383. inplace=False,
  384. train_split=None,
  385. val_split=None,
  386. balanced=True,
  387. evaluation_method="infer",
  388. use_ensemble=True,
  389. use_best=True,
  390. name=None,
  391. ) -> np.ndarray:
  392. """
  393. Fit current solver on given dataset and return the predicted value.
  394. Parameters
  395. ----------
  396. dataset: torch_geometric.data.dataset.Dataset
  397. The dataset needed to fit on. This dataset must have only one graph.
  398. time_limit: int
  399. The time limit of the whole fit process (in seconds).
  400. If set below 0, will ignore time limit. Default ``-1``.
  401. inplace: bool
  402. Whether we process the given dataset in inplace manner. Default ``False``.
  403. Set it to True if you want to save memory by modifying the given dataset directly.
  404. train_split: float or int (Optional)
  405. The train ratio (in ``float``) or number (in ``int``) of dataset. If you want to
  406. use default train/val/test split in dataset, please set this to ``None``.
  407. Default ``None``.
  408. val_split: float or int (Optional)
  409. The validation ratio (in ``float``) or number (in ``int``) of dataset. If you want
  410. to use default train/val/test split in dataset, please set this to ``None``.
  411. Default ``None``.
  412. balanced: bool
  413. Wether to create the train/valid/test split in a balanced way.
  414. If set to ``True``, the train/valid will have the same number of different classes.
  415. Default ``False``.
  416. evaluation_method: (list of) str or autogl.module.train.evaluation
  417. A (list of) evaluation method for current solver. If ``infer``, will automatically
  418. determine. Default ``infer``.
  419. use_ensemble: bool
  420. Whether to use ensemble to do the predict. Default ``True``.
  421. use_best: bool
  422. Whether to use the best single model to do the predict. Will only be effective when
  423. ``use_ensemble`` is ``False``.
  424. Default ``True``.
  425. name: str or None
  426. The name of model used to predict. Will only be effective when ``use_ensemble`` and
  427. ``use_best`` both are ``False``.
  428. Default ``None``.
  429. Returns
  430. -------
  431. result: np.ndarray
  432. An array of shape ``(N,)``, where ``N`` is the number of test nodes. The prediction
  433. on given dataset.
  434. """
  435. self.fit(
  436. dataset=dataset,
  437. time_limit=time_limit,
  438. inplace=inplace,
  439. train_split=train_split,
  440. val_split=val_split,
  441. balanced=balanced,
  442. evaluation_method=evaluation_method,
  443. )
  444. return self.predict(
  445. dataset=dataset,
  446. inplaced=inplace,
  447. inplace=inplace,
  448. use_ensemble=use_ensemble,
  449. use_best=use_best,
  450. name=name,
  451. )
  452. def predict_proba(
  453. self,
  454. dataset=None,
  455. inplaced=False,
  456. inplace=False,
  457. use_ensemble=True,
  458. use_best=True,
  459. name=None,
  460. mask="test",
  461. ) -> np.ndarray:
  462. """
  463. Predict the node probability.
  464. Parameters
  465. ----------
  466. dataset: torch_geometric.data.dataset.Dataset or None
  467. The dataset needed to predict. If ``None``, will use the processed dataset passed
  468. to ``fit()`` instead. Default ``None``.
  469. inplaced: bool
  470. Whether the given dataset is processed. Only be effective when ``dataset``
  471. is not ``None``. If you pass the dataset to ``fit()`` with ``inplace=True``, and
  472. you pass the dataset again to this method, you should set this argument to ``True``.
  473. Otherwise ``False``. Default ``False``.
  474. inplace: bool
  475. Whether we process the given dataset in inplace manner. Default ``False``. Set it to
  476. True if you want to save memory by modifying the given dataset directly.
  477. use_ensemble: bool
  478. Whether to use ensemble to do the predict. Default ``True``.
  479. use_best: bool
  480. Whether to use the best single model to do the predict. Will only be effective when
  481. ``use_ensemble`` is ``False``. Default ``True``.
  482. name: str or None
  483. The name of model used to predict. Will only be effective when ``use_ensemble`` and
  484. ``use_best`` both are ``False``. Default ``None``.
  485. mask: str
  486. The data split to give prediction on. Default ``test``.
  487. Returns
  488. -------
  489. result: np.ndarray
  490. An array of shape ``(N,C,)``, where ``N`` is the number of test nodes and ``C`` is
  491. the number of classes. The prediction on given dataset.
  492. """
  493. if dataset is None:
  494. dataset = self.dataset
  495. assert dataset is not None, (
  496. "Please execute fit() first before" " predicting on remembered dataset"
  497. )
  498. elif not inplaced and self.feature_module is not None:
  499. dataset = self.feature_module.transform(dataset, inplace=inplace)
  500. if use_ensemble:
  501. LOGGER.info("Ensemble argument on, will try using ensemble model.")
  502. if not use_ensemble and use_best:
  503. LOGGER.info(
  504. "Ensemble argument off and best argument on, will try using best model."
  505. )
  506. if (use_ensemble and self.ensemble_module is not None) or (
  507. not use_best and name == "ensemble"
  508. ):
  509. # we need to get all the prediction of every model trained
  510. predict_result = []
  511. names = []
  512. for model_name in self.trained_models:
  513. predict_result.append(
  514. self._predict_proba_by_name(dataset, model_name, mask)
  515. )
  516. names.append(model_name)
  517. return self.ensemble_module.ensemble(predict_result, names)
  518. if use_ensemble and self.ensemble_module is None:
  519. LOGGER.warning(
  520. "Cannot use ensemble because no ensebmle module is given."
  521. " Will use best model instead."
  522. )
  523. if use_best or (use_ensemble and self.ensemble_module is None):
  524. # just return the best model we have found
  525. name = self.leaderboard.get_best_model()
  526. return self._predict_proba_by_name(dataset, name, mask)
  527. if name is not None:
  528. # return model performance by name
  529. return self._predict_proba_by_name(dataset, name, mask)
  530. LOGGER.error(
  531. "No model name is given while ensemble and best arguments are off."
  532. )
  533. raise ValueError(
  534. "You need to specify a model name if you do not want use ensemble and best model."
  535. )
  536. def _predict_proba_by_name(self, dataset, name, mask="test"):
  537. self.trained_models[name].to(self.runtime_device)
  538. predicted = (
  539. self.trained_models[name].predict_proba(convert_dataset(dataset), mask=mask).cpu().numpy()
  540. )
  541. self.trained_models[name].to(torch.device("cpu"))
  542. return predicted
  543. def predict(
  544. self,
  545. dataset=None,
  546. inplaced=False,
  547. inplace=False,
  548. use_ensemble=True,
  549. use_best=True,
  550. name=None,
  551. mask="test",
  552. ) -> np.ndarray:
  553. """
  554. Predict the node class number.
  555. Parameters
  556. ----------
  557. dataset: torch_geometric.data.dataset.Dataset or None
  558. The dataset needed to predict. If ``None``, will use the processed dataset passed
  559. to ``fit()`` instead. Default ``None``.
  560. inplaced: bool
  561. Whether the given dataset is processed. Only be effective when ``dataset``
  562. is not ``None``. If you pass the dataset to ``fit()`` with ``inplace=True``,
  563. and you pass the dataset again to this method, you should set this argument
  564. to ``True``. Otherwise ``False``. Default ``False``.
  565. inplace: bool
  566. Whether we process the given dataset in inplace manner. Default ``False``.
  567. Set it to True if you want to save memory by modifying the given dataset directly.
  568. use_ensemble: bool
  569. Whether to use ensemble to do the predict. Default ``True``.
  570. use_best: bool
  571. Whether to use the best single model to do the predict. Will only be effective
  572. when ``use_ensemble`` is ``False``. Default ``True``.
  573. name: str or None
  574. The name of model used to predict. Will only be effective when ``use_ensemble``
  575. and ``use_best`` both are ``False``. Default ``None``.
  576. mask: str
  577. The data split to give prediction on. Default ``test``.
  578. Returns
  579. -------
  580. result: np.ndarray
  581. An array of shape ``(N,)``, where ``N`` is the number of test nodes.
  582. The prediction on given dataset.
  583. """
  584. proba = self.predict_proba(
  585. dataset, inplaced, inplace, use_ensemble, use_best, name, mask
  586. )
  587. return np.argmax(proba, axis=1)
  588. def evaluate(self, dataset=None,
  589. inplaced=False,
  590. inplace=False,
  591. use_ensemble=True,
  592. use_best=True,
  593. name=None,
  594. mask="test",
  595. label=None,
  596. metric="acc"
  597. ):
  598. """
  599. Evaluate the given dataset.
  600. Parameters
  601. ----------
  602. dataset: torch_geometric.data.dataset.Dataset or None
  603. The dataset needed to predict. If ``None``, will use the processed dataset passed
  604. to ``fit()`` instead. Default ``None``.
  605. inplaced: bool
  606. Whether the given dataset is processed. Only be effective when ``dataset``
  607. is not ``None``. If you pass the dataset to ``fit()`` with ``inplace=True``, and
  608. you pass the dataset again to this method, you should set this argument to ``True``.
  609. Otherwise ``False``. Default ``False``.
  610. inplace: bool
  611. Whether we process the given dataset in inplace manner. Default ``False``. Set it to
  612. True if you want to save memory by modifying the given dataset directly.
  613. use_ensemble: bool
  614. Whether to use ensemble to do the predict. Default ``True``.
  615. use_best: bool
  616. Whether to use the best single model to do the predict. Will only be effective when
  617. ``use_ensemble`` is ``False``. Default ``True``.
  618. name: str or None
  619. The name of model used to predict. Will only be effective when ``use_ensemble`` and
  620. ``use_best`` both are ``False``. Default ``None``.
  621. mask: str
  622. The data split to give prediction on. Default ``test``.
  623. label: torch.Tensor (Optional)
  624. The groud truth label of the given predicted dataset split. If not passed, will extract
  625. labels from the input dataset.
  626. metric: str
  627. The metric to be used for evaluating the model. Default ``acc``.
  628. Returns
  629. -------
  630. score(s): (list of) evaluation scores
  631. the evaluation results according to the evaluator passed.
  632. """
  633. predicted = self.predict_proba(dataset, inplaced, inplace, use_ensemble, use_best, name, mask)
  634. if dataset is None:
  635. dataset = self.dataset
  636. if label is None:
  637. label = get_graph_labels(dataset[0])[get_graph_masks(dataset[0], mask)].cpu().numpy()
  638. evaluator = get_feval(metric)
  639. if isinstance(evaluator, Sequence):
  640. return [evals.evaluate(predicted, label) for evals in evaluator]
  641. return evaluator.evaluate(predicted, label)
  642. @classmethod
  643. def from_config(cls, path_or_dict, filetype="auto") -> "AutoNodeClassifier":
  644. """
  645. Load solver from config file.
  646. You can use this function to directly load a solver from predefined config dict
  647. or config file path. Currently, only support file type of ``json`` or ``yaml``,
  648. if you pass a path.
  649. Parameters
  650. ----------
  651. path_or_dict: str or dict
  652. The path to the config file or the config dictionary object
  653. filetype: str
  654. The filetype the given file if the path is specified. Currently only support
  655. ``json`` or ``yaml``. You can set to ``auto`` to automatically detect the file
  656. type (from file name). Default ``auto``.
  657. Returns
  658. -------
  659. solver: autogl.solver.AutoGraphClassifier
  660. The solver that is created from given file or dictionary.
  661. """
  662. assert filetype in ["auto", "yaml", "json"], (
  663. "currently only support yaml file or json file type, but get type "
  664. + filetype
  665. )
  666. if isinstance(path_or_dict, str):
  667. if filetype == "auto":
  668. if path_or_dict.endswith(".yaml") or path_or_dict.endswith(".yml"):
  669. filetype = "yaml"
  670. elif path_or_dict.endswith(".json"):
  671. filetype = "json"
  672. else:
  673. LOGGER.error(
  674. "cannot parse the type of the given file name, "
  675. "please manually set the file type"
  676. )
  677. raise ValueError(
  678. "cannot parse the type of the given file name, "
  679. "please manually set the file type"
  680. )
  681. if filetype == "yaml":
  682. path_or_dict = yaml.load(
  683. open(path_or_dict, "r").read(), Loader=yaml.FullLoader
  684. )
  685. else:
  686. path_or_dict = json.load(open(path_or_dict, "r"))
  687. path_or_dict = deepcopy(path_or_dict)
  688. solver = cls(None, [], None, None)
  689. fe_list = path_or_dict.pop("feature", None)
  690. if fe_list is not None:
  691. fe_list_ele = []
  692. for feature_engineer in fe_list:
  693. name = feature_engineer.pop("name")
  694. if name is not None:
  695. fe_list_ele.append(FEATURE_DICT[name](**feature_engineer))
  696. if fe_list_ele != []:
  697. solver.set_feature_module(fe_list_ele)
  698. models = path_or_dict.pop("models", [{"name": "gcn"}, {"name": "gat"}, {"name": "sage"}, {"name": "gin"}])
  699. # models should be a list of model
  700. # with each element in two cases
  701. # * a dict describing a certain model
  702. # * a dict containing {"encoder": encoder, "decoder": decoder}
  703. model_hp_space = [
  704. _parse_model_hp(model) for model in models
  705. ]
  706. model_list = [
  707. _initialize_single_model(model) for model in models
  708. ]
  709. trainer = path_or_dict.pop("trainer", None)
  710. default_trainer = "NodeClassificationFull"
  711. trainer_space = None
  712. if isinstance(trainer, dict):
  713. # global default
  714. default_trainer = trainer.pop("name", "NodeClassificationFull")
  715. trainer_space = _parse_hp_space(trainer.pop("hp_space", None))
  716. default_kwargs = {"num_features": None, "num_classes": None}
  717. default_kwargs.update(trainer)
  718. default_kwargs["init"] = False
  719. for i in range(len(model_list)):
  720. model = model_list[i]
  721. trainer_wrap = TRAINER_DICT[default_trainer](
  722. model=model, **default_kwargs
  723. )
  724. model_list[i] = trainer_wrap
  725. elif isinstance(trainer, list):
  726. # sequential trainer definition
  727. assert len(trainer) == len(
  728. model_list
  729. ), "The number of trainer and model does not match"
  730. trainer_space = []
  731. for i in range(len(model_list)):
  732. train, model = trainer[i], model_list[i]
  733. default_trainer = train.pop("name", "NodeClassificationFull")
  734. trainer_space.append(_parse_hp_space(train.pop("hp_space", None)))
  735. default_kwargs = {"num_features": None, "num_classes": None}
  736. default_kwargs.update(train)
  737. default_kwargs["init"] = False
  738. trainer_wrap = TRAINER_DICT[default_trainer](
  739. model=model, **default_kwargs
  740. )
  741. model_list[i] = trainer_wrap
  742. solver.set_graph_models(
  743. model_list, default_trainer, trainer_space, model_hp_space
  744. )
  745. hpo_dict = path_or_dict.pop("hpo", {"name": "anneal"})
  746. if hpo_dict is not None:
  747. name = hpo_dict.pop("name")
  748. solver.set_hpo_module(name, **hpo_dict)
  749. ensemble_dict = path_or_dict.pop("ensemble", {"name": "voting"})
  750. if ensemble_dict is not None:
  751. name = ensemble_dict.pop("name")
  752. solver.set_ensemble_module(name, **ensemble_dict)
  753. nas_dict = path_or_dict.pop("nas", None)
  754. if nas_dict is not None:
  755. keys: set = set(nas_dict.keys())
  756. needed = {"space", "algorithm", "estimator"}
  757. if keys != needed:
  758. LOGGER.error("Key mismatch, we need %s, you give %s", needed, keys)
  759. raise KeyError("Key mismatch, we need %s, you give %s" % (needed, keys))
  760. spaces, algorithms, estimators = [], [], []
  761. for container, indexer, k in zip(
  762. [spaces, algorithms, estimators],
  763. [NAS_SPACE_DICT, NAS_ALGO_DICT, NAS_ESTIMATOR_DICT],
  764. ["space", "algorithm", "estimator"],
  765. ):
  766. configs = nas_dict[k]
  767. if isinstance(configs, list):
  768. for item in configs:
  769. container.append(indexer[item.pop("name")](**item))
  770. else:
  771. container.append(indexer[configs.pop("name")](**configs))
  772. solver.set_nas_module(algorithms, spaces, estimators)
  773. return solver