You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_classification.py 17 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543
  1. from . import register_trainer, BaseTrainer, Evaluation, EVALUATE_DICT, EarlyStopping
  2. import torch
  3. from torch.optim.lr_scheduler import (
  4. StepLR,
  5. MultiStepLR,
  6. ExponentialLR,
  7. ReduceLROnPlateau,
  8. )
  9. import torch.nn.functional as F
  10. from ..model import MODEL_DICT, BaseModel
  11. from .evaluate import Logloss, Acc, Auc
  12. from typing import Union
  13. from copy import deepcopy
  14. from ...utils import get_logger
  15. LOGGER = get_logger("node classification trainer")
  16. def get_feval(feval):
  17. if isinstance(feval, str):
  18. return EVALUATE_DICT[feval]
  19. if isinstance(feval, type) and issubclass(feval, Evaluation):
  20. return feval
  21. if isinstance(feval, list):
  22. return [get_feval(f) for f in feval]
  23. raise ValueError("feval argument of type", type(feval), "is not supported!")
  24. @register_trainer("NodeClassification")
  25. class NodeClassificationTrainer(BaseTrainer):
  26. """
  27. The node classification trainer.
  28. Used to automatically train the node classification problem.
  29. Parameters
  30. ----------
  31. model: ``BaseModel`` or ``str``
  32. The (name of) model used to train and predict.
  33. optimizer: ``Optimizer`` of ``str``
  34. The (name of) optimizer used to train and predict.
  35. lr: ``float``
  36. The learning rate of node classification task.
  37. max_epoch: ``int``
  38. The max number of epochs in training.
  39. early_stopping_round: ``int``
  40. The round of early stop.
  41. device: ``torch.device`` or ``str``
  42. The device where model will be running on.
  43. init: ``bool``
  44. If True(False), the model will (not) be initialized.
  45. """
  46. space = None
  47. def __init__(
  48. self,
  49. model: Union[BaseModel, str],
  50. num_features,
  51. num_classes,
  52. optimizer=None,
  53. lr=None,
  54. max_epoch=None,
  55. early_stopping_round=None,
  56. weight_decay=1e-4,
  57. device=None,
  58. init=True,
  59. feval=[Logloss],
  60. loss="nll_loss",
  61. lr_scheduler_type=None,
  62. *args,
  63. **kwargs
  64. ):
  65. super(NodeClassificationTrainer, self).__init__(model)
  66. self.loss_type = loss
  67. if device is None:
  68. device = "cpu"
  69. # init model
  70. if isinstance(model, str):
  71. assert model in MODEL_DICT, "Cannot parse model name " + model
  72. self.model = MODEL_DICT[model](num_features, num_classes, device, init=init)
  73. elif isinstance(model, BaseModel):
  74. self.model = model
  75. self.opt_received = optimizer
  76. if type(optimizer) == str and optimizer.lower() == "adam":
  77. self.optimizer = torch.optim.Adam
  78. elif type(optimizer) == str and optimizer.lower() == "sgd":
  79. self.optimizer = torch.optim.SGD
  80. else:
  81. self.optimizer = torch.optim.Adam
  82. self.lr_scheduler_type = lr_scheduler_type
  83. self.num_features = num_features
  84. self.num_classes = num_classes
  85. self.lr = lr if lr is not None else 1e-4
  86. self.max_epoch = max_epoch if max_epoch is not None else 100
  87. self.early_stopping_round = (
  88. early_stopping_round if early_stopping_round is not None else 100
  89. )
  90. self.device = device
  91. self.args = args
  92. self.kwargs = kwargs
  93. self.feval = get_feval(feval)
  94. self.weight_decay = weight_decay
  95. self.early_stopping = EarlyStopping(
  96. patience=early_stopping_round, verbose=False
  97. )
  98. self.valid_result = None
  99. self.valid_result_prob = None
  100. self.valid_score = None
  101. self.initialized = False
  102. self.num_features = num_features
  103. self.num_classes = num_classes
  104. self.device = device
  105. self.space = [
  106. {
  107. "parameterName": "max_epoch",
  108. "type": "INTEGER",
  109. "maxValue": 500,
  110. "minValue": 10,
  111. "scalingType": "LINEAR",
  112. },
  113. {
  114. "parameterName": "early_stopping_round",
  115. "type": "INTEGER",
  116. "maxValue": 30,
  117. "minValue": 10,
  118. "scalingType": "LINEAR",
  119. },
  120. {
  121. "parameterName": "lr",
  122. "type": "DOUBLE",
  123. "maxValue": 1e-1,
  124. "minValue": 1e-4,
  125. "scalingType": "LOG",
  126. },
  127. {
  128. "parameterName": "weight_decay",
  129. "type": "DOUBLE",
  130. "maxValue": 1e-2,
  131. "minValue": 1e-4,
  132. "scalingType": "LOG",
  133. },
  134. ]
  135. # self.space += self.model.space
  136. NodeClassificationTrainer.space = self.space
  137. self.hyperparams = {
  138. "max_epoch": self.max_epoch,
  139. "early_stopping_round": self.early_stopping_round,
  140. "lr": self.lr,
  141. "weight_decay": self.weight_decay,
  142. }
  143. self.hyperparams = {**self.hyperparams, **self.model.get_hyper_parameter()}
  144. if init is True:
  145. self.initialize()
  146. def initialize(self):
  147. # Initialize the auto model in trainer.
  148. if self.initialized is True:
  149. return
  150. self.initialized = True
  151. self.model.initialize()
  152. def get_model(self):
  153. # Get auto model used in trainer.
  154. return self.model
  155. @classmethod
  156. def get_task_name(cls):
  157. # Get task name, i.e., `NodeClassification`.
  158. return "NodeClassification"
  159. def train_only(self, data, train_mask=None):
  160. """
  161. The function of training on the given dataset and mask.
  162. Parameters
  163. ----------
  164. data: The node classification dataset used to be trained. It should consist of masks, including train_mask, and etc.
  165. train_mask: The mask used in training stage.
  166. Returns
  167. -------
  168. self: ``autogl.train.NodeClassificationTrainer``
  169. A reference of current trainer.
  170. """
  171. data = data.to(self.device)
  172. mask = data.train_mask if train_mask is None else train_mask
  173. optimizer = self.optimizer(
  174. self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay
  175. )
  176. # scheduler = StepLR(optimizer, step_size=100, gamma=0.1)
  177. lr_scheduler_type = self.lr_scheduler_type
  178. if type(lr_scheduler_type) == str and lr_scheduler_type == "steplr":
  179. scheduler = StepLR(optimizer, step_size=100, gamma=0.1)
  180. elif type(lr_scheduler_type) == str and lr_scheduler_type == "multisteplr":
  181. scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)
  182. elif type(lr_scheduler_type) == str and lr_scheduler_type == "exponentiallr":
  183. scheduler = ExponentialLR(optimizer, gamma=0.1)
  184. elif (
  185. type(lr_scheduler_type) == str and lr_scheduler_type == "reducelronplateau"
  186. ):
  187. scheduler = ReduceLROnPlateau(optimizer, "min")
  188. else:
  189. scheduler = None
  190. for epoch in range(1, self.max_epoch):
  191. self.model.model.train()
  192. optimizer.zero_grad()
  193. res = self.model.model.forward(data)
  194. if hasattr(F, self.loss_type):
  195. loss = getattr(F, self.loss_type)(res[mask], data.y[mask])
  196. else:
  197. raise TypeError(
  198. "PyTorch does not support loss type {}".format(self.loss_type)
  199. )
  200. loss.backward()
  201. optimizer.step()
  202. if self.lr_scheduler_type:
  203. scheduler.step()
  204. if type(self.feval) is list:
  205. feval = self.feval[0]
  206. else:
  207. feval = self.feval
  208. val_loss = self.evaluate([data], mask=data.val_mask, feval=feval)
  209. if feval.is_higher_better() is True:
  210. val_loss = -val_loss
  211. self.early_stopping(val_loss, self.model.model)
  212. if self.early_stopping.early_stop:
  213. LOGGER.debug("Early stopping at %d", epoch)
  214. break
  215. self.early_stopping.load_checkpoint(self.model.model)
  216. def predict_only(self, data, test_mask=None):
  217. """
  218. The function of predicting on the given dataset and mask.
  219. Parameters
  220. ----------
  221. data: The node classification dataset used to be predicted.
  222. train_mask: The mask used in training stage.
  223. Returns
  224. -------
  225. res: The result of predicting on the given dataset.
  226. """
  227. # mask = data.test_mask if test_mask is None else test_mask
  228. data = data.to(self.device)
  229. self.model.model.eval()
  230. with torch.no_grad():
  231. res = self.model.model.forward(data)
  232. return res
  233. def train(self, dataset, keep_valid_result=True):
  234. """
  235. The function of training on the given dataset and keeping valid result.
  236. Parameters
  237. ----------
  238. dataset: The node classification dataset used to be trained.
  239. keep_valid_result: ``bool``
  240. If True(False), save the validation result after training.
  241. Returns
  242. -------
  243. self: ``autogl.train.NodeClassificationTrainer``
  244. A reference of current trainer.
  245. """
  246. data = dataset[0]
  247. self.train_only(data)
  248. if keep_valid_result:
  249. self.valid_result = self.predict_only(data)[data.val_mask].max(1)[1]
  250. self.valid_result_prob = self.predict_only(data)[data.val_mask]
  251. self.valid_score = self.evaluate(
  252. dataset, mask=data.val_mask, feval=self.feval
  253. )
  254. def predict(self, dataset, mask=None):
  255. """
  256. The function of predicting on the given dataset.
  257. Parameters
  258. ----------
  259. dataset: The node classification dataset used to be predicted.
  260. mask: ``train``, ``val``, or ``test``.
  261. The dataset mask.
  262. Returns
  263. -------
  264. The prediction result of ``predict_proba``.
  265. """
  266. return self.predict_proba(dataset, mask=mask, in_log_format=True).max(1)[1]
  267. def predict_proba(self, dataset, mask=None, in_log_format=False):
  268. """
  269. The function of predicting the probability on the given dataset.
  270. Parameters
  271. ----------
  272. dataset: The node classification dataset used to be predicted.
  273. mask: ``train``, ``val``, or ``test``.
  274. The dataset mask.
  275. in_log_format: ``bool``.
  276. If True(False), the probability will (not) be log format.
  277. Returns
  278. -------
  279. The prediction result.
  280. """
  281. data = dataset[0]
  282. data = data.to(self.device)
  283. if mask is not None:
  284. if mask == "val":
  285. mask = data.val_mask
  286. elif mask == "test":
  287. mask = data.test_mask
  288. elif mask == "train":
  289. mask = data.train_mask
  290. else:
  291. mask = data.test_mask
  292. ret = self.predict_only(data, mask)[mask]
  293. if in_log_format is True:
  294. return ret
  295. else:
  296. return torch.exp(ret)
  297. def get_valid_predict(self):
  298. # """Get the valid result."""
  299. return self.valid_result
  300. def get_valid_predict_proba(self):
  301. # """Get the valid result (prediction probability)."""
  302. return self.valid_result_prob
  303. def get_valid_score(self, return_major=True):
  304. """
  305. The function of getting the valid score.
  306. Parameters
  307. ----------
  308. return_major: ``bool``.
  309. If True, the return only consists of the major result.
  310. If False, the return consists of the all results.
  311. Returns
  312. -------
  313. result: The valid score in training stage.
  314. """
  315. if isinstance(self.feval, list):
  316. if return_major:
  317. return self.valid_score[0], self.feval[0].is_higher_better()
  318. else:
  319. return self.valid_score, [f.is_higher_better() for f in self.feval]
  320. else:
  321. return self.valid_score, self.feval.is_higher_better()
  322. def get_name_with_hp(self):
  323. # """Get the name of hyperparameter."""
  324. name = "-".join(
  325. [
  326. str(self.optimizer),
  327. str(self.lr),
  328. str(self.max_epoch),
  329. str(self.early_stopping_round),
  330. str(self.model),
  331. str(self.device),
  332. ]
  333. )
  334. name = (
  335. name
  336. + "|"
  337. + "-".join(
  338. [
  339. str(x[0]) + "-" + str(x[1])
  340. for x in self.model.get_hyper_parameter().items()
  341. ]
  342. )
  343. )
  344. return name
  345. def evaluate(self, dataset, mask=None, feval=None):
  346. """
  347. The function of training on the given dataset and keeping valid result.
  348. Parameters
  349. ----------
  350. dataset: The node classification dataset used to be evaluated.
  351. mask: ``train``, ``val``, or ``test``.
  352. The dataset mask.
  353. feval: ``str``.
  354. The evaluation method used in this function.
  355. Returns
  356. -------
  357. res: The evaluation result on the given dataset.
  358. """
  359. data = dataset[0]
  360. data = data.to(self.device)
  361. test_mask = mask
  362. if feval is None:
  363. feval = self.feval
  364. else:
  365. feval = get_feval(feval)
  366. if test_mask is None:
  367. test_mask = data.test_mask
  368. elif test_mask == "test":
  369. test_mask = data.test_mask
  370. elif test_mask == "val":
  371. test_mask = data.val_mask
  372. elif test_mask == "train":
  373. test_mask = data.train_mask
  374. y_pred_prob = self.predict_proba(dataset, mask)
  375. y_pred = y_pred_prob.max(1)[1]
  376. y_true = data.y[test_mask]
  377. if not isinstance(feval, list):
  378. feval = [feval]
  379. return_signle = True
  380. else:
  381. return_signle = False
  382. res = []
  383. for f in feval:
  384. try:
  385. res.append(f.evaluate(y_pred_prob, y_true))
  386. except:
  387. res.append(f.evaluate(y_pred_prob.cpu().numpy(), y_true.cpu().numpy()))
  388. if return_signle:
  389. return res[0]
  390. return res
  391. def to(self, new_device):
  392. assert isinstance(new_device, torch.device)
  393. self.device = new_device
  394. if self.model is not None:
  395. self.model.to(self.device)
  396. def duplicate_from_hyper_parameter(self, hp: dict, model=None, restricted=True):
  397. """
  398. The function of duplicating a new instance from the given hyperparameter.
  399. Parameters
  400. ----------
  401. hp: ``dict``.
  402. The hyperparameter used in the new instance.
  403. model: The model used in the new instance of trainer.
  404. restricted: ``bool``.
  405. If False(True), the hyperparameter should (not) be updated from origin hyperparameter.
  406. Returns
  407. -------
  408. self: ``autogl.train.NodeClassificationTrainer``
  409. A new instance of trainer.
  410. """
  411. if not restricted:
  412. origin_hp = deepcopy(self.hyperparams)
  413. origin_hp.update(hp)
  414. hp = origin_hp
  415. if model is None:
  416. model = self.model
  417. model = model.from_hyper_parameter(
  418. dict(
  419. [
  420. x
  421. for x in hp.items()
  422. if x[0] in [y["parameterName"] for y in model.space]
  423. ]
  424. )
  425. )
  426. ret = self.__class__(
  427. model=model,
  428. num_features=self.num_features,
  429. num_classes=self.num_classes,
  430. optimizer=self.opt_received,
  431. lr=hp["lr"],
  432. max_epoch=hp["max_epoch"],
  433. early_stopping_round=hp["early_stopping_round"],
  434. device=self.device,
  435. weight_decay=hp["weight_decay"],
  436. feval=self.feval,
  437. loss=self.loss_type,
  438. lr_scheduler_type=self.lr_scheduler_type,
  439. init=True,
  440. *self.args,
  441. **self.kwargs
  442. )
  443. return ret
  444. def set_feval(self, feval):
  445. # """Set the evaluation metrics."""
  446. self.feval = get_feval(feval)
  447. @property
  448. def hyper_parameter_space(self):
  449. # """Get the space of hyperparameter."""
  450. return self.space
  451. @hyper_parameter_space.setter
  452. def hyper_parameter_space(self, space):
  453. # """Set the space of hyperparameter."""
  454. self.space = space
  455. NodeClassificationTrainer.space = space
  456. def get_hyper_parameter(self):
  457. # """Get the hyperparameter in this trainer."""
  458. return self.hyperparams