You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_classifier.py 31 kB

4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797
  1. """
  2. Auto Classfier for Graph Node Classification
  3. """
  4. import time
  5. import json
  6. from copy import deepcopy
  7. import torch
  8. import numpy as np
  9. import yaml
  10. from .base import BaseClassifier
  11. from ...module.feature import FEATURE_DICT
  12. from ...module.model import BaseModel, MODEL_DICT
  13. from ...module.train import TRAINER_DICT, get_feval, BaseGraphClassificationTrainer
  14. from ..base import _initialize_single_model, _parse_hp_space
  15. from ..utils import LeaderBoard, get_dataset_labels, set_seed, get_graph_from_dataset, get_graph_node_features, convert_dataset
  16. from ...datasets import utils
  17. from ..utils import get_logger
  18. from ...backend import DependentBackend
  19. LOGGER = get_logger("GraphClassifier")
  20. BACKEND = DependentBackend.get_backend_name()
  21. class AutoGraphClassifier(BaseClassifier):
  22. """
  23. Auto Multi-class Graph Classifier.
  24. Used to automatically solve the graph classification problems.
  25. Parameters
  26. ----------
  27. feature_module: autogl.module.feature.BaseFeatureEngineer or str or None
  28. The (name of) auto feature engineer used to process the given dataset.
  29. Disable feature engineer by setting it to ``None``. Default ``deepgl``.
  30. graph_models: list of autogl.module.model.BaseModel or list of str
  31. The (name of) models to be optimized as backbone. Default ``['gat', 'gcn']``.
  32. hpo_module: autogl.module.hpo.BaseHPOptimizer or str or None
  33. The (name of) hpo module used to search for best hyper parameters.
  34. Disable hpo by setting it to ``None``. Default ``anneal``.
  35. ensemble_module: autogl.module.ensemble.BaseEnsembler or str or None
  36. The (name of) ensemble module used to ensemble the multi-models found.
  37. Disable ensemble by setting it to ``None``. Default ``voting``.
  38. max_evals: int (Optional)
  39. If given, will set the number eval times the hpo module will use.
  40. Only be effective when hpo_module is ``str``. Default ``None``.
  41. trainer_hp_space: Iterable[dict] (Optional)
  42. trainer hp space or list of trainer hp spaces configuration.
  43. If a single trainer hp is given, will specify the hp space of trainer for
  44. every model. If a list of trainer hp is given, will specify every model
  45. with corrsponding trainer hp space. Default ``None``.
  46. model_hp_spaces: Iterable[Iterable[dict]] (Optional)
  47. model hp space configuration.
  48. If given, will specify every hp space of every passed model. Default ``None``.
  49. size: int (Optional)
  50. The max models ensemble module will use. Default ``None``.
  51. device: torch.device or str
  52. The device where model will be running on. If set to ``auto``, will use gpu
  53. when available. You can also specify the device by directly giving ``gpu`` or
  54. ``cuda:0``, etc. Default ``auto``.
  55. """
  56. # pylint: disable=W0102
  57. def __init__(
  58. self,
  59. feature_module=None,
  60. graph_models=["gin", "topkpool"],
  61. # nas_algorithms=None,
  62. # nas_spaces=None,
  63. # nas_estimators=None,
  64. hpo_module="anneal",
  65. ensemble_module="voting",
  66. max_evals=50,
  67. default_trainer=None,
  68. trainer_hp_space=None,
  69. model_hp_spaces=None,
  70. size=4,
  71. device="auto",
  72. ):
  73. super().__init__(
  74. feature_module=feature_module,
  75. graph_models=graph_models,
  76. nas_algorithms=None, # nas_algorithms,
  77. nas_spaces=None, # nas_spaces,
  78. nas_estimators=None, # nas_estimators,
  79. hpo_module=hpo_module,
  80. ensemble_module=ensemble_module,
  81. max_evals=max_evals,
  82. default_trainer=default_trainer or "GraphClassificationFull",
  83. trainer_hp_space=trainer_hp_space,
  84. model_hp_spaces=model_hp_spaces,
  85. size=size,
  86. device=device,
  87. )
  88. self.dataset = None
  89. def _init_graph_module(
  90. self,
  91. graph_models,
  92. num_classes,
  93. num_features,
  94. feval,
  95. device,
  96. loss,
  97. num_graph_features,
  98. ) -> "AutoGraphClassifier":
  99. # load graph network module
  100. self.graph_model_list = []
  101. if isinstance(graph_models, (list, tuple)):
  102. for model in graph_models:
  103. if isinstance(model, str):
  104. if model in MODEL_DICT:
  105. self.graph_model_list.append(
  106. MODEL_DICT[model](
  107. num_classes=num_classes,
  108. num_features=num_features,
  109. num_graph_features=num_graph_features,
  110. device=device,
  111. init=False,
  112. )
  113. )
  114. else:
  115. raise KeyError("cannot find model %s" % (model))
  116. elif isinstance(model, type) and issubclass(model, BaseModel):
  117. self.graph_model_list.append(
  118. model(
  119. num_classes=num_classes,
  120. num_features=num_features,
  121. num_graph_features=num_graph_features,
  122. device=device,
  123. init=False,
  124. )
  125. )
  126. elif isinstance(model, BaseModel):
  127. # setup the hp of num_classes and num_features
  128. model.set_num_classes(num_classes)
  129. model.set_num_features(num_features)
  130. model.set_num_graph_features(num_graph_features)
  131. self.graph_model_list.append(model.to(device))
  132. elif isinstance(model, BaseGraphClassificationTrainer):
  133. # receive a trainer list, put trainer to list
  134. assert (
  135. model.get_model() is not None
  136. ), "Passed trainer should contain a model"
  137. model.model.set_num_classes(num_classes)
  138. model.model.set_num_features(num_features)
  139. model.model.set_num_graph_features(num_graph_features)
  140. model.update_parameters(
  141. num_classes=num_classes,
  142. num_features=num_features,
  143. num_graph_features=num_graph_features,
  144. loss=loss,
  145. feval=feval,
  146. device=device,
  147. )
  148. self.graph_model_list.append(model)
  149. else:
  150. raise KeyError("cannot find graph network %s." % (model))
  151. else:
  152. raise ValueError(
  153. "need graph network to be (list of) str or a BaseModel class/instance, get",
  154. graph_models,
  155. "instead.",
  156. )
  157. # wrap all model_cls with specified trainer
  158. for i, model in enumerate(self.graph_model_list):
  159. # set model hp space
  160. if self._model_hp_spaces is not None:
  161. if self._model_hp_spaces[i] is not None:
  162. if isinstance(model, BaseGraphClassificationTrainer):
  163. model.model.hyper_parameter_space = self._model_hp_spaces[i]
  164. else:
  165. model.hyper_parameter_space = self._model_hp_spaces[i]
  166. # initialize trainer if needed
  167. if isinstance(model, BaseModel):
  168. name = (
  169. self._default_trainer
  170. if isinstance(self._default_trainer, str)
  171. else self._default_trainer[i]
  172. )
  173. model = TRAINER_DICT[name](
  174. model=model,
  175. num_features=num_features,
  176. num_classes=num_classes,
  177. loss=loss,
  178. feval=feval,
  179. device=device,
  180. num_graph_features=num_graph_features,
  181. init=False,
  182. )
  183. # set trainer hp space
  184. if self._trainer_hp_space is not None:
  185. if isinstance(self._trainer_hp_space[0], list):
  186. current_hp_for_trainer = self._trainer_hp_space[i]
  187. else:
  188. current_hp_for_trainer = self._trainer_hp_space
  189. model.hyper_parameter_space = current_hp_for_trainer
  190. self.graph_model_list[i] = model
  191. return self
  192. """
  193. # currently disabled
  194. def _init_nas_module(
  195. self, num_features, num_classes, num_graph_features, feval, device, loss
  196. ):
  197. for algo, space, estimator in zip(
  198. self.nas_algorithms, self.nas_spaces, self.nas_estimators
  199. ):
  200. # TODO: initialize important parameters
  201. pass
  202. """
  203. # pylint: disable=arguments-differ
  204. def fit(
  205. self,
  206. dataset,
  207. time_limit=-1,
  208. inplace=False,
  209. train_split=None,
  210. val_split=None,
  211. evaluation_method="infer",
  212. seed=None,
  213. ) -> "AutoGraphClassifier":
  214. """
  215. Fit current solver on given dataset.
  216. Parameters
  217. ----------
  218. dataset: autogl.data.dataset
  219. The multi-graph dataset needed to fit on.
  220. time_limit: int
  221. The time limit of the whole fit process (in seconds). If set below 0, will ignore
  222. time limit. Default ``-1``.
  223. inplace: bool
  224. Whether we process the given dataset in inplace manner. Default ``False``.
  225. Set it to True if you want to save memory by modifying the given dataset directly.
  226. train_split: float or int (Optional)
  227. The train ratio (in ``float``) or number (in ``int``) of dataset. If you want to use
  228. default train/val/test split in dataset, please set this to ``None``.
  229. Default ``None``.
  230. val_split: float or int (Optional)
  231. The validation ratio (in ``float``) or number (in ``int``) of dataset. If you want to
  232. use default train/val/test split in dataset, please set this to ``None``.
  233. Default ``None``.
  234. evaluation_method: (list of) str autogl.module.train.evaluation
  235. A (list of) evaluation method for current solver. If ``infer``, will automatically
  236. determine. Default ``infer``.
  237. seed: int (Optional)
  238. The random seed. If set to ``None``, will run everything at random.
  239. Default ``None``.
  240. Returns
  241. -------
  242. self: autogl.solver.AutoGraphClassifier
  243. A reference of current solver.
  244. """
  245. set_seed(seed)
  246. num_classes = max(get_dataset_labels(dataset)) + 1
  247. if time_limit < 0:
  248. time_limit = 3600 * 24
  249. time_begin = time.time()
  250. # initialize leaderboard
  251. if evaluation_method == "infer":
  252. if hasattr(dataset, "metric"):
  253. evaluation_method = [dataset.metric]
  254. else:
  255. if num_classes == 2:
  256. evaluation_method = ["auc"]
  257. else:
  258. evaluation_method = ["acc"]
  259. assert isinstance(evaluation_method, list)
  260. evaluator_list = get_feval(evaluation_method)
  261. self.leaderboard = LeaderBoard(
  262. [e.get_eval_name() for e in evaluator_list],
  263. {e.get_eval_name(): e.is_higher_better() for e in evaluator_list},
  264. )
  265. # set up the dataset
  266. if train_split is None and val_split is None:
  267. assert hasattr(dataset, "train_split") and hasattr(dataset, "val_split"), (
  268. "The dataset has no default train/val split! "
  269. "Please manually pass train and val ratio."
  270. )
  271. LOGGER.info("Use the default train/val/test ratio in given dataset")
  272. # if hasattr(dataset.train_split, "n_splits"):
  273. # cross_validation = True
  274. elif train_split is not None and val_split is not None:
  275. utils.graph_random_splits(dataset, train_split, val_split, seed=seed)
  276. else:
  277. LOGGER.error(
  278. "Please set both train_split and val_split explicitly. Detect %s is None.",
  279. "train_split" if train_split is None else "val_split",
  280. )
  281. raise ValueError(
  282. "In consistent setting of train/val split. Detect {} is None.".format(
  283. "train_split" if train_split is None else "val_split"
  284. )
  285. )
  286. # feature engineering
  287. if self.feature_module is not None:
  288. self.feature_module.fit(dataset.train_split)
  289. dataset = self.feature_module.transform(dataset, inplace=inplace)
  290. self.dataset = dataset
  291. # check whether the dataset has features.
  292. # currently we only support graph classification with features.
  293. feat = get_graph_node_features(get_graph_from_dataset(dataset))
  294. assert feat is not None, (
  295. "Does not support fit on non node-feature dataset!"
  296. " Please add node features to dataset or specify feature engineers that generate"
  297. " node features."
  298. )
  299. num_features = feat.size(-1)
  300. # initialize graph networks
  301. self._init_graph_module(
  302. self.gml,
  303. # TODO: what should we use to get feature dimension?
  304. num_features=num_features,
  305. num_classes=num_classes,
  306. feval=evaluator_list,
  307. device=self.runtime_device,
  308. loss="cross_entropy" if not hasattr(dataset, "loss") else dataset.loss,
  309. num_graph_features=(0
  310. if not hasattr(dataset[0], "gf")
  311. else dataset[0].gf.size(1)) if BACKEND == 'pyg' else
  312. (0 if 'gf' not in dataset[0].data else dataset[0].data['gf'].size(1)),
  313. )
  314. # currently disabled
  315. """
  316. self._init_nas_module(
  317. num_features=dataset.num_node_features,
  318. num_classes=dataset.num_classes,
  319. feval=evaluator_list,
  320. device=self.runtime_device,
  321. loss="cross_entropy" if not hasattr(dataset, "loss") else dataset.loss,
  322. num_graph_features=0
  323. if not hasattr(dataset.data, "gf")
  324. else dataset.data.gf.size(1),
  325. )
  326. # neural architecture search
  327. if self.nas_algorithms is not None:
  328. # perform nas and add them to trainer list
  329. for algo, space, estimator in zip(
  330. self.nas_algorithms, self.nas_spaces, self.nas_estimators
  331. ):
  332. trainer = algo.search(space, self.dataset, estimator)
  333. self.graph_model_list.append(trainer)
  334. """
  335. # train the models and tune hpo
  336. result_valid = []
  337. names = []
  338. for idx, model in enumerate(self.graph_model_list):
  339. if time_limit < 0:
  340. time_for_each_model = None
  341. else:
  342. time_for_each_model = (time_limit - time.time() + time_begin) / (
  343. len(self.graph_model_list) - idx
  344. )
  345. if self.hpo_module is None:
  346. model.initialize()
  347. model.train(convert_dataset(dataset), True)
  348. optimized = model
  349. else:
  350. optimized, _ = self.hpo_module.optimize(
  351. trainer=model, dataset=convert_dataset(dataset), time_limit=time_for_each_model
  352. )
  353. # to save memory, all the trainer derived will be mapped to cpu
  354. optimized.to(torch.device("cpu"))
  355. name = str(optimized)
  356. names.append(name)
  357. performance_on_valid, _ = optimized.get_valid_score(return_major=False)
  358. result_valid.append(
  359. optimized.get_valid_predict_proba().detach().cpu().numpy()
  360. )
  361. self.leaderboard.insert_model_performance(
  362. name,
  363. dict(
  364. zip(
  365. [e.get_eval_name() for e in evaluator_list],
  366. performance_on_valid,
  367. )
  368. ),
  369. )
  370. self.trained_models[name] = optimized
  371. # fit the ensemble model
  372. if self.ensemble_module is not None:
  373. performance = self.ensemble_module.fit(
  374. result_valid,
  375. get_dataset_labels(dataset)[dataset.val_index].cpu().numpy(),
  376. names,
  377. evaluator_list,
  378. n_classes=dataset.num_classes,
  379. )
  380. self.leaderboard.insert_model_performance(
  381. "ensemble",
  382. dict(zip([e.get_eval_name() for e in evaluator_list], performance)),
  383. )
  384. return self
  385. def fit_predict(
  386. self,
  387. dataset,
  388. time_limit=-1,
  389. inplace=False,
  390. train_split=None,
  391. val_split=None,
  392. evaluation_method="infer",
  393. seed=None,
  394. use_ensemble=True,
  395. use_best=True,
  396. name=None,
  397. ) -> np.ndarray:
  398. """
  399. Fit current solver on given dataset and return the predicted value.
  400. Parameters
  401. ----------
  402. dataset: torch_geometric.data.dataset.Dataset
  403. The dataset needed to fit on. This dataset must have only one graph.
  404. time_limit: int
  405. The time limit of the whole fit process (in seconds). If set below 0, will
  406. ignore time limit. Default ``-1``.
  407. inplace: bool
  408. Whether we process the given dataset in inplace manner. Default ``False``.
  409. Set it to True if you want to save memory by modifying the given dataset directly.
  410. train_split: float or int (Optional)
  411. The train ratio (in ``float``) or number (in ``int``) of dataset. If you want to
  412. use default train/val/test split in dataset, please set this to ``None``.
  413. Default ``None``.
  414. val_split: float or int (Optional)
  415. The validation ratio (in ``float``) or number (in ``int``) of dataset. If you want
  416. to use default train/val/test split in dataset, please set this to ``None``.
  417. Default ``None``.
  418. evaluation_method: (list of) str or autogl.module.train.evaluation
  419. A (list of) evaluation method for current solver. If ``infer``, will automatically
  420. determine. Default ``infer``.
  421. seed: int (Optional)
  422. The random seed. If set to ``None``, will run everything at random.
  423. Default ``None``.
  424. use_ensemble: bool
  425. Whether to use ensemble to do the predict. Default ``True``.
  426. use_best: bool
  427. Whether to use the best single model to do the predict. Will only be effective when
  428. ``use_ensemble`` is ``False``. Default ``True``.
  429. name: str or None
  430. The name of model used to predict. Will only be effective when ``use_ensemble`` and
  431. ``use_best`` both are ``False``. Default ``None``.
  432. Returns
  433. -------
  434. result: np.ndarray
  435. An array of shape ``(N,)``, where ``N`` is the number of test nodes. The prediction
  436. on given dataset.
  437. """
  438. self.fit(
  439. dataset=dataset,
  440. time_limit=time_limit,
  441. inplace=inplace,
  442. train_split=train_split,
  443. val_split=val_split,
  444. evaluation_method=evaluation_method,
  445. seed=seed,
  446. )
  447. return self.predict(
  448. dataset=dataset,
  449. inplaced=inplace,
  450. inplace=inplace,
  451. use_ensemble=use_ensemble,
  452. use_best=use_best,
  453. name=name,
  454. )
  455. def predict_proba(
  456. self,
  457. dataset=None,
  458. inplaced=False,
  459. inplace=False,
  460. use_ensemble=True,
  461. use_best=True,
  462. name=None,
  463. mask="test",
  464. ) -> np.ndarray:
  465. """
  466. Predict the node probability.
  467. Parameters
  468. ----------
  469. dataset: autogl.data.Dataset or None
  470. The dataset needed to predict. If ``None``, will use the processed dataset
  471. passed to ``fit()`` instead. Default ``None``.
  472. inplaced: bool
  473. Whether the given dataset is processed. Only be effective when ``dataset``
  474. is not ``None``. If you pass the dataset to ``fit()`` with ``inplace=True``,
  475. and you pass the dataset again to this method, you should set this argument
  476. to ``True``. Otherwise ``False``. Default ``False``.
  477. inplace: bool
  478. Whether we process the given dataset in inplace manner. Default ``False``.
  479. Set it to True if you want to save memory by modifying the given dataset directly.
  480. use_ensemble: bool
  481. Whether to use ensemble to do the predict. Default ``True``.
  482. use_best: bool
  483. Whether to use the best single model to do the predict. Will only be effective when
  484. ``use_ensemble`` is ``False``. Default ``True``.
  485. name: str or None
  486. The name of model used to predict. Will only be effective when ``use_ensemble`` and
  487. ``use_best`` both are ``False``. Default ``None``.
  488. mask: str
  489. The data split to give prediction on. Default ``test``.
  490. Returns
  491. -------
  492. result: np.ndarray
  493. An array of shape ``(N,C,)``, where ``N`` is the number of test nodes and ``C`` is
  494. the number of classes. The prediction on given dataset.
  495. """
  496. if dataset is None:
  497. dataset = self.dataset
  498. elif not inplaced:
  499. if self.feature_module is not None:
  500. dataset = self.feature_module.transform(dataset, inplace=inplace)
  501. if use_ensemble:
  502. LOGGER.info("Ensemble argument on, will try using ensemble model.")
  503. if not use_ensemble and use_best:
  504. LOGGER.info(
  505. "Ensemble argument off and best argument on, will try using best model."
  506. )
  507. if (use_ensemble and self.ensemble_module is not None) or (
  508. not use_best and name == "ensemble"
  509. ):
  510. # we need to get all the prediction of every model trained
  511. predict_result = []
  512. names = []
  513. for model_name in self.trained_models:
  514. predict_result.append(
  515. self._predict_proba_by_name(dataset, model_name, mask)
  516. )
  517. names.append(model_name)
  518. return self.ensemble_module.ensemble(predict_result, names)
  519. if use_ensemble and self.ensemble_module is None:
  520. LOGGER.warning(
  521. "Cannot use ensemble because no ensebmle module is given. "
  522. "Will use best model instead."
  523. )
  524. if use_best or (use_ensemble and self.ensemble_module is None):
  525. # just return the best model we have found
  526. best_model_name = self.leaderboard.get_best_model()
  527. return self._predict_proba_by_name(dataset, best_model_name, mask)
  528. if name is not None:
  529. # return model performance by name
  530. return self._predict_proba_by_name(dataset, name, mask)
  531. LOGGER.error(
  532. "No model name is given while ensemble and best arguments are off."
  533. )
  534. raise ValueError(
  535. "You need to specify a model name if you do not want use ensemble and best model."
  536. )
  537. def _predict_proba_by_name(self, dataset, name, mask):
  538. self.trained_models[name].to(self.runtime_device)
  539. predicted = (
  540. self.trained_models[name]
  541. .predict_proba(convert_dataset(dataset), mask=mask)
  542. .detach()
  543. .cpu()
  544. .numpy()
  545. )
  546. self.trained_models[name].to(torch.device("cpu"))
  547. return predicted
  548. def predict(
  549. self,
  550. dataset=None,
  551. inplaced=False,
  552. inplace=False,
  553. use_ensemble=True,
  554. use_best=True,
  555. name=None,
  556. mask="test",
  557. ) -> np.ndarray:
  558. """
  559. Predict the node class number.
  560. Parameters
  561. ----------
  562. dataset: autogl.data.Dataset or None
  563. The dataset needed to predict. If ``None``, will use the processed dataset passed
  564. to ``fit()`` instead. Default ``None``.
  565. inplaced: bool
  566. Whether the given dataset is processed. Only be effective when ``dataset``
  567. is not ``None``. If you pass the dataset to ``fit()`` with ``inplace=True``, and
  568. you pass the dataset again to this method, you should set this argument to ``True``.
  569. Otherwise ``False``. Default ``False``.
  570. inplace: bool
  571. Whether we process the given dataset in inplace manner. Default ``False``.
  572. Set it to True if you want to save memory by modifying the given dataset directly.
  573. use_ensemble: bool
  574. Whether to use ensemble to do the predict. Default ``True``.
  575. use_best: bool
  576. Whether to use the best single model to do the predict. Will only be effective
  577. when ``use_ensemble`` is ``False``. Default ``True``.
  578. name: str or None
  579. The name of model used to predict. Will only be effective when ``use_ensemble``
  580. and ``use_best`` both are ``False``. Default ``None``.
  581. Returns
  582. -------
  583. result: np.ndarray
  584. An array of shape ``(N,)``, where ``N`` is the number of test nodes.
  585. The prediction on given dataset.
  586. """
  587. proba = self.predict_proba(
  588. dataset, inplaced, inplace, use_ensemble, use_best, name, mask
  589. )
  590. return np.argmax(proba, axis=1)
  591. @classmethod
  592. def from_config(cls, path_or_dict, filetype="auto") -> "AutoGraphClassifier":
  593. """
  594. Load solver from config file.
  595. You can use this function to directly load a solver from predefined config dict
  596. or config file path. Currently, only support file type of ``json`` or ``yaml``,
  597. if you pass a path.
  598. Parameters
  599. ----------
  600. path_or_dict: str or dict
  601. The path to the config file or the config dictionary object
  602. filetype: str
  603. The filetype the given file if the path is specified. Currently only support
  604. ``json`` or ``yaml``. You can set to ``auto`` to automatically detect the file
  605. type (from file name). Default ``auto``.
  606. Returns
  607. -------
  608. solver: autogl.solver.AutoGraphClassifier
  609. The solver that is created from given file or dictionary.
  610. """
  611. assert filetype in ["auto", "yaml", "json"], (
  612. "currently only support yaml file or json file type, but get type "
  613. + filetype
  614. )
  615. if isinstance(path_or_dict, str):
  616. if filetype == "auto":
  617. if path_or_dict.endswith(".yaml") or path_or_dict.endswith(".yml"):
  618. filetype = "yaml"
  619. elif path_or_dict.endswith(".json"):
  620. filetype = "json"
  621. else:
  622. LOGGER.error(
  623. "cannot parse the type of the given file name, "
  624. "please manually set the file type"
  625. )
  626. raise ValueError(
  627. "cannot parse the type of the given file name, "
  628. "please manually set the file type"
  629. )
  630. if filetype == "yaml":
  631. path_or_dict = yaml.load(
  632. open(path_or_dict, "r").read(), Loader=yaml.FullLoader
  633. )
  634. else:
  635. path_or_dict = json.load(open(path_or_dict, "r"))
  636. # load the dictionary
  637. path_or_dict = deepcopy(path_or_dict)
  638. solver = cls(None, [], None, None)
  639. fe_list = path_or_dict.pop("feature", None)
  640. if fe_list is not None:
  641. fe_list_ele = []
  642. for feature_engineer in fe_list:
  643. name = feature_engineer.pop("name")
  644. if name is not None:
  645. fe_list_ele.append(FEATURE_DICT[name](**feature_engineer))
  646. if fe_list_ele != []:
  647. solver.set_feature_module(fe_list_ele)
  648. models = path_or_dict.pop("models", [{"name": "gin"}, {"name": "topkpool"}])
  649. model_hp_space = [
  650. _parse_hp_space(model.pop("hp_space", None)) for model in models
  651. ]
  652. model_list = [
  653. _initialize_single_model(model.pop("name"), model) for model in models
  654. ]
  655. trainer = path_or_dict.pop("trainer", None)
  656. default_trainer = "GraphClassificationFull"
  657. trainer_space = None
  658. if isinstance(trainer, dict):
  659. # global default
  660. default_trainer = trainer.pop("name", "GraphClassificationFull")
  661. trainer_space = _parse_hp_space(trainer.pop("hp_space", None))
  662. default_kwargs = {"num_features": None, "num_classes": None}
  663. default_kwargs.update(trainer)
  664. default_kwargs["init"] = False
  665. for i in range(len(model_list)):
  666. model = model_list[i]
  667. trainer_wrapper = TRAINER_DICT[default_trainer](
  668. model=model, **default_kwargs
  669. )
  670. model_list[i] = trainer_wrapper
  671. elif isinstance(trainer, list):
  672. # sequential trainer definition
  673. assert len(trainer) == len(
  674. model_list
  675. ), "The number of trainer and model does not match"
  676. trainer_space = []
  677. for i in range(len(model_list)):
  678. train, model = trainer[i], model_list[i]
  679. default_trainer = train.pop("name", "GraphClassificationFull")
  680. trainer_space.append(_parse_hp_space(train.pop("hp_space", None)))
  681. default_kwargs = {"num_features": None, "num_classes": None}
  682. default_kwargs.update(train)
  683. default_kwargs["init"] = False
  684. trainer_wrap = TRAINER_DICT[default_trainer](
  685. model=model, **default_kwargs
  686. )
  687. model_list[i] = trainer_wrap
  688. solver.set_graph_models(
  689. model_list, default_trainer, trainer_space, model_hp_space
  690. )
  691. hpo_dict = path_or_dict.pop("hpo", {"name": "anneal"})
  692. if hpo_dict is not None:
  693. name = hpo_dict.pop("name")
  694. solver.set_hpo_module(name, **hpo_dict)
  695. ensemble_dict = path_or_dict.pop("ensemble", {"name": "voting"})
  696. if ensemble_dict is not None:
  697. name = ensemble_dict.pop("name")
  698. solver.set_ensemble_module(name, **ensemble_dict)
  699. return solver