You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

t_ssl_trainer.rst 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. .. _trainer_ssl:
  2. AutoGL SSL Trainer
  3. ==================
  4. AutoGL project use ``trainer`` to implement the graph self-supervised
  5. methods. Currently, we only support the
  6. `GraphCL <https://proceedings.neurips.cc/paper/2020/hash/3fe230348e9a12c13120749e3f9fa4cd-Abstract.html>`__
  7. with semi-supervised downstream tasks:
  8. - ``GraphCLSemisupervisedTrainer`` using GraphCL algorithm for
  9. semi-supervised downstream tasks, the main interfaces are shown below
  10. - ``train(self, dataset, keep_valid_result=True)``: The function of
  11. training on the given dataset and keeping valid results
  12. - ``dataset``: the graph dataset used to be trained
  13. - ``keep_valid_result``: if ``True``, save the validation result
  14. after training. Only if ``keep_valid_result`` is ``True`` and
  15. after training, the method ``get_valid_score``,
  16. ``get_valid_predict_proba`` and ``get_valid_predict`` could
  17. output meaningful results.
  18. - ``predict(self, dataset, mask="test")``: The function of
  19. predicting on the given dataset
  20. - ``dataset``: the graph dataset used to be predicted
  21. - ``mask``: ``"train", "val" or "test"``, the dataset mask
  22. - ``predict_proba(self, dataset, mask="test", in_log_format=False)``:
  23. The function of predicting the probability on the given dataset.
  24. - ``dataset``: the graph dataset used to be predicted
  25. - ``mask``: ``"train", "val" or "test"``, the dataset mask
  26. - ``in_log_format``: if ``in_log_format`` is ``True``, the
  27. probability will be log format
  28. - ``evaluate(self, dataset, mask="val", feval=None)``: The function
  29. of evaluating the model on the given dataset and keeping valid
  30. result.
  31. - ``dataset``: the graph dataset used to be evaluated
  32. - ``mask``: ``"train", "val" or "test"``, the dataset mask
  33. - ``feval``: the evaluation method used in this function. If
  34. ``feval`` is ``None``, it will use the ``feval`` given when
  35. initiate
  36. - ``get_valid_score(self, return_major=True)``: The function of
  37. getting valid scores after training.
  38. - ``return_major``: if ``return_major`` is ``True``, then return
  39. only consists of the major result.
  40. - ``get_valid_predict_proba(self)``: Get the prediction probability
  41. of the valid set after training.
  42. - ``get_valid_predict(self)``: Get the valid result after training
  43. Lazy Initialization
  44. -------------------
  45. Similar reason to :ref:model, we also use lazy initialization for all
  46. trainers. Only (part of) the hyper-parameters will be set when
  47. ``__init__()`` is called. The ``trainer`` will have its core ``model``
  48. only after ``initialize()`` is explicitly called, which will be done
  49. automatically in ``solver`` and ``duplicate_from_hyper_parameter()``,
  50. after all the hyper-parameters are set properly.
  51. For example, if you want to set ``gcn`` as encoder, a simple ``mlp`` as
  52. a decoder, and use ``mlp`` as a classifier to solve a graph
  53. classification problem, there are three steps you need to do.
  54. - First, import everything you need
  55. .. code:: python
  56. from autogl.module.train.ssl import GraphCLSemisupervisedTrainer
  57. from autogl.datasets import build_dataset_from_name, utils
  58. from autogl.datasets.utils.conversion import to_pyg_dataset as convert_dataset
  59. - Secondly, setup the hyper-parameters of the encoder, decoder and the
  60. classifier
  61. .. code:: python
  62. trainer_hp = {
  63. 'batch_size': 128,
  64. 'p_lr': 0.0001, # learning rate of pretraining stage
  65. 'p_weight_decay': 0, # weight decay of pretraining stage
  66. 'p_epoch': 100, # max epoch of pretraining stage
  67. 'p_early_stopping_round': 100, # early stopping round of pretraining stage
  68. 'f_lr': 0.0001, # learning rate of fine-tuning stage
  69. 'f_weight_decay': 0, # weight decay of fine-tuning stage
  70. 'f_epoch': 100, # max epoch of fine-tuning stage
  71. 'f_early_stopping_round': 100, # early stopping round of fine-tuning stage
  72. }
  73. encoder_hp = {
  74. 'num_layers': 3,
  75. 'hidden': [64, 128], # hidden dimensions, didn't need to set the dimension of final layer
  76. 'dropout': 0.5,
  77. 'act': 'relu',
  78. 'eps': 'false'
  79. }
  80. decoder_hp = {
  81. 'hidden': 64,
  82. 'act': 'relu',
  83. 'dropout': 0.5
  84. }
  85. prediction_head_hp = {
  86. 'hidden': 64,
  87. 'act': 'relu',
  88. 'dropout': 0.5
  89. }
  90. - Thirdly, use ``duplicate_from_hyper_parameter()``
  91. .. code:: python
  92. dataset = build_dataset_from_name('proteins')
  93. dataset = convert_dataset(dataset)
  94. utils.graph_random_splits(dataset, train_ratio=0.1, val_ratio=0.1, seed=2022) # split the dataset
  95. # generate a trainer, but it couldn't be used
  96. # before you call `duplicate_from_hyper_parameter`
  97. trainer = GraphCLSemisupervisedTrainer(
  98. model=('gcn', 'sumpoolmlp'),
  99. prediction_model_head='sumpoolmlp',
  100. views_fn=['random2', 'random2'],
  101. num_features=dataset[0].x.size(1),
  102. num_classes=max([data.y.item() for data in dataset]) + 1,
  103. z_dim=128, # the embedding dimension
  104. init=False
  105. )
  106. # call duplicate_from_hyper_parameter to set some information about
  107. # model architecture and learning hyperparameters
  108. trainer.initialize()
  109. trainer = trainer.duplicate_from_hyper_parameter(
  110. {
  111. 'trainer': trainer_hp,
  112. 'encoder': encoder_hp,
  113. 'decoder': decoder_hp,
  114. 'prediction_head': prediction_head_hp
  115. }
  116. )
  117. Train and Predict
  118. -----------------
  119. After initializing a trainer, you can train it on the given datasets.
  120. We are given the training and testing functions for the tasks of graph
  121. classification. You can also create your own tasks following similar
  122. patterns to ours.
  123. We provide some interfaces, and you can easily use them to train or test
  124. on the given datasets.
  125. - Training: ``train()``
  126. .. code:: python
  127. trainer.train(dataset, keep_valid_result=False)
  128. ``train()`` is the method of training on the given dataset and
  129. keeping valid results.
  130. It has two parameters, the first parameter is ``dataset``, which is
  131. the graph dataset used to be trained. And the second parameter is
  132. ``keep_valid_result``. It is a bool value, if true, the trainer will
  133. save the validation result after training if the dataset has a
  134. validation set.
  135. - Testing: ``predict()``
  136. .. code:: python
  137. trainer.predict(dataset, 'test').detach().cpu().numpy()
  138. ``predict()`` is the method of predicting the given dataset.
  139. It has two parameters, the first parameter is ``dataset``, which is
  140. the graph dataset used to be predicted. And the second parameter is
  141. ``mask``. It is a string which can be 'train', 'val', or 'test'. And
  142. returns the prediction results.
  143. - Evaluation: ``evaluate()``
  144. .. code:: python
  145. result = trainer.evaluate(dataset, 'test') # return a list of metrics, the default metric is accuracy
  146. ``evaluate()`` is the method of evaluating the model on the given
  147. dataset and keeping valid results.
  148. It has three parameters, the first parameter is ``dataset``, which is
  149. the graph dataset used to be evaluated. And the second parameter is
  150. ``mask``. It is a string which can be 'train', 'val' or 'test'. And
  151. the last parameter is ``feval``, which can be a string, tuple of strings,
  152. or None, which means the used evaluation methods such ``Acc``.
  153. And you can write your own evaluation metrics and methods. Here is a
  154. simple example:
  155. .. code:: python
  156. from autogl.module.train.evaluation import Evaluation, register_evaluate
  157. from sklearn.metrics import accuracy_score
  158. @register_evaluate("my_acc") # use method register_evaluate, and then you can use this class by its register name 'my_acc'
  159. class MyAcc(Evaluation):
  160. @staticmethod
  161. def get_eval_name():
  162. '''
  163. define the name, didn't need to same as the registered name
  164. '''
  165. return "my_acc"
  166. @staticmethod
  167. def is_higher_better():
  168. '''
  169. return whether this evaluation method is higher better (bool)
  170. '''
  171. return True
  172. @staticmethod
  173. def evaluate(predict, label):
  174. '''
  175. return the evaluation result (float)
  176. '''
  177. if len(predict.shape) == 2:
  178. predict = np.argmax(predict, axis=1)
  179. else:
  180. predict = [1 if p > 0.5 else 0 for p in predict]
  181. return accuracy_score(label, predict)
  182. Implement SSL Trainer
  183. ---------------------
  184. Next, we will show how to implement your own ssl trainer. It is more
  185. difficult to implement the trainer than to use it, it needs to implement
  186. three main functions ``_train_only()``, ``_predict_only()`` and
  187. ``duplicate_from_hyper_parameter()``. Now we will implement GraphCL with
  188. unsupervised downstream tasks step by step.
  189. - initialize your trainer
  190. First, We need to import some classes and methods, define a basic
  191. ``__init__()`` method, and register our trainer.
  192. .. code:: python
  193. import torch
  194. from torch.optim.lr_scheduler import StepLR
  195. from autogl.module.train import register_trainer
  196. from autogl.module.train.ssl.base import BaseContrastiveTrainer
  197. from autogl.datasets import utils
  198. @register_trainer("GraphCLUnsupervisedTrainer")
  199. class GraphCLUnsupervisedTrainer(BaseContrastiveTrainer):
  200. def __init__(
  201. self,
  202. model,
  203. prediction_model_head,
  204. num_features,
  205. num_classes,
  206. num_graph_features,
  207. device,
  208. feval,
  209. views_fn,
  210. z_dim,
  211. num_workers,
  212. batch_size,
  213. eval_interval,
  214. init,
  215. *args,
  216. **kwargs,
  217. ):
  218. # setup encoder and decoder
  219. if isinstance(model, Tuple):
  220. encoder, decoder = model
  221. elif isinstance(model, BaseAutoModel):
  222. raise ValueError("The GraphCL trainer must need an encoder and a decoder, so `model` shouldn't be an instance of `BaseAutoModel`")
  223. else:
  224. encoder, decoder = model, "sumpoolmlp"
  225. self.eval_interval = eval_interval
  226. # init contrastive learning
  227. super().__init__(
  228. encoder=encoder,
  229. decoder=decoder,
  230. decoder_node=None,
  231. num_features=num_features,
  232. num_graph_features=num_graph_features,
  233. views_fn=views_fn,
  234. graph_level=True, # have graph-level features
  235. node_level=False, # have node-level features
  236. device=device,
  237. feval=feval,
  238. z_dim=z_dim, # the dimension of the embedding output by encoder
  239. z_node_dim=None,
  240. *args,
  241. **kwargs,
  242. )
  243. # initialize something specific for your own method
  244. self.views_fn = views_fn
  245. self.aug_ratio = aug_ratio
  246. self._prediction_model_head = None
  247. self.num_classes = num_classes
  248. self.prediction_model_head = prediction_model_head
  249. self.batch_size = batch_size
  250. self.num_workers = num_workers
  251. if self.num_workers > 0:
  252. mp.set_start_method("fork", force=True)
  253. # setup the hyperparameter when initializing
  254. self.hyper_parameters = {
  255. "batch_size": self.batch_size,
  256. "p_epoch": self.p_epoch,
  257. "p_early_stopping_round": self.p_early_stopping_round,
  258. "p_lr": self.p_lr,
  259. "p_weight_decay": self.p_weight_decay,
  260. "f_epoch": self.f_epoch,
  261. "f_early_stopping_round": self.f_early_stopping_round,
  262. "f_lr": self.f_lr,
  263. "f_weight_decay": self.f_weight_decay,
  264. }
  265. self.args = args
  266. self.kwargs = kwargs
  267. if init:
  268. self.initialize()
  269. - ``_train_only(self, dataset)``
  270. In this method, the trainer trains the model on the given dataset.
  271. You can define several different methods for different training
  272. stages.
  273. - set the model on the specified device
  274. .. code:: python
  275. def _set_model_device(self, dataset):
  276. self.encoder.encoder.to(self.device)
  277. self.decoder.decoder.to(self.device)
  278. - For training, you can simply call
  279. ``super(). _train_pretraining_only(dataset, per_epoch)`` to train
  280. the encoder.
  281. .. code:: python
  282. for i, epoch in enumerate(super()._train_pretraining_only(dataset, per_epoch=True)):
  283. # you can define your own training process if you want
  284. # for example, we will fine-tune for every eval_interval epoch
  285. if (i + 1) % self.eval_interval == 0:
  286. # fine-tuning
  287. # get dataset
  288. train_loader = utils.graph_get_split(dataset, "train", batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)
  289. val_loader = utils.graph_get_split(dataset, "val", batch_size=self.batch_size, num_workers=self.num_workers)
  290. # setup model
  291. self.encoder.encoder.eval()
  292. self.prediction_model_head.initialize(self.encoder)
  293. # just fine-tuning the prediction head
  294. model = self.prediction_model_head.decoder
  295. # setup optimizer and scheduler
  296. optimizer = self.f_optimizer(model.parameters(), lr=self.f_lr, weight_decay=self.f_weight_decay)
  297. scheduler = self._get_scheduler('finetune', optimizer)
  298. for epoch in range(self.f_epoch):
  299. model.train()
  300. for data in train_loader:
  301. optimizer.zero_grad()
  302. data = data.to(self.device)
  303. embeds = self.encoder.encoder(data)
  304. out = model(embeds, data)
  305. loss = self.f_loss(out, data.y)
  306. loss.backward()
  307. optimizer.step()
  308. if self.f_lr_scheduler_type:
  309. scheduler.step()
  310. - To implement the full model, we also need to implement the
  311. ``_predict_only()`` function to evaluate the effect of the model.
  312. .. code:: python
  313. def _predict_only(self, loader, return_label=False):
  314. model = self._compose_model()
  315. model.eval()
  316. pred = []
  317. label = []
  318. for data in loader:
  319. data = data.to(self.device)
  320. out = model(data)
  321. pred.append(out)
  322. label.append(data.y)
  323. ret = torch.cat(pred, 0)
  324. label = torch.cat(label, 0)
  325. if return_label:
  326. return ret, label
  327. else:
  328. return ret
  329. - ``duplicate_from_hyper_parameter`` is a method that could
  330. re-generate the trainer. However, if you don't want to use a
  331. solver to search a good hyper-parameters automatically, you don't
  332. need to implement it in fact.
  333. .. code:: python
  334. def duplicate_from_hyper_parameter(self, hp, encoder="same", decoder="same", prediction_head="same", restricted=True):
  335. hp_trainer = hp.get("trainer", {})
  336. hp_encoder = hp.get("encoder", {})
  337. hp_decoder = hp.get("decoder", {})
  338. hp_phead = hp.get("prediction_head", {})
  339. if not restricted:
  340. origin_hp = deepcopy(self.hyper_parameters)
  341. origin_hp.update(hp_trainer)
  342. hp = origin_hp
  343. else:
  344. hp = hp_trainer
  345. encoder = encoder if encoder != "same" else self.encoder
  346. decoder = decoder if decoder != "same" else self.decoder
  347. prediction_head = prediction_head if prediction_head != "same" else self.prediction_model_head
  348. encoder = encoder.from_hyper_parameter(hp_encoder)
  349. decoder.output_dimension = tuple(encoder.get_output_dimensions())[-1]
  350. if isinstance(encoder, BaseEncoderMaintainer) and isinstance(decoder, BaseDecoderMaintainer):
  351. decoder = decoder.from_hyper_parameter_and_encoder(hp_decoder, encoder)
  352. if isinstance(encoder, BaseEncoderMaintainer) and isinstance(prediction_head, BaseDecoderMaintainer):
  353. prediction_head = prediction_head.from_hyper_parameter_and_encoder(hp_phead, encoder)
  354. ret = self.__class__(
  355. model=(encoder, decoder),
  356. prediction_model_head=prediction_head,
  357. num_features=self.num_features,
  358. num_classes=self.num_classes,
  359. num_graph_features=self.num_graph_features,
  360. device=self.device,
  361. feval=self.feval,
  362. loss=self.loss,
  363. f_loss=self.f_loss,
  364. views_fn=self.views_fn_opt,
  365. aug_ratio=self.aug_ratio,
  366. z_dim=self.last_dim,
  367. neg_by_crpt=self.neg_by_crpt,
  368. tau=self.tau,
  369. model_path=self.model_path,
  370. num_workers=self.num_workers,
  371. batch_size=hp["batch_size"],
  372. eval_interval=self.eval_interval,
  373. p_optim=self.p_opt_received,
  374. p_lr=hp["p_lr"],
  375. p_lr_scheduler_type=self.p_lr_scheduler_type,
  376. p_epoch=hp["p_epoch"],
  377. p_early_stopping_round=hp["p_early_stopping_round"],
  378. p_weight_decay=hp["p_weight_decay"],
  379. f_optim=self.f_opt_received,
  380. f_lr=hp["f_lr"],
  381. f_lr_scheduler_type=self.f_lr_scheduler_type,
  382. f_epoch=hp["f_epoch"],
  383. f_early_stopping_round=hp["f_early_stopping_round"],
  384. f_weight_decay=hp["f_weight_decay"],
  385. init=True,
  386. *self.args,
  387. **self.kwargs
  388. )
  389. return ret