from autogl.module.train import GraphClassificationFullTrainer from autogl.datasets import build_dataset_from_name, utils from autogl.datasets.utils.conversion._to_pyg_dataset import to_pyg_dataset def test_graph_trainer(): dataset = build_dataset_from_name("mutag") utils.graph_random_splits(dataset, 0.8, 0.1) dataset = to_pyg_dataset(dataset) lp_trainer = GraphClassificationFullTrainer(model='gin', init=False) lp_trainer.num_features = dataset[0].x.size(1) lp_trainer.num_classes = max([d.y for d in dataset]).item() + 1 lp_trainer.num_graph_features = 0 lp_trainer.initialize() print(lp_trainer.encoder.encoder) print(lp_trainer.decoder.decoder) lp_trainer.train(dataset, True) result = lp_trainer.evaluate(dataset, "test", "acc") print("Acc:", result) test_graph_trainer()