You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_classification.py 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. """
  2. Test file for nas on node classification
  3. AUTOGL_BACKEND=pyg python test/nas/node_classification.py
  4. AUTOGL_BACKEND=dgl python test/nas/node_classification.py
  5. TODO: make it a unit test file to test all the possible combinations
  6. """
  7. import os
  8. import logging
  9. logging.basicConfig(level=logging.INFO)
  10. from autogl.backend import DependentBackend
  11. if DependentBackend.is_dgl():
  12. from autogl.module.model.dgl import BaseModel
  13. from dgl.data import CoraGraphDataset
  14. elif DependentBackend.is_pyg():
  15. from torch_geometric.datasets import Planetoid
  16. from autogl.module.model.pyg import BaseModel
  17. import torch
  18. import torch.nn.functional as F
  19. from autogl.module.nas.space.single_path import SinglePathNodeClassificationSpace
  20. from autogl.module.nas.space.graph_nas import GraphNasNodeClassificationSpace
  21. from autogl.module.nas.space.graph_nas_macro import GraphNasMacroNodeClassificationSpace
  22. from autogl.module.nas.estimator.one_shot import OneShotEstimator
  23. from autogl.module.nas.backend import bk_feat, bk_label
  24. from autogl.module.nas.algorithm import Darts, RL, GraphNasRL, Enas, RandomSearch,Spos
  25. import numpy as np
  26. from autogl.solver.utils import set_seed
  27. set_seed(202106)
  28. def test_model(model, data=None, check_children=False):
  29. """
  30. Test model interface.
  31. Interface
  32. ---------
  33. - model.from_hyper_parameter()
  34. - model.model.forward()
  35. - model.to()
  36. - model.initialize()
  37. """
  38. assert isinstance(model, BaseModel)
  39. assert hasattr(model, "to")
  40. assert hasattr(model, "initialize")
  41. model.initialize()
  42. model.to("cuda")
  43. if data is not None:
  44. data = data.to("cuda")
  45. assert hasattr(model, "model")
  46. __model = model.model
  47. assert isinstance(__model, torch.nn.Module)
  48. if data is not None:
  49. __model.forward(data)
  50. # FIXME: we can only perform tests when hyper_parameter_space is []
  51. if len(model.hyper_parameter_space) == 0:
  52. model_2 = model.from_hyper_parameter({})
  53. if check_children:
  54. test_model(model_2, data)
  55. if __name__ == "__main__":
  56. print("Testing backend: {}".format("dgl" if DependentBackend.is_dgl() else "pyg"))
  57. if DependentBackend.is_dgl():
  58. dataset = CoraGraphDataset()
  59. else:
  60. dataset = Planetoid(os.path.expanduser("~/.cache-autogl"), "Cora")
  61. data = dataset[0]
  62. di = bk_feat(data).shape[1]
  63. do = len(np.unique(bk_label(data)))
  64. print("evolutionary + singlepath ")
  65. space=SinglePathNodeClassificationSpace().cuda()
  66. space.instantiate(input_dim=di,output_dim=do)
  67. esti=OneShotEstimator()
  68. algo=Spos(cycles=200)
  69. model = algo.search(space, dataset, esti)
  70. test_model(model, data, True)
  71. print("evolutionary + graphnas ")
  72. space=GraphNasNodeClassificationSpace().cuda()
  73. space.instantiate(input_dim=di,output_dim=do)
  74. esti=OneShotEstimator()
  75. algo=Spos(cycles=200)
  76. model = algo.search(space, dataset, esti)
  77. test_model(model, data, True)
  78. print("Random search + graphnas ")
  79. space = GraphNasNodeClassificationSpace().cuda()
  80. space.instantiate(input_dim=di, output_dim=do)
  81. esti = OneShotEstimator()
  82. algo = RandomSearch(num_epochs=10)
  83. model = algo.search(space, dataset, esti)
  84. test_model(model, data, True)
  85. print("Random search + singlepath ")
  86. space = SinglePathNodeClassificationSpace().cuda()
  87. space.instantiate(input_dim=di, output_dim=do)
  88. esti = OneShotEstimator()
  89. algo = RandomSearch(num_epochs=10)
  90. model = algo.search(space, dataset, esti)
  91. test_model(model, data, True)
  92. print("rl + graphnas ")
  93. space = GraphNasNodeClassificationSpace().cuda()
  94. space.instantiate(input_dim=di, output_dim=do)
  95. esti = OneShotEstimator()
  96. algo = RL(num_epochs=10)
  97. model = algo.search(space, dataset, esti)
  98. test_model(model, data, True)
  99. print("graphnasrl + graphnas ")
  100. space = GraphNasNodeClassificationSpace().cuda()
  101. space.instantiate(input_dim=di, output_dim=do)
  102. esti = OneShotEstimator()
  103. algo = GraphNasRL(num_epochs=10)
  104. model = algo.search(space, dataset, esti)
  105. test_model(model, data, True)
  106. print("enas + graphnas ")
  107. space = GraphNasNodeClassificationSpace().cuda()
  108. space.instantiate(input_dim=di, output_dim=do)
  109. esti = OneShotEstimator()
  110. algo = Enas(num_epochs=10)
  111. model = algo.search(space, dataset, esti)
  112. test_model(model, data, True)
  113. print("darts + graphnas ")
  114. space = GraphNasNodeClassificationSpace(con_ops=['concat']).cuda()
  115. space.instantiate(input_dim=di, output_dim=do)
  116. esti = OneShotEstimator()
  117. algo = Darts(num_epochs=10)
  118. model = algo.search(space, dataset, esti)
  119. test_model(model, data, True)
  120. print("darts + singlepath ")
  121. space = SinglePathNodeClassificationSpace().cuda()
  122. space.instantiate(input_dim=di, output_dim=do)
  123. esti = OneShotEstimator()
  124. algo = Darts(num_epochs=10)
  125. model = algo.search(space, dataset, esti)
  126. test_model(model, data, True)
  127. print("Random search + graphnas macro")
  128. space = GraphNasMacroNodeClassificationSpace().cuda()
  129. space.instantiate(input_dim=di, output_dim=do)
  130. esti = OneShotEstimator()
  131. algo = RandomSearch(num_epochs=10)
  132. model = algo.search(space, dataset, esti)
  133. test_model(model, data, True)
  134. print("RL + graphnas macro")
  135. space = GraphNasMacroNodeClassificationSpace().cuda()
  136. space.instantiate(input_dim=di, output_dim=do)
  137. esti = OneShotEstimator()
  138. algo = RL(num_epochs=10)
  139. model = algo.search(space, dataset, esti)
  140. test_model(model, data, True)