You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

quickstart.ipynb 2.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# Quick start\n",
  8. "## Import datasets\n",
  9. "At first, import our library and datasets from the given path.\n",
  10. "Under the given directory, there should be a `data` directory with different datasets. E.g. `/home/AGL/data/cora`."
  11. ]
  12. },
  13. {
  14. "cell_type": "code",
  15. "execution_count": null,
  16. "metadata": {},
  17. "outputs": [],
  18. "source": [
  19. "import autogl\n",
  20. "from autogl.datasets import build_dataset_from_name\n",
  21. "cora_dataset = build_dataset_from_name('cora', path = '~/')\n"
  22. ]
  23. },
  24. {
  25. "cell_type": "markdown",
  26. "metadata": {},
  27. "source": [
  28. "## Decide modules\n",
  29. "Then, you should decide which models to use.\n",
  30. "Here, we use `deepgl` to pre-process graph features, then we use two GNNs to learn the target task, e.g. `GCN` and `GAT`.\n",
  31. "We use Simulated annealing algorithm to tune the hyper-parameters of the two GNNs. \n",
  32. "After training, use voting method to ensemble the results of the two GNNs.\n",
  33. "Also, you can specify which device to run on."
  34. ]
  35. },
  36. {
  37. "cell_type": "code",
  38. "execution_count": null,
  39. "metadata": {},
  40. "outputs": [],
  41. "source": [
  42. "import torch\n",
  43. "device = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')\n",
  44. "from autogl.solver import AutoNodeClassifier\n",
  45. "solver = AutoNodeClassifier(\n",
  46. " feature_module='deepgl',\n",
  47. " graph_models=['gcn', 'gat'],\n",
  48. " hpo_module='anneal',\n",
  49. " ensemble_module='voting',\n",
  50. " device=device\n",
  51. ")"
  52. ]
  53. },
  54. {
  55. "cell_type": "markdown",
  56. "metadata": {},
  57. "source": [
  58. "## Running\n",
  59. "Run the whole process with a certain time limit and show the leaderboard.\n",
  60. "You can also get the accuracy by evaluate the predictions."
  61. ]
  62. },
  63. {
  64. "cell_type": "code",
  65. "execution_count": null,
  66. "metadata": {},
  67. "outputs": [],
  68. "source": [
  69. "solver.fit(cora_dataset, time_limit=3600)\n",
  70. "solver.get_leaderboard().show()\n",
  71. "\n",
  72. "from autogl.module.train import Acc\n",
  73. "from autogl.solver.utils import get_graph_labels, get_graph_masks\n",
  74. "\n",
  75. "predicted = solver.predict_proba()\n",
  76. "label = get_graph_labels(cora_dataset[0])[get_graph_masks(cora_dataset[0], 'test')].cpu().numpy()\n",
  77. "print('Test accuracy: ', Acc.evaluate(predicted, label))"
  78. ]
  79. }
  80. ],
  81. "metadata": {
  82. "kernelspec": {
  83. "display_name": "agl",
  84. "language": "python",
  85. "name": "python3"
  86. },
  87. "language_info": {
  88. "name": "python",
  89. "version": "3.9.15 (main, Nov 24 2022, 14:31:59) \n[GCC 11.2.0]"
  90. },
  91. "orig_nbformat": 4,
  92. "vscode": {
  93. "interpreter": {
  94. "hash": "f5df920585aabffd7c100033fb0a19d10b4cb3343e2b2347d38d24b7cc162540"
  95. }
  96. }
  97. },
  98. "nbformat": 4,
  99. "nbformat_minor": 2
  100. }