You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ssl_graph_classifier.py 31 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802
  1. """
  2. Auto Classfier for Graph Node Classification
  3. """
  4. import time
  5. import json
  6. from copy import deepcopy
  7. from typing import Sequence
  8. import torch
  9. import numpy as np
  10. import yaml
  11. from .utils import _parse_hp_space, _parse_model_hp, _initialize_single_model
  12. from ..base import BaseClassifier
  13. from ....module.model import EncoderUniversalRegistry, DecoderUniversalRegistry, ModelUniversalRegistry
  14. from ....module.feature import FEATURE_DICT
  15. from ....module.train import TRAINER_DICT, get_feval, BaseGraphClassificationTrainer
  16. from ....module.train.ssl import BaseContrastiveTrainer
  17. from ...utils import LeaderBoard, get_dataset_labels, set_seed, get_graph_from_dataset, get_graph_node_features, convert_dataset
  18. from ....datasets import utils
  19. from ...utils import get_logger
  20. from ....backend import DependentBackend
  21. LOGGER = get_logger("SSLGraphClassifier")
  22. BACKEND = DependentBackend.get_backend_name()
  23. class SSLGraphClassifier(BaseClassifier):
  24. """
  25. Auto Multi-class Graph Classifier.
  26. Used to automatically solve the graph classification problems.
  27. Parameters
  28. ----------
  29. feature_module: autogl.module.feature.BaseFeatureEngineer or str or None
  30. The (name of) auto feature engineer used to process the given dataset.
  31. Disable feature engineer by setting it to ``None``. Default ``deepgl``.
  32. graph_models: list of autogl.module.model.BaseModel or list of str
  33. The (name of) models to be optimized as backbone. Default ``['gat', 'gcn']``.
  34. hpo_module: autogl.module.hpo.BaseHPOptimizer or str or None
  35. The (name of) hpo module used to search for best hyper parameters.
  36. Disable hpo by setting it to ``None``. Default ``anneal``.
  37. ensemble_module: autogl.module.ensemble.BaseEnsembler or str or None
  38. The (name of) ensemble module used to ensemble the multi-models found.
  39. Disable ensemble by setting it to ``None``. Default ``voting``.
  40. max_evals: int (Optional)
  41. If given, will set the number eval times the hpo module will use.
  42. Only be effective when hpo_module is ``str``. Default ``None``.
  43. default_trainer: str (Optional)
  44. The (name of) the trainer used in this solver. Default to ``NodeClassificationFull``.
  45. trainer_hp_space: Iterable[dict] (Optional)
  46. trainer hp space or list of trainer hp spaces configuration.
  47. If a single trainer hp is given, will specify the hp space of trainer for
  48. every model. If a list of trainer hp is given, will specify every model
  49. with corrsponding trainer hp space. Default ``None``.
  50. model_hp_spaces: list of list of dict (Optional)
  51. model hp space configuration.
  52. If given, will specify every hp space of every passed model. Default ``None``.
  53. If the encoder(-decoder) is passed, the space should be a dict containing keys "encoder"
  54. and "decoder", specifying the detailed encoder decoder hp spaces.
  55. size: int (Optional)
  56. The max models ensemble module will use. Default ``None``.
  57. device: torch.device or str
  58. The device where model will be running on. If set to ``auto``, will use gpu
  59. when available. You can also specify the device by directly giving ``gpu`` or
  60. ``cuda:0``, etc. Default ``auto``.
  61. """
  62. # pylint: disable=W0102
  63. def __init__(
  64. self,
  65. feature_module=None,
  66. graph_models=("gin", "topkpool"),
  67. hpo_module="anneal",
  68. ensemble_module="voting",
  69. max_evals=50,
  70. default_trainer="GraphClassificationFull",
  71. trainer_hp_space=None,
  72. model_hp_spaces=None,
  73. size=4,
  74. device="auto",
  75. ):
  76. super().__init__(
  77. feature_module=feature_module,
  78. graph_models=graph_models,
  79. nas_algorithms=None,
  80. nas_spaces=None,
  81. nas_estimators=None,
  82. hpo_module=hpo_module,
  83. ensemble_module=ensemble_module,
  84. max_evals=max_evals,
  85. default_trainer=default_trainer,
  86. trainer_hp_space=trainer_hp_space,
  87. model_hp_spaces=model_hp_spaces,
  88. size=size,
  89. device=device,
  90. )
  91. self.dataset = None
  92. def _init_graph_module(
  93. self,
  94. graph_models,
  95. num_classes,
  96. num_features,
  97. feval,
  98. device,
  99. loss,
  100. num_graph_features,
  101. ) -> "SSLGraphClassifier":
  102. # load graph network module
  103. self.graph_model_list = []
  104. for i, model in enumerate(graph_models):
  105. # init the trainer
  106. if not isinstance(model, BaseContrastiveTrainer):
  107. trainer = (
  108. self._default_trainer if not isinstance(self._default_trainer, (tuple, list))
  109. else self._default_trainer[i]
  110. )
  111. if isinstance(trainer, str):
  112. if trainer not in TRAINER_DICT:
  113. raise KeyError(f"Does not support trainer {trainer}")
  114. trainer = TRAINER_DICT[trainer]()
  115. if isinstance(model, (tuple, list)):
  116. trainer.encoder = model[0]
  117. trainer.decoder = model[1]
  118. else:
  119. trainer.encoder = model
  120. else:
  121. trainer = model
  122. # set model hp space
  123. if self._model_hp_spaces is not None:
  124. if self._model_hp_spaces[i] is not None:
  125. if isinstance(self._model_hp_spaces[i], dict):
  126. encoder_hp_space = self._model_hp_spaces[i].get('encoder', None)
  127. decoder_hp_space = self._model_hp_spaces[i].get('decoder', None)
  128. prediction_head_hp_space = self._model_hp_spaces[i].get('prediction_head', None)
  129. else:
  130. encoder_hp_space = self._model_hp_spaces[i]
  131. decoder_hp_space = None
  132. if encoder_hp_space is not None:
  133. trainer.encoder.hyper_parameter_space = encoder_hp_space
  134. if decoder_hp_space is not None:
  135. trainer.decoder.hyper_parameter_space = decoder_hp_space
  136. if prediction_head_hp_space is not None:
  137. trainer.prediction_head.hyper_parameter_space = prediction_head_hp_space
  138. # set trainer hp space
  139. if self._trainer_hp_space is not None:
  140. if isinstance(self._trainer_hp_space[0], list):
  141. current_hp_for_trainer = self._trainer_hp_space[i]
  142. else:
  143. current_hp_for_trainer = self._trainer_hp_space
  144. trainer.hyper_parameter_space = current_hp_for_trainer
  145. trainer.num_features = num_features
  146. trainer.num_classes = num_classes
  147. trainer.num_graph_features = num_graph_features
  148. trainer.loss = loss
  149. trainer.feval = feval
  150. trainer.to(device)
  151. self.graph_model_list.append(trainer)
  152. return self
  153. # pylint: disable=arguments-differ
  154. def fit(
  155. self,
  156. dataset,
  157. time_limit=-1,
  158. inplace=False,
  159. train_split=None,
  160. val_split=None,
  161. evaluation_method="infer",
  162. seed=None,
  163. ) -> "AutoGraphClassifier":
  164. """
  165. Fit current solver on given dataset.
  166. Parameters
  167. ----------
  168. dataset: autogl.data.dataset
  169. The multi-graph dataset needed to fit on.
  170. time_limit: int
  171. The time limit of the whole fit process (in seconds). If set below 0, will ignore
  172. time limit. Default ``-1``.
  173. inplace: bool
  174. Whether we process the given dataset in inplace manner. Default ``False``.
  175. Set it to True if you want to save memory by modifying the given dataset directly.
  176. train_split: float or int (Optional)
  177. The train ratio (in ``float``) or number (in ``int``) of dataset. If you want to use
  178. default train/val/test split in dataset, please set this to ``None``.
  179. Default ``None``.
  180. val_split: float or int (Optional)
  181. The validation ratio (in ``float``) or number (in ``int``) of dataset. If you want to
  182. use default train/val/test split in dataset, please set this to ``None``.
  183. Default ``None``.
  184. evaluation_method: (list of) str autogl.module.train.evaluation
  185. A (list of) evaluation method for current solver. If ``infer``, will automatically
  186. determine. Default ``infer``.
  187. seed: int (Optional)
  188. The random seed. If set to ``None``, will run everything at random.
  189. Default ``None``.
  190. Returns
  191. -------
  192. self: autogl.solver.AutoGraphClassifier
  193. A reference of current solver.
  194. """
  195. set_seed(seed)
  196. num_classes = get_dataset_labels(dataset).max().item() + 1
  197. if time_limit < 0:
  198. time_limit = 3600 * 24
  199. time_begin = time.time()
  200. # initialize leaderboard
  201. if evaluation_method == "infer":
  202. if hasattr(dataset, "metric"):
  203. evaluation_method = [dataset.metric]
  204. else:
  205. if num_classes == 2:
  206. evaluation_method = ["auc"]
  207. else:
  208. evaluation_method = ["acc"]
  209. assert isinstance(evaluation_method, list)
  210. evaluator_list = get_feval(evaluation_method)
  211. self.leaderboard = LeaderBoard(
  212. [e.get_eval_name() for e in evaluator_list],
  213. {e.get_eval_name(): e.is_higher_better() for e in evaluator_list},
  214. )
  215. # set up the dataset
  216. if train_split is None and val_split is None:
  217. assert hasattr(dataset, "train_split") and hasattr(dataset, "val_split"), (
  218. "The dataset has no default train/val split! "
  219. "Please manually pass train and val ratio."
  220. )
  221. LOGGER.info("Use the default train/val/test ratio in given dataset")
  222. # if hasattr(dataset.train_split, "n_splits"):
  223. # cross_validation = True
  224. elif train_split is not None and val_split is not None:
  225. utils.graph_random_splits(dataset, train_split, val_split, seed=seed)
  226. else:
  227. LOGGER.error(
  228. "Please set both train_split and val_split explicitly. Detect %s is None.",
  229. "train_split" if train_split is None else "val_split",
  230. )
  231. raise ValueError(
  232. "In consistent setting of train/val split. Detect {} is None.".format(
  233. "train_split" if train_split is None else "val_split"
  234. )
  235. )
  236. # feature engineering
  237. if self.feature_module is not None:
  238. self.feature_module.fit(dataset.train_split)
  239. dataset = self.feature_module.transform(dataset, inplace=inplace)
  240. self.dataset = dataset
  241. # check whether the dataset has features.
  242. # currently we only support graph classification with features.
  243. feat = get_graph_node_features(get_graph_from_dataset(dataset))
  244. assert feat is not None, (
  245. "Does not support fit on non node-feature dataset!"
  246. " Please add node features to dataset or specify feature engineers that generate"
  247. " node features."
  248. )
  249. num_features = feat.size(-1)
  250. # initialize graph networks
  251. self._init_graph_module(
  252. self.gml,
  253. num_features=num_features,
  254. num_classes=num_classes,
  255. feval=evaluator_list,
  256. device=self.runtime_device,
  257. loss="NT_Xent" if not hasattr(dataset, "loss") else dataset.loss,
  258. num_graph_features=(0
  259. if not hasattr(dataset[0], "gf")
  260. else dataset[0].gf.size(1)) if BACKEND == 'pyg' else
  261. (0 if 'gf' not in dataset[0].data else dataset[0].data['gf'].size(1)),
  262. )
  263. # train the models and tune hpo
  264. result_valid = []
  265. names = []
  266. for idx, model in enumerate(self.graph_model_list):
  267. if time_limit < 0:
  268. time_for_each_model = None
  269. else:
  270. time_for_each_model = (time_limit - time.time() + time_begin) / (
  271. len(self.graph_model_list) - idx
  272. )
  273. if self.hpo_module is None:
  274. model.initialize()
  275. model.train(convert_dataset(dataset), True)
  276. optimized = model
  277. else:
  278. optimized, _ = self.hpo_module.optimize(
  279. trainer=model, dataset=convert_dataset(dataset), time_limit=time_for_each_model
  280. )
  281. # to save memory, all the trainer derived will be mapped to cpu
  282. optimized.to(torch.device("cpu"))
  283. name = str(optimized) + "_idx%d" % (idx)
  284. names.append(name)
  285. performance_on_valid, _ = optimized.get_valid_score(return_major=False)
  286. result_valid.append(
  287. optimized.get_valid_predict_proba().detach().cpu().numpy()
  288. )
  289. self.leaderboard.insert_model_performance(
  290. name,
  291. dict(
  292. zip(
  293. [e.get_eval_name() for e in evaluator_list],
  294. performance_on_valid,
  295. )
  296. ),
  297. )
  298. self.trained_models[name] = optimized
  299. # fit the ensemble model
  300. if self.ensemble_module is not None:
  301. performance = self.ensemble_module.fit(
  302. result_valid,
  303. get_dataset_labels(dataset)[dataset.val_index].cpu().numpy(),
  304. names,
  305. evaluator_list,
  306. n_classes=num_classes,
  307. )
  308. self.leaderboard.insert_model_performance(
  309. "ensemble",
  310. dict(zip([e.get_eval_name() for e in evaluator_list], performance)),
  311. )
  312. return self
  313. def fit_predict(
  314. self,
  315. dataset,
  316. time_limit=-1,
  317. inplace=False,
  318. train_split=None,
  319. val_split=None,
  320. evaluation_method="infer",
  321. seed=None,
  322. use_ensemble=True,
  323. use_best=True,
  324. name=None,
  325. ) -> np.ndarray:
  326. """
  327. Fit current solver on given dataset and return the predicted value.
  328. Parameters
  329. ----------
  330. dataset: torch_geometric.data.dataset.Dataset
  331. The dataset needed to fit on. This dataset must have only one graph.
  332. time_limit: int
  333. The time limit of the whole fit process (in seconds). If set below 0, will
  334. ignore time limit. Default ``-1``.
  335. inplace: bool
  336. Whether we process the given dataset in inplace manner. Default ``False``.
  337. Set it to True if you want to save memory by modifying the given dataset directly.
  338. train_split: float or int (Optional)
  339. The train ratio (in ``float``) or number (in ``int``) of dataset. If you want to
  340. use default train/val/test split in dataset, please set this to ``None``.
  341. Default ``None``.
  342. val_split: float or int (Optional)
  343. The validation ratio (in ``float``) or number (in ``int``) of dataset. If you want
  344. to use default train/val/test split in dataset, please set this to ``None``.
  345. Default ``None``.
  346. evaluation_method: (list of) str or autogl.module.train.evaluation
  347. A (list of) evaluation method for current solver. If ``infer``, will automatically
  348. determine. Default ``infer``.
  349. seed: int (Optional)
  350. The random seed. If set to ``None``, will run everything at random.
  351. Default ``None``.
  352. use_ensemble: bool
  353. Whether to use ensemble to do the predict. Default ``True``.
  354. use_best: bool
  355. Whether to use the best single model to do the predict. Will only be effective when
  356. ``use_ensemble`` is ``False``. Default ``True``.
  357. name: str or None
  358. The name of model used to predict. Will only be effective when ``use_ensemble`` and
  359. ``use_best`` both are ``False``. Default ``None``.
  360. Returns
  361. -------
  362. result: np.ndarray
  363. An array of shape ``(N,)``, where ``N`` is the number of test nodes. The prediction
  364. on given dataset.
  365. """
  366. self.fit(
  367. dataset=dataset,
  368. time_limit=time_limit,
  369. inplace=inplace,
  370. train_split=train_split,
  371. val_split=val_split,
  372. evaluation_method=evaluation_method,
  373. seed=seed,
  374. )
  375. return self.predict(
  376. dataset=dataset,
  377. inplaced=inplace,
  378. inplace=inplace,
  379. use_ensemble=use_ensemble,
  380. use_best=use_best,
  381. name=name,
  382. )
  383. def predict_proba(
  384. self,
  385. dataset=None,
  386. inplaced=False,
  387. inplace=False,
  388. use_ensemble=True,
  389. use_best=True,
  390. name=None,
  391. mask="test",
  392. ) -> np.ndarray:
  393. """
  394. Predict the node probability.
  395. Parameters
  396. ----------
  397. dataset: autogl.data.Dataset or None
  398. The dataset needed to predict. If ``None``, will use the processed dataset
  399. passed to ``fit()`` instead. Default ``None``.
  400. inplaced: bool
  401. Whether the given dataset is processed. Only be effective when ``dataset``
  402. is not ``None``. If you pass the dataset to ``fit()`` with ``inplace=True``,
  403. and you pass the dataset again to this method, you should set this argument
  404. to ``True``. Otherwise ``False``. Default ``False``.
  405. inplace: bool
  406. Whether we process the given dataset in inplace manner. Default ``False``.
  407. Set it to True if you want to save memory by modifying the given dataset directly.
  408. use_ensemble: bool
  409. Whether to use ensemble to do the predict. Default ``True``.
  410. use_best: bool
  411. Whether to use the best single model to do the predict. Will only be effective when
  412. ``use_ensemble`` is ``False``. Default ``True``.
  413. name: str or None
  414. The name of model used to predict. Will only be effective when ``use_ensemble`` and
  415. ``use_best`` both are ``False``. Default ``None``.
  416. mask: str
  417. The data split to give prediction on. Default ``test``.
  418. Returns
  419. -------
  420. result: np.ndarray
  421. An array of shape ``(N,C,)``, where ``N`` is the number of test nodes and ``C`` is
  422. the number of classes. The prediction on given dataset.
  423. """
  424. if dataset is None:
  425. dataset = self.dataset
  426. elif not inplaced:
  427. if self.feature_module is not None:
  428. dataset = self.feature_module.transform(dataset, inplace=inplace)
  429. if use_ensemble:
  430. LOGGER.info("Ensemble argument on, will try using ensemble model.")
  431. if not use_ensemble and use_best:
  432. LOGGER.info(
  433. "Ensemble argument off and best argument on, will try using best model."
  434. )
  435. if (use_ensemble and self.ensemble_module is not None) or (
  436. not use_best and name == "ensemble"
  437. ):
  438. # we need to get all the prediction of every model trained
  439. predict_result = []
  440. names = []
  441. for model_name in self.trained_models:
  442. predict_result.append(
  443. self._predict_proba_by_name(dataset, model_name, mask)
  444. )
  445. names.append(model_name)
  446. return self.ensemble_module.ensemble(predict_result, names)
  447. if use_ensemble and self.ensemble_module is None:
  448. LOGGER.warning(
  449. "Cannot use ensemble because no ensebmle module is given. "
  450. "Will use best model instead."
  451. )
  452. if use_best or (use_ensemble and self.ensemble_module is None):
  453. # just return the best model we have found
  454. best_model_name = self.leaderboard.get_best_model()
  455. return self._predict_proba_by_name(dataset, best_model_name, mask)
  456. if name is not None:
  457. # return model performance by name
  458. return self._predict_proba_by_name(dataset, name, mask)
  459. LOGGER.error(
  460. "No model name is given while ensemble and best arguments are off."
  461. )
  462. raise ValueError(
  463. "You need to specify a model name if you do not want use ensemble and best model."
  464. )
  465. def _predict_proba_by_name(self, dataset, name, mask):
  466. self.trained_models[name].to(self.runtime_device)
  467. predicted = (
  468. self.trained_models[name]
  469. .predict_proba(convert_dataset(dataset), mask=mask)
  470. .detach()
  471. .cpu()
  472. .numpy()
  473. )
  474. self.trained_models[name].to(torch.device("cpu"))
  475. return predicted
  476. def predict(
  477. self,
  478. dataset=None,
  479. inplaced=False,
  480. inplace=False,
  481. use_ensemble=True,
  482. use_best=True,
  483. name=None,
  484. mask="test",
  485. ) -> np.ndarray:
  486. """
  487. Predict the node class number.
  488. Parameters
  489. ----------
  490. dataset: autogl.data.Dataset or None
  491. The dataset needed to predict. If ``None``, will use the processed dataset passed
  492. to ``fit()`` instead. Default ``None``.
  493. inplaced: bool
  494. Whether the given dataset is processed. Only be effective when ``dataset``
  495. is not ``None``. If you pass the dataset to ``fit()`` with ``inplace=True``, and
  496. you pass the dataset again to this method, you should set this argument to ``True``.
  497. Otherwise ``False``. Default ``False``.
  498. inplace: bool
  499. Whether we process the given dataset in inplace manner. Default ``False``.
  500. Set it to True if you want to save memory by modifying the given dataset directly.
  501. use_ensemble: bool
  502. Whether to use ensemble to do the predict. Default ``True``.
  503. use_best: bool
  504. Whether to use the best single model to do the predict. Will only be effective
  505. when ``use_ensemble`` is ``False``. Default ``True``.
  506. name: str or None
  507. The name of model used to predict. Will only be effective when ``use_ensemble``
  508. and ``use_best`` both are ``False``. Default ``None``.
  509. Returns
  510. -------
  511. result: np.ndarray
  512. An array of shape ``(N,)``, where ``N`` is the number of test nodes.
  513. The prediction on given dataset.
  514. """
  515. proba = self.predict_proba(
  516. dataset, inplaced, inplace, use_ensemble, use_best, name, mask
  517. )
  518. return np.argmax(proba, axis=1)
  519. def evaluate(self, dataset=None,
  520. inplaced=False,
  521. inplace=False,
  522. use_ensemble=True,
  523. use_best=True,
  524. name=None,
  525. mask="test",
  526. label=None,
  527. metric="acc"
  528. ):
  529. """
  530. Evaluate the given dataset.
  531. Parameters
  532. ----------
  533. dataset: torch_geometric.data.dataset.Dataset or None
  534. The dataset needed to predict. If ``None``, will use the processed dataset passed
  535. to ``fit()`` instead. Default ``None``.
  536. inplaced: bool
  537. Whether the given dataset is processed. Only be effective when ``dataset``
  538. is not ``None``. If you pass the dataset to ``fit()`` with ``inplace=True``, and
  539. you pass the dataset again to this method, you should set this argument to ``True``.
  540. Otherwise ``False``. Default ``False``.
  541. inplace: bool
  542. Whether we process the given dataset in inplace manner. Default ``False``. Set it to
  543. True if you want to save memory by modifying the given dataset directly.
  544. use_ensemble: bool
  545. Whether to use ensemble to do the predict. Default ``True``.
  546. use_best: bool
  547. Whether to use the best single model to do the predict. Will only be effective when
  548. ``use_ensemble`` is ``False``. Default ``True``.
  549. name: str or None
  550. The name of model used to predict. Will only be effective when ``use_ensemble`` and
  551. ``use_best`` both are ``False``. Default ``None``.
  552. mask: str
  553. The data split to give prediction on. Default ``test``.
  554. label: torch.Tensor (Optional)
  555. The groud truth label of the given predicted dataset split. If not passed, will extract
  556. labels from the input dataset.
  557. metric: str
  558. The metric to be used for evaluating the model. Default ``acc``.
  559. Returns
  560. -------
  561. score(s): (list of) evaluation scores
  562. the evaluation results according to the evaluator passed.
  563. """
  564. predicted = self.predict_proba(dataset, inplaced, inplace, use_ensemble, use_best, name, mask)
  565. if dataset is None:
  566. dataset = self.dataset
  567. if label is None:
  568. if mask == "all":
  569. masked_dataset = dataset
  570. else:
  571. masked_dataset = utils.graph_get_split(dataset, mask, False)
  572. label = get_dataset_labels(masked_dataset)
  573. evaluator = get_feval(metric)
  574. if isinstance(evaluator, Sequence):
  575. return [evals.evaluate(predicted, label) for evals in evaluator]
  576. return evaluator.evaluate(predicted, label)
  577. @classmethod
  578. def from_config(cls, path_or_dict, filetype="auto") -> "AutoGraphClassifier":
  579. """
  580. Load solver from config file.
  581. You can use this function to directly load a solver from predefined config dict
  582. or config file path. Currently, only support file type of ``json`` or ``yaml``,
  583. if you pass a path.
  584. Parameters
  585. ----------
  586. path_or_dict: str or dict
  587. The path to the config file or the config dictionary object
  588. filetype: str
  589. The filetype the given file if the path is specified. Currently only support
  590. ``json`` or ``yaml``. You can set to ``auto`` to automatically detect the file
  591. type (from file name). Default ``auto``.
  592. Returns
  593. -------
  594. solver: autogl.solver.AutoGraphClassifier
  595. The solver that is created from given file or dictionary.
  596. """
  597. assert filetype in ["auto", "yaml", "json"], (
  598. "currently only support yaml file or json file type, but get type "
  599. + filetype
  600. )
  601. if isinstance(path_or_dict, str):
  602. if filetype == "auto":
  603. if path_or_dict.endswith(".yaml") or path_or_dict.endswith(".yml"):
  604. filetype = "yaml"
  605. elif path_or_dict.endswith(".json"):
  606. filetype = "json"
  607. else:
  608. LOGGER.error(
  609. "cannot parse the type of the given file name, "
  610. "please manually set the file type"
  611. )
  612. raise ValueError(
  613. "cannot parse the type of the given file name, "
  614. "please manually set the file type"
  615. )
  616. if filetype == "yaml":
  617. path_or_dict = yaml.load(
  618. open(path_or_dict, "r").read(), Loader=yaml.FullLoader
  619. )
  620. else:
  621. path_or_dict = json.load(open(path_or_dict, "r"))
  622. # load the dictionary
  623. path_or_dict = deepcopy(path_or_dict)
  624. solver = cls(None, [], None, None)
  625. fe_list = path_or_dict.pop("feature", None)
  626. if fe_list is not None:
  627. fe_list_ele = []
  628. for feature_engineer in fe_list:
  629. name = feature_engineer.pop("name")
  630. if name is not None:
  631. fe_list_ele.append(FEATURE_DICT[name](**feature_engineer))
  632. if fe_list_ele != []:
  633. solver.set_feature_module(fe_list_ele)
  634. models = path_or_dict.pop("models", [{"name": "gcn"}, {"name": "gat"}, {"name": "sage"}, {"name": "gin"}])
  635. # models should be a list of model
  636. # with each element in two cases
  637. # * a dict describing a certain model
  638. # * a dict containing {"encoder": encoder, "decoder": decoder}
  639. model_hp_space = [
  640. _parse_model_hp(model) for model in models
  641. ]
  642. model_list = [
  643. _initialize_single_model(model) for model in models
  644. ]
  645. trainer = path_or_dict.pop("trainer", None)
  646. default_trainer = "GraphClassificationFull"
  647. trainer_space = None
  648. if isinstance(trainer, dict):
  649. # global default
  650. default_trainer = trainer.pop("name", "GraphClassificationFull")
  651. trainer_space = _parse_hp_space(trainer.pop("hp_space", None))
  652. default_kwargs = {"num_features": None, "num_classes": None}
  653. default_kwargs.update(trainer)
  654. default_kwargs["init"] = False
  655. for i in range(len(model_list)):
  656. model = model_list[i]
  657. trainer_wrapper = TRAINER_DICT[default_trainer](
  658. model=model, prediction_head=models[i]["prediction_head"].pop("name", None), **default_kwargs
  659. )
  660. model_list[i] = trainer_wrapper
  661. elif isinstance(trainer, list):
  662. # sequential trainer definition
  663. assert len(trainer) == len(
  664. model_list
  665. ), "The number of trainer and model does not match"
  666. trainer_space = []
  667. for i in range(len(model_list)):
  668. train, model = trainer[i], model_list[i]
  669. default_trainer = train.pop("name", "GraphClassificationFull")
  670. trainer_space.append(_parse_hp_space(train.pop("hp_space", None)))
  671. default_kwargs = {"num_features": None, "num_classes": None}
  672. default_kwargs.update(train)
  673. default_kwargs["init"] = False
  674. trainer_wrap = TRAINER_DICT[default_trainer](
  675. model=model, **default_kwargs
  676. )
  677. model_list[i] = trainer_wrap
  678. solver.set_graph_models(
  679. model_list, default_trainer, trainer_space, model_hp_space
  680. )
  681. hpo_dict = path_or_dict.pop("hpo", {"name": "anneal"})
  682. if hpo_dict is not None:
  683. name = hpo_dict.pop("name")
  684. solver.set_hpo_module(name, **hpo_dict)
  685. ensemble_dict = path_or_dict.pop("ensemble", {"name": "voting"})
  686. if ensemble_dict is not None:
  687. name = ensemble_dict.pop("name")
  688. solver.set_ensemble_module(name, **ensemble_dict)
  689. return solver