{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Quick start\n", "## Import datasets\n", "At first, import our library and datasets from the given path.\n", "Under the given directory, there should be a `data` directory with different datasets. E.g. `/home/AGL/data/cora`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import autogl\n", "from autogl.datasets import build_dataset_from_name\n", "cora_dataset = build_dataset_from_name('cora', path = '/home/AGL/')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Decide modules\n", "Then, you should decide which models to use.\n", "Here, we use `deepgl` to pre-process graph features, then we use two GNNs to learn the target task, e.g. `GCN` and `GAT`.\n", "We use Simulated annealing algorithm to tune the hyper-parameters of the two GNNs. \n", "After training, use voting method to ensemble the results of the two GNNs.\n", "Also, you can specify which device to run on." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "device = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')\n", "from autogl.solver import AutoNodeClassifier\n", "solver = AutoNodeClassifier(\n", " feature_module='deepgl',\n", " graph_models=['gcn', 'gat'],\n", " hpo_module='anneal',\n", " ensemble_module='voting',\n", " device=device\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Running\n", "Run the whole process with a certain time limit and show the leaderboard.\n", "You can also get the accuracy by evaluate the predictions." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "solver.fit(cora_dataset, time_limit=3600)\n", "solver.get_leaderboard().show()\n", "\n", "from autogl.module.train import Acc\n", "predicted = solver.predict_proba()\n", "print('Test accuracy: ', Acc.evaluate(predicted, \n", " cora_dataset.data.y[cora_dataset.data.test_mask].cpu().numpy()))" ] } ], "metadata": { "language_info": { "name": "python" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }