You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_classification.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. from . import register_trainer, BaseTrainer, Evaluation, EVALUATE_DICT, EarlyStopping
  2. import torch
  3. from torch.optim.lr_scheduler import StepLR
  4. import torch.nn.functional as F
  5. from ..model import MODEL_DICT, BaseModel
  6. from .evaluate import Logloss, Acc, Auc
  7. from typing import Union
  8. from copy import deepcopy
  9. from ...utils import get_logger
  10. LOGGER = get_logger("node classification trainer")
  11. def get_feval(feval):
  12. if isinstance(feval, str):
  13. return EVALUATE_DICT[feval]
  14. if isinstance(feval, type) and issubclass(feval, Evaluation):
  15. return feval
  16. if isinstance(feval, list):
  17. return [get_feval(f) for f in feval]
  18. raise ValueError("feval argument of type", type(feval), "is not supported!")
  19. @register_trainer("NodeClassification")
  20. class NodeClassificationTrainer(BaseTrainer):
  21. """
  22. The node classification trainer.
  23. Used to automatically train the node classification problem.
  24. Parameters
  25. ----------
  26. model: ``BaseModel`` or ``str``
  27. The (name of) model used to train and predict.
  28. optimizer: ``Optimizer`` of ``str``
  29. The (name of) optimizer used to train and predict.
  30. lr: ``float``
  31. The learning rate of node classification task.
  32. max_epoch: ``int``
  33. The max number of epochs in training.
  34. early_stopping_round: ``int``
  35. The round of early stop.
  36. device: ``torch.device`` or ``str``
  37. The device where model will be running on.
  38. init: ``bool``
  39. If True(False), the model will (not) be initialized.
  40. """
  41. space = None
  42. def __init__(
  43. self,
  44. model: Union[BaseModel, str],
  45. num_features,
  46. num_classes,
  47. optimizer=None,
  48. lr=None,
  49. max_epoch=None,
  50. early_stopping_round=None,
  51. weight_decay=1e-4,
  52. device=None,
  53. init=True,
  54. feval=[Logloss],
  55. loss="nll_loss",
  56. *args,
  57. **kwargs
  58. ):
  59. super(NodeClassificationTrainer, self).__init__(model)
  60. self.loss_type = loss
  61. if device is None:
  62. device = "cpu"
  63. # init model
  64. if isinstance(model, str):
  65. assert model in MODEL_DICT, "Cannot parse model name " + model
  66. self.model = MODEL_DICT[model](num_features, num_classes, device, init=init)
  67. elif isinstance(model, BaseModel):
  68. self.model = model
  69. if type(optimizer) == str and optimizer.lower() == "adam":
  70. self.optimizer = torch.optim.Adam
  71. elif type(optimizer) == str and optimizer.lower() == "sgd":
  72. self.optimizer = torch.optim.SGD
  73. else:
  74. self.optimizer = torch.optim.Adam
  75. self.num_features = num_features
  76. self.num_classes = num_classes
  77. self.lr = lr if lr is not None else 1e-4
  78. self.max_epoch = max_epoch if max_epoch is not None else 100
  79. self.early_stopping_round = (
  80. early_stopping_round if early_stopping_round is not None else 100
  81. )
  82. self.device = device
  83. self.args = args
  84. self.kwargs = kwargs
  85. self.feval = get_feval(feval)
  86. self.weight_decay = weight_decay
  87. self.early_stopping = EarlyStopping(
  88. patience=early_stopping_round, verbose=False
  89. )
  90. self.valid_result = None
  91. self.valid_result_prob = None
  92. self.valid_score = None
  93. self.initialized = False
  94. self.num_features = num_features
  95. self.num_classes = num_classes
  96. self.device = device
  97. self.space = [
  98. {
  99. "parameterName": "max_epoch",
  100. "type": "INTEGER",
  101. "maxValue": 500,
  102. "minValue": 10,
  103. "scalingType": "LINEAR",
  104. },
  105. {
  106. "parameterName": "early_stopping_round",
  107. "type": "INTEGER",
  108. "maxValue": 30,
  109. "minValue": 10,
  110. "scalingType": "LINEAR",
  111. },
  112. {
  113. "parameterName": "lr",
  114. "type": "DOUBLE",
  115. "maxValue": 1e-1,
  116. "minValue": 1e-4,
  117. "scalingType": "LOG",
  118. },
  119. {
  120. "parameterName": "weight_decay",
  121. "type": "DOUBLE",
  122. "maxValue": 1e-2,
  123. "minValue": 1e-4,
  124. "scalingType": "LOG",
  125. },
  126. ]
  127. self.space += self.model.space
  128. NodeClassificationTrainer.space = self.space
  129. self.hyperparams = {
  130. "max_epoch": self.max_epoch,
  131. "early_stopping_round": self.early_stopping_round,
  132. "lr": self.lr,
  133. "weight_decay": self.weight_decay,
  134. }
  135. self.hyperparams = {**self.hyperparams, **self.model.get_hyper_parameter()}
  136. if init is True:
  137. self.initialize()
  138. def initialize(self):
  139. # Initialize the auto model in trainer.
  140. if self.initialized is True:
  141. return
  142. self.initialized = True
  143. self.model.initialize()
  144. def get_model(self):
  145. # Get auto model used in trainer.
  146. return self.model
  147. @classmethod
  148. def get_task_name(cls):
  149. # Get task name, i.e., `NodeClassification`.
  150. return "NodeClassification"
  151. def train_only(self, data, train_mask=None):
  152. """
  153. The function of training on the given dataset and mask.
  154. Parameters
  155. ----------
  156. data: The node classification dataset used to be trained. It should consist of masks, including train_mask, and etc.
  157. train_mask: The mask used in training stage.
  158. Returns
  159. -------
  160. self: ``autogl.train.NodeClassificationTrainer``
  161. A reference of current trainer.
  162. """
  163. data = data.to(self.device)
  164. mask = data.train_mask if train_mask is None else train_mask
  165. optimizer = self.optimizer(
  166. self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay
  167. )
  168. scheduler = StepLR(optimizer, step_size=100, gamma=0.1)
  169. for epoch in range(1, self.max_epoch):
  170. self.model.model.train()
  171. optimizer.zero_grad()
  172. res = self.model.model.forward(data)
  173. if hasattr(F, self.loss_type):
  174. loss = getattr(F, self.loss_type)(res[mask], data.y[mask])
  175. else:
  176. raise TypeError("PyTorch does not support loss type {}".format(self.loss_type))
  177. loss.backward()
  178. optimizer.step()
  179. scheduler.step()
  180. if type(self.feval) is list:
  181. feval = self.feval[0]
  182. else:
  183. feval = self.feval
  184. val_loss = self.evaluate([data], mask=data.val_mask, feval=feval)
  185. if feval.is_higher_better() is True:
  186. val_loss = -val_loss
  187. self.early_stopping(val_loss, self.model.model)
  188. if self.early_stopping.early_stop:
  189. LOGGER.debug("Early stopping at %d", epoch)
  190. self.early_stopping.load_checkpoint(self.model.model)
  191. break
  192. def predict_only(self, data, test_mask=None):
  193. """
  194. The function of predicting on the given dataset and mask.
  195. Parameters
  196. ----------
  197. data: The node classification dataset used to be predicted.
  198. train_mask: The mask used in training stage.
  199. Returns
  200. -------
  201. res: The result of predicting on the given dataset.
  202. """
  203. # mask = data.test_mask if test_mask is None else test_mask
  204. data = data.to(self.device)
  205. self.model.model.eval()
  206. with torch.no_grad():
  207. res = self.model.model.forward(data)
  208. return res
  209. def train(self, dataset, keep_valid_result=True):
  210. """
  211. The function of training on the given dataset and keeping valid result.
  212. Parameters
  213. ----------
  214. dataset: The node classification dataset used to be trained.
  215. keep_valid_result: ``bool``
  216. If True(False), save the validation result after training.
  217. Returns
  218. -------
  219. self: ``autogl.train.NodeClassificationTrainer``
  220. A reference of current trainer.
  221. """
  222. data = dataset[0]
  223. self.train_only(data)
  224. if keep_valid_result:
  225. self.valid_result = self.predict_only(data)[data.val_mask].max(1)[1]
  226. self.valid_result_prob = self.predict_only(data)[data.val_mask]
  227. self.valid_score = self.evaluate(
  228. dataset, mask=data.val_mask, feval=self.feval
  229. )
  230. def predict(self, dataset, mask=None):
  231. """
  232. The function of predicting on the given dataset.
  233. Parameters
  234. ----------
  235. dataset: The node classification dataset used to be predicted.
  236. mask: ``train``, ``val``, or ``test``.
  237. The dataset mask.
  238. Returns
  239. -------
  240. The prediction result of ``predict_proba``.
  241. """
  242. return self.predict_proba(dataset, mask=mask, in_log_format=True).max(1)[1]
  243. def predict_proba(self, dataset, mask=None, in_log_format=False):
  244. """
  245. The function of predicting the probability on the given dataset.
  246. Parameters
  247. ----------
  248. dataset: The node classification dataset used to be predicted.
  249. mask: ``train``, ``val``, or ``test``.
  250. The dataset mask.
  251. in_log_format: ``bool``.
  252. If True(False), the probability will (not) be log format.
  253. Returns
  254. -------
  255. The prediction result.
  256. """
  257. data = dataset[0]
  258. data = data.to(self.device)
  259. if mask is not None:
  260. if mask == "val":
  261. mask = data.val_mask
  262. elif mask == "test":
  263. mask = data.test_mask
  264. elif mask == "train":
  265. mask = data.train_mask
  266. else:
  267. mask = data.test_mask
  268. ret = self.predict_only(data, mask)[mask]
  269. if in_log_format is True:
  270. return ret
  271. else:
  272. return torch.exp(ret)
  273. def get_valid_predict(self):
  274. # """Get the valid result."""
  275. return self.valid_result
  276. def get_valid_predict_proba(self):
  277. # """Get the valid result (prediction probability)."""
  278. return self.valid_result_prob
  279. def get_valid_score(self, return_major=True):
  280. """
  281. The function of getting the valid score.
  282. Parameters
  283. ----------
  284. return_major: ``bool``.
  285. If True, the return only consists of the major result.
  286. If False, the return consists of the all results.
  287. Returns
  288. -------
  289. result: The valid score in training stage.
  290. """
  291. if isinstance(self.feval, list):
  292. if return_major:
  293. return self.valid_score[0], self.feval[0].is_higher_better()
  294. else:
  295. return self.valid_score, [f.is_higher_better() for f in self.feval]
  296. else:
  297. return self.valid_score, self.feval.is_higher_better()
  298. def get_name_with_hp(self):
  299. # """Get the name of hyperparameter."""
  300. name = "-".join(
  301. [
  302. str(self.optimizer),
  303. str(self.lr),
  304. str(self.max_epoch),
  305. str(self.early_stopping_round),
  306. str(self.model),
  307. str(self.device),
  308. ]
  309. )
  310. name = (
  311. name
  312. + "|"
  313. + "-".join(
  314. [
  315. str(x[0]) + "-" + str(x[1])
  316. for x in self.model.get_hyper_parameter().items()
  317. ]
  318. )
  319. )
  320. return name
  321. def evaluate(self, dataset, mask=None, feval=None):
  322. """
  323. The function of training on the given dataset and keeping valid result.
  324. Parameters
  325. ----------
  326. dataset: The node classification dataset used to be evaluated.
  327. mask: ``train``, ``val``, or ``test``.
  328. The dataset mask.
  329. feval: ``str``.
  330. The evaluation method used in this function.
  331. Returns
  332. -------
  333. res: The evaluation result on the given dataset.
  334. """
  335. data = dataset[0]
  336. data = data.to(self.device)
  337. test_mask = mask
  338. if feval is None:
  339. feval = self.feval
  340. else:
  341. feval = get_feval(feval)
  342. if test_mask is None:
  343. test_mask = data.test_mask
  344. elif test_mask == "test":
  345. test_mask = data.test_mask
  346. elif test_mask == "val":
  347. test_mask = data.val_mask
  348. elif test_mask == "train":
  349. test_mask = data.train_mask
  350. y_pred_prob = self.predict_proba(dataset, mask)
  351. y_pred = y_pred_prob.max(1)[1]
  352. y_true = data.y[test_mask]
  353. if not isinstance(feval, list):
  354. feval = [feval]
  355. return_signle = True
  356. else:
  357. return_signle = False
  358. res = []
  359. for f in feval:
  360. try:
  361. res.append(f.evaluate(y_pred_prob, y_true))
  362. except:
  363. res.append(f.evaluate(y_pred_prob.cpu().numpy(), y_true.cpu().numpy()))
  364. if return_signle:
  365. return res[0]
  366. return res
  367. def to(self, new_device):
  368. assert isinstance(new_device, torch.device)
  369. self.device = new_device
  370. if self.model is not None:
  371. self.model.to(self.device)
  372. def duplicate_from_hyper_parameter(self, hp: dict, model=None, restricted=True):
  373. """
  374. The function of duplicating a new instance from the given hyperparameter.
  375. Parameters
  376. ----------
  377. hp: ``dict``.
  378. The hyperparameter used in the new instance.
  379. model: The model used in the new instance of trainer.
  380. restricted: ``bool``.
  381. If False(True), the hyperparameter should (not) be updated from origin hyperparameter.
  382. Returns
  383. -------
  384. self: ``autogl.train.NodeClassificationTrainer``
  385. A new instance of trainer.
  386. """
  387. if not restricted:
  388. origin_hp = deepcopy(self.hyperparams)
  389. origin_hp.update(hp)
  390. hp = origin_hp
  391. if model is None:
  392. model = self.model
  393. model = model.from_hyper_parameter(
  394. dict(
  395. [
  396. x
  397. for x in hp.items()
  398. if x[0] in [y["parameterName"] for y in model.space]
  399. ]
  400. )
  401. )
  402. ret = self.__class__(
  403. model=model,
  404. num_features=self.num_features,
  405. num_classes=self.num_classes,
  406. optimizer=self.optimizer,
  407. lr=hp["lr"],
  408. max_epoch=hp["max_epoch"],
  409. early_stopping_round=hp["early_stopping_round"],
  410. device=self.device,
  411. weight_decay=hp["weight_decay"],
  412. feval=self.feval,
  413. init=True,
  414. *self.args,
  415. **self.kwargs
  416. )
  417. return ret
  418. def set_feval(self, feval):
  419. # """Set the evaluation metrics."""
  420. self.feval = get_feval(feval)
  421. @property
  422. def hyper_parameter_space(self):
  423. # """Get the space of hyperparameter."""
  424. return self.space
  425. @hyper_parameter_space.setter
  426. def hyper_parameter_space(self, space):
  427. # """Set the space of hyperparameter."""
  428. self.space = space
  429. NodeClassificationTrainer.space = space
  430. def get_hyper_parameter(self):
  431. # """Get the hyperparameter in this trainer."""
  432. return self.hyperparams