diff --git a/examples/test_graph_nas.py b/examples/test_graph_nas.py index 194c739..d1e297c 100644 --- a/examples/test_graph_nas.py +++ b/examples/test_graph_nas.py @@ -30,7 +30,7 @@ if __name__ == '__main__': lr_scheduler_type=None,), nas_algorithms=[Enas(num_epochs=10)], #nas_algorithms=[Darts(num_epochs=200)], - nas_spaces=[GraphNasNodeClassificationSpace(hidden_dim=16, ops=[GCNConv, GCNConv])], + nas_spaces=[GraphNasNodeClassificationSpace(hidden_dim=16, ops=[GCNConv, GCNConv],search_act_con=True)], nas_estimators=[OneShotEstimator()] ) solver.fit(dataset)