You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_classification.py 6.2 kB

4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. """
  2. Test file for nas on node classification
  3. AUTOGL_BACKEND=pyg python test/nas/node_classification.py
  4. AUTOGL_BACKEND=dgl python test/nas/node_classification.py
  5. TODO: make it a unit test file to test all the possible combinations
  6. """
  7. import os
  8. import logging
  9. logging.basicConfig(level=logging.INFO)
  10. from autogl.backend import DependentBackend
  11. if DependentBackend.is_dgl():
  12. from autogl.module.model.dgl import BaseAutoModel
  13. from dgl.data import CoraGraphDataset
  14. elif DependentBackend.is_pyg():
  15. from torch_geometric.datasets import Planetoid
  16. from autogl.module.model.pyg import BaseAutoModel
  17. import torch
  18. import torch.nn.functional as F
  19. from autogl.module.nas.space.single_path import SinglePathNodeClassificationSpace
  20. from autogl.module.nas.space.graph_nas import GraphNasNodeClassificationSpace
  21. from autogl.module.nas.space.graph_nas_macro import GraphNasMacroNodeClassificationSpace
  22. from autogl.module.nas.estimator.one_shot import OneShotEstimator
  23. from autogl.module.nas.space.autoattend import AutoAttendNodeClassificationSpace
  24. from autogl.module.nas.backend import bk_feat, bk_label
  25. from autogl.module.nas.algorithm import Darts, RL, GraphNasRL, Enas, RandomSearch,Spos
  26. import numpy as np
  27. from autogl.solver.utils import set_seed
  28. set_seed(202106)
  29. def test_model(model, data=None, check_children=False):
  30. """
  31. Test model interface.
  32. Interface
  33. ---------
  34. - model.from_hyper_parameter()
  35. - model.model.forward()
  36. - model.to()
  37. - model.initialize()
  38. """
  39. assert isinstance(model, BaseAutoModel)
  40. assert hasattr(model, "to")
  41. assert hasattr(model, "initialize")
  42. model.initialize()
  43. model.to("cuda")
  44. if data is not None:
  45. data = data.to("cuda")
  46. assert hasattr(model, "model")
  47. __model = model.model
  48. assert isinstance(__model, torch.nn.Module)
  49. if data is not None:
  50. __model.forward(data)
  51. # FIXME: we can only perform tests when hyper_parameter_space is []
  52. if len(model.hyper_parameter_space) == 0:
  53. model_2 = model.from_hyper_parameter({})
  54. if check_children:
  55. test_model(model_2, data)
  56. if __name__ == "__main__":
  57. print("Testing backend: {}".format("dgl" if DependentBackend.is_dgl() else "pyg"))
  58. if DependentBackend.is_dgl():
  59. dataset = CoraGraphDataset()
  60. else:
  61. dataset = Planetoid(os.path.expanduser("~/.cache-autogl"), "Cora")
  62. data = dataset[0]
  63. di = bk_feat(data).shape[1]
  64. do = len(np.unique(bk_label(data)))
  65. print("evolutionary + singlepath ")
  66. space=SinglePathNodeClassificationSpace().cuda()
  67. space.instantiate(input_dim=di,output_dim=do)
  68. esti=OneShotEstimator()
  69. algo=Spos(cycles=200)
  70. model = algo.search(space, dataset, esti)
  71. test_model(model, data, True)
  72. print("evolutionary + graphnas ")
  73. space=GraphNasNodeClassificationSpace().cuda()
  74. space.instantiate(input_dim=di,output_dim=do)
  75. esti=OneShotEstimator()
  76. algo=Spos(cycles=200)
  77. model = algo.search(space, dataset, esti)
  78. test_model(model, data, True)
  79. print("Random search + graphnas ")
  80. space = GraphNasNodeClassificationSpace().cuda()
  81. space.instantiate(input_dim=di, output_dim=do)
  82. esti = OneShotEstimator()
  83. algo = RandomSearch(num_epochs=100)
  84. model = algo.search(space, dataset, esti)
  85. test_model(model, data, True)
  86. print("Random search + AutoAttend ")
  87. space = AutoAttendNodeClassificationSpace().cuda()
  88. space.instantiate(input_dim=di, output_dim=do)
  89. esti = OneShotEstimator()
  90. algo = RandomSearch(num_epochs=10)
  91. model = algo.search(space, dataset, esti)
  92. print(model)
  93. test_model(model, data, True)
  94. print("rl + AutoAttend ")
  95. space = AutoAttendNodeClassificationSpace().cuda()
  96. space.instantiate(input_dim=di, output_dim=do)
  97. esti = OneShotEstimator()
  98. algo = RL(num_epochs=10)
  99. model = algo.search(space, dataset, esti)
  100. test_model(model, data, True)
  101. print("darts + graphnas ")
  102. space = AutoAttendNodeClassificationSpace().cuda()
  103. space.instantiate(input_dim=di, output_dim=do)
  104. esti = OneShotEstimator()
  105. algo = Darts(num_epochs=10)
  106. model = algo.search(space, dataset, esti)
  107. test_model(model, data, True)
  108. print("Random search + graphnas ")
  109. space = GraphNasNodeClassificationSpace().cuda()
  110. space.instantiate(input_dim=di, output_dim=do)
  111. esti = OneShotEstimator()
  112. algo = RandomSearch(num_epochs=10)
  113. model = algo.search(space, dataset, esti)
  114. test_model(model, data, True)
  115. print("rl + graphnas ")
  116. space = GraphNasNodeClassificationSpace().cuda()
  117. space.instantiate(input_dim=di, output_dim=do)
  118. esti = OneShotEstimator()
  119. algo = RL(num_epochs=10)
  120. model = algo.search(space, dataset, esti)
  121. test_model(model, data, True)
  122. print("graphnasrl + graphnas ")
  123. space = GraphNasNodeClassificationSpace().cuda()
  124. space.instantiate(input_dim=di, output_dim=do)
  125. esti = OneShotEstimator()
  126. algo = GraphNasRL(num_epochs=10)
  127. model = algo.search(space, dataset, esti)
  128. test_model(model, data, True)
  129. print("darts + graphnas ")
  130. space = GraphNasNodeClassificationSpace(con_ops=['concat']).cuda()
  131. space.instantiate(input_dim=di, output_dim=do)
  132. esti = OneShotEstimator()
  133. algo = Darts(num_epochs=10)
  134. model = algo.search(space, dataset, esti)
  135. test_model(model, data, True)
  136. print("darts + singlepath ")
  137. space = SinglePathNodeClassificationSpace().cuda()
  138. space.instantiate(input_dim=di, output_dim=do)
  139. esti = OneShotEstimator()
  140. algo = Darts(num_epochs=10)
  141. model = algo.search(space, dataset, esti)
  142. test_model(model, data, True)
  143. print("Random search + graphnas macro")
  144. space = GraphNasMacroNodeClassificationSpace().cuda()
  145. space.instantiate(input_dim=di, output_dim=do)
  146. esti = OneShotEstimator()
  147. algo = RandomSearch(num_epochs=10)
  148. model = algo.search(space, dataset, esti)
  149. test_model(model, data, True)
  150. print("RL + graphnas macro")
  151. space = GraphNasMacroNodeClassificationSpace().cuda()
  152. space.instantiate(input_dim=di, output_dim=do)
  153. esti = OneShotEstimator()
  154. algo = RL(num_epochs=10)
  155. model = algo.search(space, dataset, esti)
  156. test_model(model, data, True)