|
|
|
@@ -82,8 +82,8 @@ if __name__ == '__main__': |
|
|
|
|
|
|
|
if args.model == 'gat': |
|
|
|
model = AutoGAT( |
|
|
|
num_features=num_features, |
|
|
|
num_classes=num_classes, |
|
|
|
input_dimension=num_features, |
|
|
|
output_dimension=num_classes, |
|
|
|
device=args.device |
|
|
|
).from_hyper_parameter({ |
|
|
|
# hp from model |
|
|
|
@@ -96,8 +96,8 @@ if __name__ == '__main__': |
|
|
|
}).model |
|
|
|
elif args.model == 'gcn': |
|
|
|
model = AutoGCN( |
|
|
|
num_features=num_features, |
|
|
|
num_classes=num_classes, |
|
|
|
input_dimension=num_features, |
|
|
|
output_dimension=num_classes, |
|
|
|
device=args.device |
|
|
|
).from_hyper_parameter({ |
|
|
|
"num_layers": 2, |
|
|
|
@@ -107,8 +107,8 @@ if __name__ == '__main__': |
|
|
|
}).model |
|
|
|
elif args.model == 'sage': |
|
|
|
model = AutoSAGE( |
|
|
|
num_features=num_features, |
|
|
|
num_classes=num_classes, |
|
|
|
input_dimension=num_features, |
|
|
|
output_dimension=num_classes, |
|
|
|
device=args.device |
|
|
|
).from_hyper_parameter({ |
|
|
|
"num_layers": 2, |
|
|
|
|