You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_classification_full.py 17 kB

5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. """
  2. Node classification Full Trainer Implementation
  3. """
  4. from . import register_trainer
  5. from .base import BaseNodeClassificationTrainer, EarlyStopping
  6. import torch
  7. from torch.optim.lr_scheduler import (
  8. StepLR,
  9. MultiStepLR,
  10. ExponentialLR,
  11. ReduceLROnPlateau,
  12. )
  13. import torch.nn.functional as F
  14. from ..model import MODEL_DICT, BaseModel
  15. from .evaluation import get_feval, Logloss
  16. from typing import Union
  17. from copy import deepcopy
  18. from ...utils import get_logger
  19. from ...backend import DependentBackend
  20. LOGGER = get_logger("node classification trainer")
  21. @register_trainer("NodeClassificationFull")
  22. class NodeClassificationFullTrainer(BaseNodeClassificationTrainer):
  23. """
  24. The node classification trainer.
  25. Used to automatically train the node classification problem.
  26. Parameters
  27. ----------
  28. model: ``BaseModel`` or ``str``
  29. The (name of) model used to train and predict.
  30. optimizer: ``Optimizer`` of ``str``
  31. The (name of) optimizer used to train and predict.
  32. lr: ``float``
  33. The learning rate of node classification task.
  34. max_epoch: ``int``
  35. The max number of epochs in training.
  36. early_stopping_round: ``int``
  37. The round of early stop.
  38. device: ``torch.device`` or ``str``
  39. The device where model will be running on.
  40. init: ``bool``
  41. If True(False), the model will (not) be initialized.
  42. """
  43. def __init__(
  44. self,
  45. model: Union[BaseModel, str] = None,
  46. num_features=None,
  47. num_classes=None,
  48. optimizer=None,
  49. lr=None,
  50. max_epoch=None,
  51. early_stopping_round=None,
  52. weight_decay=1e-4,
  53. device="auto",
  54. init=True,
  55. feval=[Logloss],
  56. loss="nll_loss",
  57. lr_scheduler_type=None,
  58. *args,
  59. **kwargs
  60. ):
  61. super().__init__(
  62. model,
  63. num_features,
  64. num_classes,
  65. device=device,
  66. init=init,
  67. feval=feval,
  68. loss=loss,
  69. )
  70. self.opt_received = optimizer
  71. if type(optimizer) == str and optimizer.lower() == "adam":
  72. self.optimizer = torch.optim.Adam
  73. elif type(optimizer) == str and optimizer.lower() == "sgd":
  74. self.optimizer = torch.optim.SGD
  75. else:
  76. self.optimizer = torch.optim.Adam
  77. self.lr_scheduler_type = lr_scheduler_type
  78. self.lr = lr if lr is not None else 1e-4
  79. self.max_epoch = max_epoch if max_epoch is not None else 100
  80. self.early_stopping_round = (
  81. early_stopping_round if early_stopping_round is not None else 100
  82. )
  83. self.args = args
  84. self.kwargs = kwargs
  85. self.feval = get_feval(feval)
  86. self.weight_decay = weight_decay
  87. self.early_stopping = EarlyStopping(
  88. patience=early_stopping_round, verbose=False
  89. )
  90. self.valid_result = None
  91. self.valid_result_prob = None
  92. self.valid_score = None
  93. self.initialized = False
  94. self.pyg_dgl = DependentBackend.get_backend_name()
  95. self.space = [
  96. {
  97. "parameterName": "max_epoch",
  98. "type": "INTEGER",
  99. "maxValue": 500,
  100. "minValue": 10,
  101. "scalingType": "LINEAR",
  102. },
  103. {
  104. "parameterName": "early_stopping_round",
  105. "type": "INTEGER",
  106. "maxValue": 30,
  107. "minValue": 10,
  108. "scalingType": "LINEAR",
  109. },
  110. {
  111. "parameterName": "lr",
  112. "type": "DOUBLE",
  113. "maxValue": 1e-1,
  114. "minValue": 1e-4,
  115. "scalingType": "LOG",
  116. },
  117. {
  118. "parameterName": "weight_decay",
  119. "type": "DOUBLE",
  120. "maxValue": 1e-2,
  121. "minValue": 1e-4,
  122. "scalingType": "LOG",
  123. },
  124. ]
  125. self.hyperparams = {
  126. "max_epoch": self.max_epoch,
  127. "early_stopping_round": self.early_stopping_round,
  128. "lr": self.lr,
  129. "weight_decay": self.weight_decay,
  130. }
  131. if init is True:
  132. self.initialize()
  133. def initialize(self):
  134. # Initialize the auto model in trainer.
  135. if self.initialized is True:
  136. return
  137. self.initialized = True
  138. self.model.initialize()
  139. def get_model(self):
  140. # Get auto model used in trainer.
  141. return self.model
  142. @classmethod
  143. def get_task_name(cls):
  144. # Get task name, i.e., `NodeClassification`.
  145. return "NodeClassification"
  146. def train_only(self, data, train_mask=None):
  147. """
  148. The function of training on the given dataset and mask.
  149. Parameters
  150. ----------
  151. data: The node classification dataset used to be trained. It should consist of masks, including train_mask, and etc.
  152. train_mask: The mask used in training stage.
  153. Returns
  154. -------
  155. self: ``autogl.train.NodeClassificationTrainer``
  156. A reference of current trainer.
  157. """
  158. data = data.to(self.device)
  159. if train_mask is None:
  160. if self.pyg_dgl == 'pyg':
  161. mask = data.train_mask
  162. elif self.pyg_dgl == 'dgl':
  163. mask = data.ndata['train_mask']
  164. else:
  165. mask = train_mask
  166. optimizer = self.optimizer(
  167. self.model.model.parameters(), lr=self.lr, weight_decay=self.weight_decay
  168. )
  169. # scheduler = StepLR(optimizer, step_size=100, gamma=0.1)
  170. lr_scheduler_type = self.lr_scheduler_type
  171. if type(lr_scheduler_type) == str and lr_scheduler_type == "steplr":
  172. scheduler = StepLR(optimizer, step_size=100, gamma=0.1)
  173. elif type(lr_scheduler_type) == str and lr_scheduler_type == "multisteplr":
  174. scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)
  175. elif type(lr_scheduler_type) == str and lr_scheduler_type == "exponentiallr":
  176. scheduler = ExponentialLR(optimizer, gamma=0.1)
  177. elif (
  178. type(lr_scheduler_type) == str and lr_scheduler_type == "reducelronplateau"
  179. ):
  180. scheduler = ReduceLROnPlateau(optimizer, "min")
  181. else:
  182. scheduler = None
  183. for epoch in range(1, self.max_epoch):
  184. self.model.model.train()
  185. optimizer.zero_grad()
  186. if hasattr(self.model.model, 'cls_forward'):
  187. res = self.model.model.cls_forward(data)
  188. else:
  189. res = self.model.model.forward(data)
  190. if hasattr(F, self.loss):
  191. if self.pyg_dgl == 'pyg':
  192. loss = getattr(F, self.loss)(res[mask], data.y[mask])
  193. elif self.pyg_dgl == 'dgl':
  194. loss = getattr(F, self.loss)(res[mask], data.ndata['label'][mask])
  195. else:
  196. raise TypeError(
  197. "PyTorch does not support loss type {}".format(self.loss)
  198. )
  199. loss.backward()
  200. optimizer.step()
  201. if self.lr_scheduler_type:
  202. scheduler.step()
  203. if self.pyg_dgl == 'pyg' and hasattr(data, "val_mask") and data.val_mask is not None:
  204. val_mask = data.val_mask
  205. elif self.pyg_dgl == 'dgl' and data.ndata.get('val_mask', None) is not None:
  206. val_mask = data.ndata['val_mask']
  207. else:
  208. val_mask = None
  209. if val_mask is not None:
  210. if type(self.feval) is list:
  211. feval = self.feval[0]
  212. else:
  213. feval = self.feval
  214. val_loss = self.evaluate([data], mask=val_mask, feval=feval)
  215. if feval.is_higher_better() is True:
  216. val_loss = -val_loss
  217. self.early_stopping(val_loss, self.model.model)
  218. if self.early_stopping.early_stop:
  219. LOGGER.debug("Early stopping at %d", epoch)
  220. break
  221. if hasattr(data, "val_mask") and data.val_mask is not None:
  222. self.early_stopping.load_checkpoint(self.model.model)
  223. def predict_only(self, data, mask=None):
  224. """
  225. The function of predicting on the given dataset and mask.
  226. Parameters
  227. ----------
  228. data: The node classification dataset used to be predicted.
  229. train_mask: The mask used in training stage.
  230. Returns
  231. -------
  232. res: The result of predicting on the given dataset.
  233. """
  234. if isinstance(mask, str):
  235. if self.pyg_dgl == 'pyg':
  236. mask = getattr(data, f'{mask}_mask')
  237. elif self.pyg_dgl == 'dgl':
  238. mask = data.ndata[f'{mask}_mask']
  239. data = data.to(self.device)
  240. self.model.model.eval()
  241. with torch.no_grad():
  242. if hasattr(self.model.model, 'cls_forward'):
  243. res = self.model.model.cls_forward(data)
  244. else:
  245. res = self.model.model.forward(data)
  246. if mask is None:
  247. return res
  248. else:
  249. return res[mask]
  250. def train(self, dataset, keep_valid_result=True, train_mask=None):
  251. """
  252. The function of training on the given dataset and keeping valid result.
  253. Parameters
  254. ----------
  255. dataset: The node classification dataset used to be trained.
  256. keep_valid_result: ``bool``
  257. If True(False), save the validation result after training.
  258. train_mask: The mask for training data
  259. Returns
  260. -------
  261. self: ``autogl.train.NodeClassificationTrainer``
  262. A reference of current trainer.
  263. """
  264. data = dataset[0]
  265. self.train_only(data, train_mask)
  266. if keep_valid_result:
  267. if self.pyg_dgl == 'pyg':
  268. val_mask = data.val_mask
  269. elif self.pyg_dgl == 'dgl':
  270. val_mask = data.ndata['val_mask']
  271. else:
  272. assert False
  273. self.valid_result = self.predict_only(data)[val_mask].max(1)[1]
  274. self.valid_result_prob = self.predict_only(data)[val_mask]
  275. self.valid_score = self.evaluate(
  276. dataset, mask=val_mask, feval=self.feval
  277. )
  278. # print(self.valid_score)
  279. def predict(self, dataset, mask=None):
  280. """
  281. The function of predicting on the given dataset.
  282. Parameters
  283. ----------
  284. dataset: The node classification dataset used to be predicted.
  285. mask: ``train``, ``val``, or ``test``.
  286. The dataset mask.
  287. Returns
  288. -------
  289. The prediction result of ``predict_proba``.
  290. """
  291. return self.predict_proba(dataset, mask=mask, in_log_format=True).max(1)[1]
  292. def predict_proba(self, dataset, mask=None, in_log_format=False):
  293. """
  294. The function of predicting the probability on the given dataset.
  295. Parameters
  296. ----------
  297. dataset: The node classification dataset used to be predicted.
  298. mask: ``train``, ``val``, ``test``, or ``Tensor``.
  299. The dataset mask.
  300. in_log_format: ``bool``.
  301. If True(False), the probability will (not) be log format.
  302. Returns
  303. -------
  304. The prediction result.
  305. """
  306. data = dataset[0]
  307. data = data.to(self.device)
  308. ret = self.predict_only(data, mask)
  309. if in_log_format is True:
  310. return ret
  311. else:
  312. return torch.exp(ret)
  313. def get_valid_predict(self):
  314. # """Get the valid result."""
  315. return self.valid_result
  316. def get_valid_predict_proba(self):
  317. # """Get the valid result (prediction probability)."""
  318. return self.valid_result_prob
  319. def get_valid_score(self, return_major=True):
  320. """
  321. The function of getting the valid score.
  322. Parameters
  323. ----------
  324. return_major: ``bool``.
  325. If True, the return only consists of the major result.
  326. If False, the return consists of the all results.
  327. Returns
  328. -------
  329. result: The valid score in training stage.
  330. """
  331. if isinstance(self.feval, list):
  332. if return_major:
  333. return self.valid_score[0], self.feval[0].is_higher_better()
  334. else:
  335. return self.valid_score, [f.is_higher_better() for f in self.feval]
  336. else:
  337. return self.valid_score, self.feval.is_higher_better()
  338. def __repr__(self) -> str:
  339. import yaml
  340. return yaml.dump(
  341. {
  342. "trainer_name": self.__class__.__name__,
  343. "optimizer": self.optimizer,
  344. "learning_rate": self.lr,
  345. "max_epoch": self.max_epoch,
  346. "early_stopping_round": self.early_stopping_round,
  347. "model": repr(self.model),
  348. }
  349. )
  350. def evaluate(self, dataset, mask=None, feval=None):
  351. """
  352. The function of training on the given dataset and keeping valid result.
  353. Parameters
  354. ----------
  355. dataset: The node classification dataset used to be evaluated.
  356. mask: ``train``, ``val``, or ``test``.
  357. The dataset mask.
  358. feval: ``str``.
  359. The evaluation method used in this function.
  360. Returns
  361. -------
  362. res: The evaluation result on the given dataset.
  363. """
  364. data = dataset[0]
  365. data = data.to(self.device)
  366. if isinstance(mask, str):
  367. if self.pyg_dgl == 'pyg':
  368. mask = getattr(data, f'{mask}_mask')
  369. elif self.pyg_dgl == 'dgl':
  370. mask = data.ndata[f'{mask}_mask']
  371. if self.pyg_dgl == 'pyg': label = data.y
  372. elif self.pyg_dgl == 'dgl': label = data.ndata['label']
  373. if feval is None:
  374. feval = self.feval
  375. else:
  376. feval = get_feval(feval)
  377. y_pred_prob = self.predict_proba(dataset, mask)
  378. y_true = label[mask] if mask is not None else label
  379. if not isinstance(feval, list):
  380. feval = [feval]
  381. return_signle = True
  382. else:
  383. return_signle = False
  384. res = []
  385. for f in feval:
  386. try:
  387. res.append(f.evaluate(y_pred_prob, y_true))
  388. except:
  389. res.append(f.evaluate(y_pred_prob.cpu().numpy(), y_true.cpu().numpy()))
  390. if return_signle:
  391. return res[0]
  392. return res
  393. def to(self, new_device):
  394. assert isinstance(new_device, torch.device)
  395. self.device = new_device
  396. if self.model is not None:
  397. self.model.to(self.device)
  398. def duplicate_from_hyper_parameter(self, hp: dict, model=None, restricted=True):
  399. """
  400. The function of duplicating a new instance from the given hyperparameter.
  401. Parameters
  402. ----------
  403. hp: ``dict``.
  404. The hyperparameter used in the new instance.
  405. model: The model used in the new instance of trainer.
  406. restricted: ``bool``.
  407. If False(True), the hyperparameter should (not) be updated from origin hyperparameter.
  408. Returns
  409. -------
  410. self: ``autogl.train.NodeClassificationTrainer``
  411. A new instance of trainer.
  412. """
  413. if not restricted:
  414. origin_hp = deepcopy(self.hyperparams)
  415. origin_hp.update(hp)
  416. hp = origin_hp
  417. if model is None:
  418. model = self.model
  419. model = model.from_hyper_parameter(
  420. dict(
  421. [
  422. x
  423. for x in hp.items()
  424. if x[0] in [y["parameterName"] for y in model.space]
  425. ]
  426. )
  427. )
  428. ret = self.__class__(
  429. model=model,
  430. num_features=self.num_features,
  431. num_classes=self.num_classes,
  432. optimizer=self.opt_received,
  433. lr=hp["lr"],
  434. max_epoch=hp["max_epoch"],
  435. early_stopping_round=hp["early_stopping_round"],
  436. device=self.device,
  437. weight_decay=hp["weight_decay"],
  438. feval=self.feval,
  439. loss=self.loss,
  440. lr_scheduler_type=self.lr_scheduler_type,
  441. init=True,
  442. *self.args,
  443. **self.kwargs
  444. )
  445. return ret
  446. @property
  447. def hyper_parameter_space(self):
  448. # """Get the space of hyperparameter."""
  449. return self.space
  450. @hyper_parameter_space.setter
  451. def hyper_parameter_space(self, space):
  452. # """Set the space of hyperparameter."""
  453. self.space = space
  454. def get_hyper_parameter(self):
  455. # """Get the hyperparameter in this trainer."""
  456. return self.hyperparams