You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_classification.py 6.2 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. """
  2. Test file for nas on node classification
  3. AUTOGL_BACKEND=pyg python test/nas/node_classification.py
  4. AUTOGL_BACKEND=dgl python test/nas/node_classification.py
  5. TODO: make it a unit test file to test all the possible combinations
  6. """
  7. import os
  8. import logging
  9. logging.basicConfig(level=logging.INFO)
  10. from autogl.backend import DependentBackend
  11. if DependentBackend.is_dgl():
  12. from autogl.module.model.dgl import BaseAutoModel
  13. from dgl.data import CoraGraphDataset
  14. elif DependentBackend.is_pyg():
  15. from torch_geometric.datasets import Planetoid
  16. from autogl.module.model.pyg import BaseAutoModel
  17. from autogl.datasets import build_dataset_from_name
  18. import torch
  19. import torch.nn.functional as F
  20. from autogl.module.nas.space.single_path import SinglePathNodeClassificationSpace
  21. from autogl.module.nas.space.graph_nas import GraphNasNodeClassificationSpace
  22. from autogl.module.nas.space.graph_nas_macro import GraphNasMacroNodeClassificationSpace
  23. from autogl.module.nas.estimator.one_shot import OneShotEstimator
  24. from autogl.module.nas.space.autoattend import AutoAttendNodeClassificationSpace
  25. from autogl.module.nas.backend import bk_feat, bk_label
  26. from autogl.module.nas.algorithm import Darts, RL, GraphNasRL, Enas, RandomSearch,Spos
  27. import numpy as np
  28. from autogl.solver.utils import set_seed
  29. set_seed(202106)
  30. def test_model(model, data=None, check_children=False):
  31. """
  32. Test model interface.
  33. Interface
  34. ---------
  35. - model.from_hyper_parameter()
  36. - model.model.forward()
  37. - model.to()
  38. - model.initialize()
  39. """
  40. assert isinstance(model, BaseAutoModel)
  41. assert hasattr(model, "to")
  42. assert hasattr(model, "initialize")
  43. model.initialize()
  44. model.to("cuda")
  45. if data is not None:
  46. data = data.to("cuda")
  47. assert hasattr(model, "model")
  48. __model = model.model
  49. assert isinstance(__model, torch.nn.Module)
  50. if data is not None:
  51. __model.forward(data)
  52. # FIXME: we can only perform tests when hyper_parameter_space is []
  53. if len(model.hyper_parameter_space) == 0:
  54. model_2 = model.from_hyper_parameter({})
  55. if check_children:
  56. test_model(model_2, data)
  57. if __name__ == "__main__":
  58. print("Testing backend: {}".format("dgl" if DependentBackend.is_dgl() else "pyg"))
  59. if DependentBackend.is_dgl():
  60. from autogl.datasets.utils.conversion._to_dgl_dataset import to_dgl_dataset as convert_dataset
  61. else:
  62. from autogl.datasets.utils.conversion._to_pyg_dataset import to_pyg_dataset as convert_dataset
  63. dataset = build_dataset_from_name('cora')
  64. dataset = convert_dataset(dataset)
  65. data = dataset[0]
  66. di = bk_feat(data).shape[1]
  67. do = len(np.unique(bk_label(data)))
  68. print("evolutionary + singlepath ")
  69. space=SinglePathNodeClassificationSpace().cuda()
  70. space.instantiate(input_dim=di,output_dim=do)
  71. esti=OneShotEstimator()
  72. algo=Spos(cycles=200)
  73. model = algo.search(space, dataset, esti)
  74. test_model(model, data, True)
  75. print("evolutionary + graphnas ")
  76. space=GraphNasNodeClassificationSpace().cuda()
  77. space.instantiate(input_dim=di,output_dim=do)
  78. esti=OneShotEstimator()
  79. algo=Spos(cycles=200)
  80. model = algo.search(space, dataset, esti)
  81. test_model(model, data, True)
  82. print("Random search + graphnas ")
  83. space = GraphNasNodeClassificationSpace().cuda()
  84. space.instantiate(input_dim=di, output_dim=do)
  85. esti = OneShotEstimator()
  86. algo = RandomSearch(num_epochs=100)
  87. model = algo.search(space, dataset, esti)
  88. test_model(model, data, True)
  89. print("Random search + AutoAttend ")
  90. space = AutoAttendNodeClassificationSpace().cuda()
  91. space.instantiate(input_dim=di, output_dim=do)
  92. esti = OneShotEstimator()
  93. algo = RandomSearch(num_epochs=10)
  94. model = algo.search(space, dataset, esti)
  95. print(model)
  96. test_model(model, data, True)
  97. print("rl + AutoAttend ")
  98. space = AutoAttendNodeClassificationSpace().cuda()
  99. space.instantiate(input_dim=di, output_dim=do)
  100. esti = OneShotEstimator()
  101. algo = RL(num_epochs=10)
  102. model = algo.search(space, dataset, esti)
  103. test_model(model, data, True)
  104. print("Random search + graphnas ")
  105. space = GraphNasNodeClassificationSpace().cuda()
  106. space.instantiate(input_dim=di, output_dim=do)
  107. esti = OneShotEstimator()
  108. algo = RandomSearch(num_epochs=10)
  109. model = algo.search(space, dataset, esti)
  110. test_model(model, data, True)
  111. print("rl + graphnas ")
  112. space = GraphNasNodeClassificationSpace().cuda()
  113. space.instantiate(input_dim=di, output_dim=do)
  114. esti = OneShotEstimator()
  115. algo = RL(num_epochs=10)
  116. model = algo.search(space, dataset, esti)
  117. test_model(model, data, True)
  118. print("graphnasrl + graphnas ")
  119. space = GraphNasNodeClassificationSpace().cuda()
  120. space.instantiate(input_dim=di, output_dim=do)
  121. esti = OneShotEstimator()
  122. algo = GraphNasRL(num_epochs=10)
  123. model = algo.search(space, dataset, esti)
  124. test_model(model, data, True)
  125. print("darts + graphnas ")
  126. space = GraphNasNodeClassificationSpace(con_ops=['concat']).cuda()
  127. space.instantiate(input_dim=di, output_dim=do)
  128. esti = OneShotEstimator()
  129. algo = Darts(num_epochs=10)
  130. model = algo.search(space, dataset, esti)
  131. test_model(model, data, True)
  132. print("darts + singlepath ")
  133. space = SinglePathNodeClassificationSpace().cuda()
  134. space.instantiate(input_dim=di, output_dim=do)
  135. esti = OneShotEstimator()
  136. algo = Darts(num_epochs=10)
  137. model = algo.search(space, dataset, esti)
  138. test_model(model, data, True)
  139. print("Random search + graphnas macro")
  140. space = GraphNasMacroNodeClassificationSpace().cuda()
  141. space.instantiate(input_dim=di, output_dim=do)
  142. esti = OneShotEstimator()
  143. algo = RandomSearch(num_epochs=10)
  144. model = algo.search(space, dataset, esti)
  145. test_model(model, data, True)
  146. print("RL + graphnas macro")
  147. space = GraphNasMacroNodeClassificationSpace().cuda()
  148. space.instantiate(input_dim=di, output_dim=do)
  149. esti = OneShotEstimator()
  150. algo = RL(num_epochs=10)
  151. model = algo.search(space, dataset, esti)
  152. test_model(model, data, True)