You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_classification.py 2.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import sys
  2. sys.path.append('../')
  3. from autogl.datasets import build_dataset_from_name
  4. from autogl.solver import AutoNodeClassifier
  5. from autogl.module import Acc
  6. import yaml
  7. import random
  8. import torch
  9. import numpy as np
  10. import logging
  11. logging.basicConfig(level=logging.INFO)
  12. if __name__ == '__main__':
  13. from argparse import ArgumentParser
  14. parser = ArgumentParser()
  15. parser.add_argument('--dataset', default='cora', type=str)
  16. parser.add_argument('--configs', type=str, default='../configs/nodeclf_gcn_benchmark_small.yml')
  17. # following arguments will override parameters in the config file
  18. parser.add_argument('--hpo', type=str, default='random')
  19. parser.add_argument('--max_eval', type=int, default=5)
  20. parser.add_argument('--seed', type=int, default=0)
  21. parser.add_argument('--device', default=0, type=int)
  22. args = parser.parse_args()
  23. if torch.cuda.is_available():
  24. torch.cuda.set_device(args.device)
  25. seed = args.seed
  26. # set random seed
  27. random.seed(seed)
  28. np.random.seed(seed)
  29. torch.manual_seed(seed)
  30. if torch.cuda.is_available():
  31. torch.cuda.manual_seed(seed)
  32. torch.backends.cudnn.deterministic = True
  33. torch.backends.cudnn.benchmark = False
  34. dataset = build_dataset_from_name(args.dataset)
  35. configs = yaml.load(open(args.configs, 'r').read(), Loader=yaml.FullLoader)
  36. configs['hpo']['name'] = args.hpo
  37. configs['hpo']['max_evals'] = args.max_eval
  38. autoClassifier = AutoNodeClassifier.from_config(configs)
  39. # train
  40. if args.dataset in ['cora', 'citeseer', 'pubmed']:
  41. autoClassifier.fit(dataset, time_limit=3600, evaluation_method=[Acc])
  42. else:
  43. autoClassifier.fit(dataset, time_limit=3600, evaluation_method=[Acc], seed=seed, train_split=20*dataset.num_classes, val_split=30*dataset.num_classes, balanced=False)
  44. val = autoClassifier.get_model_by_performance(0)[0].get_valid_score()[0]
  45. print('val acc: ', val)
  46. # test
  47. predict_result = autoClassifier.predict_proba(use_best=True, use_ensemble=False)
  48. print('test acc: ', Acc.evaluate(predict_result, dataset.data.y[dataset.data.test_mask].numpy()))