You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_classification.ipynb 7.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# Node Classification\n",
  8. "AutoGL supports multiple graph related tasks, including node classification. \n",
  9. "\n",
  10. "In this file we will give you a simple example to show how to use AutoGL to do the node classification task.\n",
  11. "\n",
  12. "## Import libraries\n",
  13. "First, you should import some libraries and you can set the random seed before you split the dataset and train the model."
  14. ]
  15. },
  16. {
  17. "cell_type": "code",
  18. "execution_count": 1,
  19. "metadata": {},
  20. "outputs": [],
  21. "source": [
  22. "import yaml\n",
  23. "import random\n",
  24. "import torch.backends.cudnn\n",
  25. "import numpy as np\n",
  26. "\n",
  27. "from autogl.datasets import build_dataset_from_name\n",
  28. "from autogl.solver import AutoNodeClassifier\n",
  29. "from autogl.module import Acc\n",
  30. "from autogl.backend import DependentBackend\n",
  31. "\n",
  32. "# set random seed\n",
  33. "random.seed(2022)\n",
  34. "np.random.seed(2022)\n",
  35. "torch.manual_seed(2022)\n",
  36. "if torch.cuda.is_available():\n",
  37. " torch.cuda.manual_seed(2022)\n",
  38. " torch.backends.cudnn.deterministic = True\n",
  39. " torch.backends.cudnn.benchmark = False"
  40. ]
  41. },
  42. {
  43. "cell_type": "markdown",
  44. "metadata": {},
  45. "source": [
  46. "## Load Dataset\n",
  47. "AutoGL provides a very convenient interface to obtain and partition common datasets, such as cora, citeseer, and amazon_computers, etc.\n",
  48. "\n",
  49. "You just need to give the name of the dataset you want and AutoGL will return the dataset.\n",
  50. "\n",
  51. "In this example, we evaluate model on Cora dataset in the semi-supervised node classification task."
  52. ]
  53. },
  54. {
  55. "cell_type": "code",
  56. "execution_count": 2,
  57. "metadata": {},
  58. "outputs": [
  59. {
  60. "name": "stdout",
  61. "output_type": "stream",
  62. "text": [
  63. " NumNodes: 2708\n",
  64. " NumEdges: 10556\n",
  65. " NumFeats: 1433\n",
  66. " NumClasses: 7\n",
  67. " NumTrainingSamples: 140\n",
  68. " NumValidationSamples: 500\n",
  69. " NumTestSamples: 1000\n",
  70. "Done loading data from cached files.\n"
  71. ]
  72. }
  73. ],
  74. "source": [
  75. "dataset = build_dataset_from_name('cora')"
  76. ]
  77. },
  78. {
  79. "cell_type": "markdown",
  80. "metadata": {},
  81. "source": [
  82. "## Initialize Solver\n",
  83. "After obtaining the dataset, we need to initialize the model.\n",
  84. "\n",
  85. "However, as AutoGL provides a convenient method to use HPO to better optimize the model, we can train the model through the solver class provided by AutoGL.\n",
  86. "\n",
  87. "Solver in AutoGL usually uses a config file for lazy initialization. The format of the config file can be found in the `../config` folder for examples, or you can read our tutorial for some help."
  88. ]
  89. },
  90. {
  91. "cell_type": "code",
  92. "execution_count": 3,
  93. "metadata": {},
  94. "outputs": [],
  95. "source": [
  96. "label = dataset[0].nodes.data['y' if DependentBackend.is_pyg() else 'label']\n",
  97. "num_classes = len(np.unique(label.numpy()))\n",
  98. "\n",
  99. "configs = yaml.load(open('../configs/nodeclf_gcn_benchmark_small.yml', \"r\").read(), Loader=yaml.FullLoader)\n",
  100. "autoClassifier = AutoNodeClassifier.from_config(configs)"
  101. ]
  102. },
  103. {
  104. "cell_type": "markdown",
  105. "metadata": {},
  106. "source": [
  107. "## Train\n",
  108. "After the initialization is finished, you can use the interface provided by AutoGL to optimize the model through HPO."
  109. ]
  110. },
  111. {
  112. "cell_type": "code",
  113. "execution_count": 4,
  114. "metadata": {},
  115. "outputs": [
  116. {
  117. "name": "stdout",
  118. "output_type": "stream",
  119. "text": [
  120. "[2022-10-24 18:55:27] INFO (NodeClassifier/MainThread) Use the default train/val/test ratio in given dataset\n",
  121. "HPO Search Phase:\n",
  122. "\n"
  123. ]
  124. },
  125. {
  126. "name": "stderr",
  127. "output_type": "stream",
  128. "text": [
  129. "100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:09<00:00, 1.40s/it]\n"
  130. ]
  131. },
  132. {
  133. "name": "stdout",
  134. "output_type": "stream",
  135. "text": [
  136. "[2022-10-24 18:56:37] INFO (HPO/MainThread) Best Parameter:\n",
  137. "[2022-10-24 18:56:37] INFO (HPO/MainThread) Parameter: {'trainer': {'max_epoch': 165, 'early_stopping_round': 18, 'lr': 0.014545893271287733, 'weight_decay': 0.0001682578213292401}, 'encoder': {'num_layers': 2, 'hidden': [42], 'dropout': 0.6019468841551312, 'act': 'tanh'}, 'decoder': {}} acc: 0.806 higher_better\n"
  138. ]
  139. },
  140. {
  141. "data": {
  142. "text/plain": [
  143. "<autogl.solver.classifier.node_classifier.AutoNodeClassifier at 0x7fa319a65cd0>"
  144. ]
  145. },
  146. "execution_count": 4,
  147. "metadata": {},
  148. "output_type": "execute_result"
  149. }
  150. ],
  151. "source": [
  152. "# time limit is the seconds limited for training the model\n",
  153. "# evaluation method is the metric to evaluate the performance\n",
  154. "autoClassifier.fit(dataset, time_limit=3600, evaluation_method=[Acc])"
  155. ]
  156. },
  157. {
  158. "cell_type": "markdown",
  159. "metadata": {},
  160. "source": [
  161. "## Evaluation\n",
  162. "After training, you can evaluate the model."
  163. ]
  164. },
  165. {
  166. "cell_type": "code",
  167. "execution_count": 5,
  168. "metadata": {},
  169. "outputs": [
  170. {
  171. "name": "stdout",
  172. "output_type": "stream",
  173. "text": [
  174. "+-------------------------------------------------------------------------+-------+\n",
  175. "| name | acc |\n",
  176. "+=========================================================================+=======+\n",
  177. "| decoder: None | 0.806 |\n",
  178. "| early_stopping_round: 18 | |\n",
  179. "| encoder: <autogl.module.model.dgl.gcn.AutoGCN object at 0x7fa3186d0b50> | |\n",
  180. "| learning_rate: 0.014545893271287733 | |\n",
  181. "| max_epoch: 165 | |\n",
  182. "| optimizer: !!python/name:torch.optim.adam.Adam '' | |\n",
  183. "| trainer_name: NodeClassificationFullTrainer | |\n",
  184. "| _idx0 | |\n",
  185. "+-------------------------------------------------------------------------+-------+\n",
  186. "[2022-10-24 18:56:50] INFO (NodeClassifier/MainThread) Ensemble argument on, will try using ensemble model.\n",
  187. "[2022-10-24 18:56:50] WARNING (NodeClassifier/MainThread) Cannot use ensemble because no ensebmle module is given. Will use best model instead.\n",
  188. "test acc: 0.8060\n"
  189. ]
  190. }
  191. ],
  192. "source": [
  193. "autoClassifier.get_leaderboard().show()\n",
  194. "# you can also provided the metric here!\n",
  195. "acc = autoClassifier.evaluate(metric=\"acc\")\n",
  196. "print(\"test acc: {:.4f}\".format(acc))"
  197. ]
  198. }
  199. ],
  200. "metadata": {
  201. "kernelspec": {
  202. "display_name": "Python 3 (ipykernel)",
  203. "language": "python",
  204. "name": "python3"
  205. },
  206. "language_info": {
  207. "codemirror_mode": {
  208. "name": "ipython",
  209. "version": 3
  210. },
  211. "file_extension": ".py",
  212. "mimetype": "text/x-python",
  213. "name": "python",
  214. "nbconvert_exporter": "python",
  215. "pygments_lexer": "ipython3",
  216. "version": "3.7.11"
  217. },
  218. "vscode": {
  219. "interpreter": {
  220. "hash": "ceaf47f872914ebc119c31eaf5650b5ee907a61565d128a6607ed80bbe5b2670"
  221. }
  222. }
  223. },
  224. "nbformat": 4,
  225. "nbformat_minor": 2
  226. }