You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grna.py 2.8 kB

3 years ago
3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # "Adversarially Robust Neural Architecture Search for Graph Neural Networks"
  2. import logging
  3. import torch
  4. import torch.optim
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from . import register_nas_algo
  8. from .base import BaseNAS
  9. from ..estimator.base import BaseEstimator
  10. from ..space import BaseSpace
  11. from ..utils import replace_layer_choice, replace_input_choice
  12. from ...model.base import BaseAutoModel
  13. from .spos import Evolution, UniformSampler, Spos
  14. from ..utils import (
  15. AverageMeterGroup,
  16. replace_layer_choice,
  17. replace_input_choice,
  18. get_module_order,
  19. sort_replaced_module,
  20. PathSamplingLayerChoice,
  21. PathSamplingInputChoice,
  22. )
  23. from tqdm import tqdm, trange
  24. from torch.autograd import Variable
  25. import numpy as np
  26. import time
  27. import copy
  28. import torch.optim as optim
  29. import scipy.sparse as sp
  30. _logger = logging.getLogger(__name__)
  31. @register_nas_algo("grna")
  32. class GRNA(Spos):
  33. """
  34. GRNA trainer.
  35. Parameters
  36. ----------
  37. n_warmup : int
  38. Number of epochs for training super network.
  39. model_lr : float
  40. Learning rate for super network.
  41. model_wd : float
  42. Weight decay for super network.
  43. Other parameters see Evolution
  44. """
  45. def __init__(
  46. self,
  47. n_warmup=1000,
  48. grad_clip=5.0,
  49. disable_progress=False,
  50. optimize_mode='maximize',
  51. population_size=100,
  52. sample_size=25,
  53. cycles=20000,
  54. mutation_prob=0.05,
  55. device="cuda",
  56. ):
  57. super().__init__(n_warmup,
  58. grad_clip,
  59. disable_progress,
  60. optimize_mode,
  61. population_size,
  62. sample_size,
  63. cycles,
  64. mutation_prob,
  65. device)
  66. def _prepare(self):
  67. # replace choice
  68. self.nas_modules = []
  69. k2o = get_module_order(self.model)
  70. replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules)
  71. replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules)
  72. self.nas_modules = sort_replaced_module(k2o, self.nas_modules)
  73. # to device
  74. self.model = self.model.to(self.device)
  75. self.model_optim = torch.optim.Adam(
  76. self.model.parameters(), lr=self.model_lr, weight_decay=self.model_wd
  77. )
  78. # controller
  79. self.controller = UniformSampler(self.nas_modules)
  80. # Evolution
  81. self.evolve = Evolution(
  82. optimize_mode='maximize',
  83. population_size=self.population_size,
  84. sample_size=self.sample_size,
  85. cycles=self.cycles,
  86. mutation_prob=self.mutation_prob,
  87. disable_progress=self.disable_progress
  88. )
  89. def _infer(self, mask="train"):
  90. metric, loss = self.estimator.infer(self.model, self.dataset, mask=mask)
  91. return metric[0], loss