{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Quick start\n", "## Import datasets\n", "At first, import our library and datasets from the given path.\n", "Under the given directory, there should be a `data` directory with different datasets. E.g. `/home/AGL/data/cora`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import autogl\n", "from autogl.datasets import build_dataset_from_name\n", "cora_dataset = build_dataset_from_name('cora', path = '~/')\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Decide modules\n", "Then, you should decide which models to use.\n", "Here, we use `deepgl` to pre-process graph features, then we use two GNNs to learn the target task, e.g. `GCN` and `GAT`.\n", "We use Simulated annealing algorithm to tune the hyper-parameters of the two GNNs. \n", "After training, use voting method to ensemble the results of the two GNNs.\n", "Also, you can specify which device to run on." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "device = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')\n", "from autogl.solver import AutoNodeClassifier\n", "solver = AutoNodeClassifier(\n", " feature_module='deepgl',\n", " graph_models=['gcn', 'gat'],\n", " hpo_module='anneal',\n", " ensemble_module='voting',\n", " device=device\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Running\n", "Run the whole process with a certain time limit and show the leaderboard.\n", "You can also get the accuracy by evaluate the predictions." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "solver.fit(cora_dataset, time_limit=3600)\n", "solver.get_leaderboard().show()\n", "\n", "from autogl.module.train import Acc\n", "from autogl.solver.utils import get_graph_labels, get_graph_masks\n", "\n", "predicted = solver.predict_proba()\n", "label = get_graph_labels(cora_dataset[0])[get_graph_masks(cora_dataset[0], 'test')].cpu().numpy()\n", "print('Test accuracy: ', Acc.evaluate(predicted, label))" ] } ], "metadata": { "kernelspec": { "display_name": "agl", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.9.15 (main, Nov 24 2022, 14:31:59) \n[GCC 11.2.0]" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "f5df920585aabffd7c100033fb0a19d10b4cb3343e2b2347d38d24b7cc162540" } } }, "nbformat": 4, "nbformat_minor": 2 }