You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_classification_full.py 18 kB

5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534
  1. """
  2. Node classification Full Trainer Implementation
  3. """
  4. from . import register_trainer
  5. from .base import BaseNodeClassificationTrainer, EarlyStopping
  6. import torch
  7. from torch.optim.lr_scheduler import (
  8. StepLR,
  9. MultiStepLR,
  10. ExponentialLR,
  11. ReduceLROnPlateau,
  12. )
  13. import torch.nn.functional as F
  14. from ..model import BaseEncoderMaintainer, BaseDecoderMaintainer, BaseAutoModel
  15. from .evaluation import Evaluation, get_feval, Logloss
  16. from typing import Callable, Iterable, Optional, Tuple, Type, Union
  17. from copy import deepcopy
  18. from ...utils import get_logger
  19. from ...backend import DependentBackend
  20. LOGGER = get_logger("node classification trainer")
  21. @register_trainer("NodeClassificationFull")
  22. class NodeClassificationFullTrainer(BaseNodeClassificationTrainer):
  23. """
  24. The node classification trainer.
  25. Used to automatically train the node classification problem.
  26. Parameters
  27. ----------
  28. model:
  29. Models can be ``str``, ``autogl.module.model.BaseAutoModel``,
  30. ``autogl.module.model.encoders.BaseEncoderMaintainer`` or a tuple of (encoder, decoder)
  31. if need to specify both encoder and decoder. Encoder can be ``str`` or
  32. ``autogl.module.model.encoders.BaseEncoderMaintainer``, and decoder can be ``str``
  33. or ``autogl.module.model.decoders.BaseDecoderMaintainer``.
  34. If only encoder is specified, decoder will be default to "logsoftmax"
  35. num_features: int (Optional)
  36. The number of features in dataset. default None
  37. num_classes: int (Optional)
  38. The number of classes. default None
  39. optimizer: ``Optimizer`` of ``str``
  40. The (name of) optimizer used to train and predict. default torch.optim.Adam
  41. lr: ``float``
  42. The learning rate of node classification task. default 1e-4
  43. max_epoch: ``int``
  44. The max number of epochs in training. default 100
  45. early_stopping_round: ``int``
  46. The round of early stop. default 100
  47. weight_decay: ``float``
  48. weight decay ratio, default 1e-4
  49. device: ``torch.device`` or ``str``
  50. The device where model will be running on.
  51. init: ``bool``
  52. If True(False), the model will (not) be initialized.
  53. feval: (Sequence of) ``Evaluation`` or ``str``
  54. The evaluation functions, default ``[LogLoss]``
  55. loss: ``str``
  56. The loss used. Default ``nll_loss``.
  57. lr_scheduler_type: ``str`` (Optional)
  58. The lr scheduler type used. Default None.
  59. """
  60. def __init__(
  61. self,
  62. model: Union[Tuple[BaseEncoderMaintainer, BaseDecoderMaintainer], BaseEncoderMaintainer, BaseAutoModel, str] = None,
  63. num_features: Optional[int] = None,
  64. num_classes: Optional[int] = None,
  65. optimizer: Union[str, Type[torch.optim.Optimizer]] = torch.optim.Adam,
  66. lr: float = 1e-4,
  67. max_epoch: int = 100,
  68. early_stopping_round: int = 100,
  69. weight_decay: float = 1e-4,
  70. device: Union[torch.device, str] = "auto",
  71. init: bool = False,
  72. feval: Iterable[Type[Evaluation]] =[Logloss],
  73. loss: Union[Callable, str] = "nll_loss",
  74. lr_scheduler_type: Optional[str] = None,
  75. **kwargs
  76. ):
  77. if isinstance(model, Tuple):
  78. encoder, decoder = model
  79. elif isinstance(model, BaseAutoModel):
  80. encoder, decoder = model, None
  81. else:
  82. encoder, decoder = model, "logsoftmax"
  83. super().__init__(
  84. encoder=encoder,
  85. decoder=decoder,
  86. num_features=num_features,
  87. num_classes=num_classes,
  88. device=device,
  89. feval=feval,
  90. loss=loss,
  91. )
  92. self.opt_received = optimizer
  93. if isinstance(optimizer, str):
  94. if optimizer.lower() == "adam": self.optimizer = torch.optim.Adam
  95. elif optimizer.lower() == "sgd": self.optimizer = torch.optim.SGD
  96. else: raise ValueError("Currently not support optimizer {}".format(optimizer))
  97. elif isinstance(optimizer, type) and issubclass(optimizer, torch.optim.Optimizer):
  98. self.optimizer = optimizer
  99. else:
  100. raise ValueError("Currently not support optimizer {}".format(optimizer))
  101. self.lr_scheduler_type = lr_scheduler_type
  102. self.lr = lr
  103. self.max_epoch = max_epoch
  104. self.early_stopping_round = early_stopping_round
  105. self.kwargs = kwargs
  106. self.weight_decay = weight_decay
  107. self.early_stopping = EarlyStopping(
  108. patience=early_stopping_round, verbose=False
  109. )
  110. self.valid_result = None
  111. self.valid_result_prob = None
  112. self.valid_score = None
  113. self.pyg_dgl = DependentBackend.get_backend_name()
  114. self.hyper_parameter_space = [
  115. {
  116. "parameterName": "max_epoch",
  117. "type": "INTEGER",
  118. "maxValue": 500,
  119. "minValue": 10,
  120. "scalingType": "LINEAR",
  121. },
  122. {
  123. "parameterName": "early_stopping_round",
  124. "type": "INTEGER",
  125. "maxValue": 30,
  126. "minValue": 10,
  127. "scalingType": "LINEAR",
  128. },
  129. {
  130. "parameterName": "lr",
  131. "type": "DOUBLE",
  132. "maxValue": 1e-1,
  133. "minValue": 1e-4,
  134. "scalingType": "LOG",
  135. },
  136. {
  137. "parameterName": "weight_decay",
  138. "type": "DOUBLE",
  139. "maxValue": 1e-2,
  140. "minValue": 1e-4,
  141. "scalingType": "LOG",
  142. },
  143. ]
  144. self.hyper_parameters = {
  145. "max_epoch": self.max_epoch,
  146. "early_stopping_round": self.early_stopping_round,
  147. "lr": self.lr,
  148. "weight_decay": self.weight_decay,
  149. }
  150. if init is True:
  151. self.initialize()
  152. @classmethod
  153. def get_task_name(cls):
  154. """
  155. Derive the task name. (NodeClassification)
  156. """
  157. return "NodeClassification"
  158. def __train_only(self, data, train_mask=None):
  159. data = data.to(self.device)
  160. model = self._compose_model()
  161. if train_mask is None:
  162. if self.pyg_dgl == 'pyg':
  163. mask = data.train_mask
  164. elif self.pyg_dgl == 'dgl':
  165. mask = data.ndata['train_mask']
  166. else:
  167. mask = train_mask
  168. optimizer = self.optimizer(
  169. model.parameters(),
  170. lr=self.lr, weight_decay=self.weight_decay
  171. )
  172. # scheduler = StepLR(optimizer, step_size=100, gamma=0.1)
  173. lr_scheduler_type = self.lr_scheduler_type
  174. if type(lr_scheduler_type) == str and lr_scheduler_type == "steplr":
  175. scheduler = StepLR(optimizer, step_size=100, gamma=0.1)
  176. elif type(lr_scheduler_type) == str and lr_scheduler_type == "multisteplr":
  177. scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)
  178. elif type(lr_scheduler_type) == str and lr_scheduler_type == "exponentiallr":
  179. scheduler = ExponentialLR(optimizer, gamma=0.1)
  180. elif (
  181. type(lr_scheduler_type) == str and lr_scheduler_type == "reducelronplateau"
  182. ):
  183. scheduler = ReduceLROnPlateau(optimizer, "min")
  184. else:
  185. scheduler = None
  186. for epoch in range(1, self.max_epoch + 1):
  187. model.train()
  188. optimizer.zero_grad()
  189. res = model(data)
  190. if hasattr(F, self.loss):
  191. if self.pyg_dgl == 'pyg':
  192. loss = getattr(F, self.loss)(res[mask], data.y[mask])
  193. elif self.pyg_dgl == 'dgl':
  194. loss = getattr(F, self.loss)(res[mask], data.ndata['label'][mask])
  195. else:
  196. raise TypeError(
  197. "PyTorch does not support loss type {}".format(self.loss)
  198. )
  199. loss.backward()
  200. optimizer.step()
  201. if self.lr_scheduler_type:
  202. scheduler.step()
  203. # TODO: move this to autogl.backend.utils
  204. if self.pyg_dgl == 'pyg' and hasattr(data, "val_mask") and data.val_mask is not None:
  205. val_mask = data.val_mask
  206. elif self.pyg_dgl == 'dgl' and data.ndata.get('val_mask', None) is not None:
  207. val_mask = data.ndata['val_mask']
  208. else:
  209. val_mask = None
  210. if val_mask is not None:
  211. if type(self.feval) is list:
  212. feval = self.feval[0]
  213. else:
  214. feval = self.feval
  215. val_loss = self.evaluate([data], mask=val_mask, feval=feval)
  216. if feval.is_higher_better() is True:
  217. val_loss = -val_loss
  218. self.early_stopping(val_loss, model)
  219. if self.early_stopping.early_stop:
  220. LOGGER.debug("Early stopping at %d", epoch)
  221. break
  222. if self.pyg_dgl == "pyg" and hasattr(data, "val_mask") and data.val_mask is not None:
  223. self.early_stopping.load_checkpoint(model)
  224. elif self.pyg_dgl == 'dgl' and data.ndata.get('val_mask', None) is not None:
  225. self.early_stopping.load_checkpoint(model)
  226. @torch.no_grad()
  227. def __predict_only(self, data, mask=None):
  228. if isinstance(mask, str):
  229. if self.pyg_dgl == 'pyg':
  230. mask = getattr(data, f'{mask}_mask')
  231. elif self.pyg_dgl == 'dgl':
  232. mask = data.ndata[f'{mask}_mask']
  233. model = self._compose_model()
  234. model.to(self.device)
  235. data = data.to(self.device)
  236. model.eval()
  237. res = model(data)
  238. if mask is None:
  239. return res
  240. else:
  241. return res[mask]
  242. def train(self, dataset, keep_valid_result=True, train_mask=None):
  243. """
  244. Train on the given dataset.
  245. Parameters
  246. ----------
  247. dataset: The node classification dataset used to be trained.
  248. keep_valid_result: ``bool``
  249. If True(False), save the validation result after training.
  250. train_mask: The mask for training data
  251. Returns
  252. -------
  253. self: ``autogl.train.NodeClassificationTrainer``
  254. A reference of current trainer.
  255. """
  256. data = dataset[0]
  257. self.__train_only(data, train_mask)
  258. if keep_valid_result:
  259. if self.pyg_dgl == 'pyg':
  260. val_mask = data.val_mask
  261. elif self.pyg_dgl == 'dgl':
  262. val_mask = data.ndata['val_mask']
  263. else:
  264. assert False
  265. self.valid_result = self.__predict_only(data)[val_mask].max(1)[1]
  266. self.valid_result_prob = self.__predict_only(data)[val_mask]
  267. self.valid_score = self.evaluate(
  268. dataset, mask=val_mask, feval=self.feval
  269. )
  270. def predict(self, dataset, mask=None):
  271. """
  272. Predict on the given dataset using specified mask.
  273. Parameters
  274. ----------
  275. dataset: The node classification dataset used to be predicted.
  276. mask: ``train``, ``val``, or ``test``.
  277. The dataset mask.
  278. Returns
  279. -------
  280. The prediction result.
  281. """
  282. return self.predict_proba(dataset, mask=mask, in_log_format=True).max(1)[1]
  283. def predict_proba(self, dataset, mask=None, in_log_format=False):
  284. """
  285. Predict the probability on the given dataset using specified mask.
  286. Parameters
  287. ----------
  288. dataset: The node classification dataset used to be predicted.
  289. mask: ``train``, ``val``, ``test``, or ``Tensor``.
  290. The dataset mask.
  291. in_log_format: ``bool``.
  292. If True(False), the probability will (not) be log format.
  293. Returns
  294. -------
  295. The prediction result.
  296. """
  297. data = dataset[0]
  298. data = data.to(self.device)
  299. ret = self.__predict_only(data, mask)
  300. if in_log_format is True:
  301. return ret
  302. else:
  303. return torch.exp(ret)
  304. def get_valid_predict(self):
  305. # """Get the valid result."""
  306. return self.valid_result
  307. def get_valid_predict_proba(self):
  308. # """Get the valid result (prediction probability)."""
  309. return self.valid_result_prob
  310. def get_valid_score(self, return_major=True):
  311. """
  312. The function of getting the valid score.
  313. Parameters
  314. ----------
  315. return_major: ``bool``.
  316. If True, the return only consists of the major result.
  317. If False, the return consists of the all results.
  318. Returns
  319. -------
  320. result: The valid score in training stage.
  321. """
  322. if isinstance(self.feval, list):
  323. if return_major:
  324. return self.valid_score[0], self.feval[0].is_higher_better()
  325. else:
  326. return self.valid_score, [f.is_higher_better() for f in self.feval]
  327. else:
  328. return self.valid_score, self.feval.is_higher_better()
  329. def __repr__(self) -> str:
  330. import yaml
  331. return yaml.dump(
  332. {
  333. "trainer_name": self.__class__.__name__,
  334. "optimizer": self.optimizer,
  335. "learning_rate": self.lr,
  336. "max_epoch": self.max_epoch,
  337. "early_stopping_round": self.early_stopping_round,
  338. "encoder": repr(self.encoder),
  339. "decoder": repr(self.decoder)
  340. }
  341. )
  342. def evaluate(self, dataset, mask=None, feval=None):
  343. """
  344. Evaluate on the given dataset.
  345. Parameters
  346. ----------
  347. dataset: The node classification dataset used to be evaluated.
  348. mask: ``train``, ``val``, or ``test``.
  349. The dataset mask.
  350. feval: ``str``.
  351. The evaluation method used in this function.
  352. Returns
  353. -------
  354. res: The evaluation result on the given dataset.
  355. """
  356. data = dataset[0]
  357. data = data.to(self.device)
  358. if isinstance(mask, str):
  359. if self.pyg_dgl == 'pyg':
  360. mask = getattr(data, f'{mask}_mask')
  361. elif self.pyg_dgl == 'dgl':
  362. mask = data.ndata[f'{mask}_mask']
  363. if self.pyg_dgl == 'pyg': label = data.y
  364. elif self.pyg_dgl == 'dgl': label = data.ndata['label']
  365. if feval is None:
  366. feval = self.feval
  367. else:
  368. feval = get_feval(feval)
  369. y_pred_prob = self.predict_proba(dataset, mask)
  370. y_true = label[mask] if mask is not None else label
  371. if not isinstance(feval, list):
  372. feval = [feval]
  373. return_signle = True
  374. else:
  375. return_signle = False
  376. res = []
  377. for f in feval:
  378. try:
  379. res.append(f.evaluate(y_pred_prob, y_true))
  380. except:
  381. res.append(f.evaluate(y_pred_prob.cpu().numpy(), y_true.cpu().numpy()))
  382. if return_signle:
  383. return res[0]
  384. return res
  385. def duplicate_from_hyper_parameter(self, hp: dict, model=None, restricted=True):
  386. """
  387. The function of duplicating a new instance from the given hyperparameter.
  388. Parameters
  389. ----------
  390. hp: ``dict``.
  391. The hyperparameter used in the new instance. Should contain 3 keys "trainer", "encoder"
  392. "decoder", with corresponding hyperparameters as values.
  393. model:
  394. Models can be ``str``, ``autogl.module.model.BaseAutoModel``,
  395. ``autogl.module.model.encoders.BaseEncoderMaintainer`` or a tuple of (encoder, decoder)
  396. if need to specify both encoder and decoder. Encoder can be ``str`` or
  397. ``autogl.module.model.encoders.BaseEncoderMaintainer``, and decoder can be ``str``
  398. or ``autogl.module.model.decoders.BaseDecoderMaintainer``.
  399. restricted: ``bool``.
  400. If False(True), the hyperparameter should (not) be updated from origin hyperparameter.
  401. Returns
  402. -------
  403. self: ``autogl.train.NodeClassificationTrainer``
  404. A new instance of trainer.
  405. """
  406. if isinstance(model, Tuple):
  407. encoder, decoder = model
  408. elif isinstance(model, BaseAutoModel):
  409. encoder, decoder = model, None
  410. elif isinstance(model, BaseEncoderMaintainer):
  411. encoder, decoder = model, self.decoder
  412. elif model is None:
  413. encoder, decoder = self.encoder, self.decoder
  414. else:
  415. raise TypeError("Cannot parse model with type", type(model))
  416. hp_trainer = hp.get("trainer", {})
  417. hp_encoder = hp.get("encoder", {})
  418. hp_decoder = hp.get("decoder", {})
  419. if not restricted:
  420. origin_hp = deepcopy(self.hyper_parameters)
  421. origin_hp.update(hp_trainer)
  422. hp = origin_hp
  423. else:
  424. hp = hp_trainer
  425. encoder = encoder.from_hyper_parameter(hp_encoder)
  426. if isinstance(encoder, BaseEncoderMaintainer) and isinstance(decoder, BaseDecoderMaintainer):
  427. decoder = decoder.from_hyper_parameter_and_encoder(hp_decoder, encoder)
  428. ret = self.__class__(
  429. model=(encoder, decoder),
  430. num_features=self.num_features,
  431. num_classes=self.num_classes,
  432. optimizer=self.opt_received,
  433. lr=hp["lr"],
  434. max_epoch=hp["max_epoch"],
  435. early_stopping_round=hp["early_stopping_round"],
  436. device=self.device,
  437. weight_decay=hp["weight_decay"],
  438. feval=self.feval,
  439. loss=self.loss,
  440. lr_scheduler_type=self.lr_scheduler_type,
  441. init=True,
  442. **self.kwargs
  443. )
  444. return ret