You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_classification_het.py 16 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. """
  2. Node classification Het Trainer Implementation
  3. """
  4. from . import register_trainer
  5. from .base import BaseNodeClassificationHetTrainer, EarlyStopping
  6. import torch
  7. from torch.optim.lr_scheduler import (
  8. StepLR,
  9. MultiStepLR,
  10. ExponentialLR,
  11. ReduceLROnPlateau,
  12. )
  13. import torch.nn.functional as F
  14. from ..model import BaseAutoModel
  15. from .evaluation import get_feval, Logloss
  16. from typing import Union
  17. from copy import deepcopy
  18. from sklearn.metrics import f1_score
  19. from ...utils import get_logger
  20. from ...backend import DependentBackend
  21. LOGGER = get_logger("node classification het trainer")
  22. def score(logits, labels):
  23. _, indices = torch.max(logits, dim=1)
  24. prediction = indices.long().cpu().numpy()
  25. labels = labels.cpu().numpy()
  26. accuracy = (prediction == labels).sum() / len(prediction)
  27. micro_f1 = f1_score(labels, prediction, average='micro')
  28. macro_f1 = f1_score(labels, prediction, average='macro')
  29. return accuracy, micro_f1, macro_f1
  30. @register_trainer("NodeClassificationHet")
  31. class NodeClassificationHetTrainer(BaseNodeClassificationHetTrainer):
  32. """
  33. The heterogeneous node classification trainer.
  34. Parameters
  35. ----------
  36. model: ``autogl.module.model.BaseAutoModel``
  37. Currently Heterogeneous trainer doesn't support decoupled model setting.
  38. num_features: ``int`` (Optional)
  39. The number of features in dataset. default None
  40. num_classes: ``int`` (Optional)
  41. The number of classes. default None
  42. optimizer: ``Optimizer`` of ``str``
  43. The (name of) optimizer used to train and predict. default torch.optim.Adam
  44. lr: ``float``
  45. The learning rate of node classification task. default 1e-4
  46. max_epoch: ``int``
  47. The max number of epochs in training. default 100
  48. early_stopping_round: ``int``
  49. The round of early stop. default 100
  50. weight_decay: ``float``
  51. weight decay ratio, default 1e-4
  52. device: ``torch.device`` or ``str``
  53. The device where model will be running on.
  54. init: ``bool``
  55. If True(False), the model will (not) be initialized.
  56. feval: (Sequence of) ``Evaluation`` or ``str``
  57. The evaluation functions, default ``[LogLoss]``
  58. loss: ``str``
  59. The loss used. Default ``nll_loss``.
  60. lr_scheduler_type: ``str`` (Optional)
  61. The lr scheduler type used. Default None.
  62. """
  63. def __init__(
  64. self,
  65. model: Union[BaseAutoModel, str] = None,
  66. dataset = None,
  67. num_features=None,
  68. num_classes=None,
  69. optimizer=torch.optim.AdamW,
  70. lr=1e-4,
  71. max_epoch=100,
  72. early_stopping_round=100,
  73. weight_decay=1e-4,
  74. device="auto",
  75. init=False,
  76. feval=[Logloss],
  77. loss="nll_loss",
  78. lr_scheduler_type=None,
  79. *args,
  80. **kwargs
  81. ):
  82. super().__init__(
  83. model,
  84. dataset,
  85. num_features,
  86. num_classes,
  87. device=device,
  88. feval=feval,
  89. loss=loss,
  90. )
  91. self.opt_received = optimizer
  92. if isinstance(optimizer, str):
  93. if optimizer.lower() == "adam": self.optimizer = torch.optim.Adam
  94. elif optimizer.lower() == "sgd": self.optimizer = torch.optim.SGD
  95. else: raise ValueError("Currently not support optimizer {}".format(optimizer))
  96. elif isinstance(optimizer, type) and issubclass(optimizer, torch.optim.Optimizer):
  97. self.optimizer = optimizer
  98. else:
  99. raise ValueError("Currently not support optimizer {}".format(optimizer))
  100. self.lr_scheduler_type = lr_scheduler_type
  101. self.lr = lr
  102. self.max_epoch = max_epoch
  103. self.early_stopping_round = early_stopping_round
  104. self.args = args
  105. self.kwargs = kwargs
  106. self.weight_decay = weight_decay
  107. self.early_stopping = EarlyStopping(
  108. patience=early_stopping_round, verbose=False
  109. )
  110. self.valid_result = None
  111. self.valid_result_prob = None
  112. self.valid_score = None
  113. self.pyg_dgl = DependentBackend.get_backend_name()
  114. self.hyper_parameter_space = [
  115. {
  116. "parameterName": "max_epoch",
  117. "type": "INTEGER",
  118. "maxValue": 500,
  119. "minValue": 10,
  120. "scalingType": "LINEAR",
  121. },
  122. {
  123. "parameterName": "early_stopping_round",
  124. "type": "INTEGER",
  125. "maxValue": 30,
  126. "minValue": 10,
  127. "scalingType": "LINEAR",
  128. },
  129. {
  130. "parameterName": "lr",
  131. "type": "DOUBLE",
  132. "maxValue": 1e-1,
  133. "minValue": 1e-4,
  134. "scalingType": "LOG",
  135. },
  136. {
  137. "parameterName": "weight_decay",
  138. "type": "DOUBLE",
  139. "maxValue": 1e-2,
  140. "minValue": 1e-4,
  141. "scalingType": "LOG",
  142. },
  143. ]
  144. self.hyper_parameters = {
  145. "max_epoch": self.max_epoch,
  146. "early_stopping_round": self.early_stopping_round,
  147. "lr": self.lr,
  148. "weight_decay": self.weight_decay,
  149. }
  150. if init is True:
  151. self.initialize()
  152. def _initialize(self):
  153. self.encoder.initialize()
  154. @classmethod
  155. def get_task_name(cls):
  156. """
  157. Get task name ("NodeClassificationHet")
  158. """
  159. return "NodeClassificationHet"
  160. def _train_only(self, dataset, train_mask="train"):
  161. G = dataset[0].to(self.device)
  162. field = dataset.schema["target_node_type"]
  163. labels = G.nodes[field].data['label'].to(self.device)
  164. train_mask = self._get_mask(dataset, train_mask).to(self.device)
  165. val_mask = self._get_mask(dataset, "val").to(self.device)
  166. model = self.encoder.model.to(self.device)
  167. optimizer = self.optimizer(
  168. model.parameters(), lr=self.lr, weight_decay=self.weight_decay
  169. )
  170. lr_scheduler_type = self.lr_scheduler_type
  171. if type(lr_scheduler_type) == str and lr_scheduler_type == "steplr":
  172. scheduler = StepLR(optimizer, step_size=100, gamma=0.1)
  173. elif type(lr_scheduler_type) == str and lr_scheduler_type == "multisteplr":
  174. scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)
  175. elif type(lr_scheduler_type) == str and lr_scheduler_type == "exponentiallr":
  176. scheduler = ExponentialLR(optimizer, gamma=0.1)
  177. elif (
  178. type(lr_scheduler_type) == str and lr_scheduler_type == "reducelronplateau"
  179. ):
  180. scheduler = ReduceLROnPlateau(optimizer, "min")
  181. else:
  182. scheduler = None
  183. for epoch in range(1, self.max_epoch):
  184. model.train()
  185. optimizer.zero_grad()
  186. logits = model(G)
  187. if hasattr(F, self.loss):
  188. loss = getattr(F, self.loss)(logits[train_mask], labels[train_mask])
  189. else:
  190. raise TypeError(
  191. "PyTorch does not support loss type {}".format(self.loss)
  192. )
  193. loss.backward()
  194. optimizer.step()
  195. if self.lr_scheduler_type:
  196. scheduler.step()
  197. if val_mask is not None:
  198. if type(self.feval) is list:
  199. feval = self.feval[0]
  200. else:
  201. feval = self.feval
  202. val_loss = self.evaluate(dataset, mask=val_mask, feval=feval)
  203. if feval.is_higher_better() is True:
  204. val_loss = -val_loss
  205. self.early_stopping(val_loss, model)
  206. if self.early_stopping.early_stop:
  207. LOGGER.debug("Early stopping at %d", epoch)
  208. break
  209. if val_mask is not None:
  210. self.early_stopping.load_checkpoint(model)
  211. def _predict_only(self, dataset, mask=None):
  212. model = self.encoder.model.to(self.device)
  213. model.eval()
  214. G = dataset[0].to(self.device)
  215. with torch.no_grad():
  216. res = model(G)
  217. if mask is None:
  218. return res
  219. else:
  220. return res[mask]
  221. def train(self, dataset, keep_valid_result=True, train_mask="train"):
  222. """
  223. The function of training on the given dataset and keeping valid result.
  224. Parameters
  225. ----------
  226. dataset: The node classification dataset used to be trained.
  227. keep_valid_result: ``bool``
  228. If True(False), save the validation result after training.
  229. train_mask: The mask for training data
  230. Returns
  231. -------
  232. self: ``autogl.train.NodeClassificationTrainer``
  233. A reference of current trainer.
  234. """
  235. self._train_only(dataset, train_mask)
  236. G = dataset[0].to(self.device)
  237. if keep_valid_result:
  238. # generate labels
  239. val_mask = G.nodes[dataset.schema["target_node_type"]].data["val_mask"]
  240. self.valid_result = self._predict_only(dataset)[val_mask].max(1)[1]
  241. self.valid_result_prob = self._predict_only(dataset)[val_mask]
  242. self.valid_score = self.evaluate(
  243. dataset, mask=val_mask, feval=self.feval
  244. )
  245. # print(self.valid_score)
  246. def predict(self, dataset, mask="test"):
  247. """
  248. The function of predicting on the given dataset.
  249. Parameters
  250. ----------
  251. dataset: The node classification dataset used to be predicted.
  252. mask: ``train``, ``val``, or ``test``.
  253. The dataset mask.
  254. Returns
  255. -------
  256. The prediction result of ``predict_proba``.
  257. """
  258. return self.predict_proba(dataset, mask=mask, in_log_format=True).max(1)[1]
  259. def predict_proba(self, dataset, mask="test", in_log_format=False):
  260. """
  261. The function of predicting the probability on the given dataset.
  262. Parameters
  263. ----------
  264. dataset: The node classification dataset used to be predicted.
  265. mask: ``train``, ``val``, ``test``, or ``Tensor``.
  266. The dataset mask.
  267. in_log_format: ``bool``.
  268. If True(False), the probability will (not) be log format.
  269. Returns
  270. -------
  271. The prediction result.
  272. """
  273. G = dataset[0].to(self.device)
  274. if mask in ["train", "val", "test"]:
  275. mask = G.nodes[dataset.schema["target_node_type"]].data[f"{mask}_mask"]
  276. ret = self._predict_only(dataset, mask)
  277. if in_log_format is True:
  278. return ret
  279. else:
  280. return torch.exp(ret)
  281. def get_valid_predict(self):
  282. # """Get the valid result."""
  283. return self.valid_result
  284. def get_valid_predict_proba(self):
  285. # """Get the valid result (prediction probability)."""
  286. return self.valid_result_prob
  287. def get_valid_score(self, return_major=True):
  288. """
  289. The function of getting the valid score.
  290. Parameters
  291. ----------
  292. return_major: ``bool``.
  293. If True, the return only consists of the major result.
  294. If False, the return consists of the all results.
  295. Returns
  296. -------
  297. result: The valid score in training stage.
  298. """
  299. if isinstance(self.feval, list):
  300. if return_major:
  301. return self.valid_score[0], self.feval[0].is_higher_better()
  302. else:
  303. return self.valid_score, [f.is_higher_better() for f in self.feval]
  304. else:
  305. return self.valid_score, self.feval.is_higher_better()
  306. def __repr__(self) -> str:
  307. import yaml
  308. return yaml.dump(
  309. {
  310. "trainer_name": self.__class__.__name__,
  311. "optimizer": self.optimizer,
  312. "learning_rate": self.lr,
  313. "max_epoch": self.max_epoch,
  314. "early_stopping_round": self.early_stopping_round,
  315. "model": repr(self.model.model),
  316. }
  317. )
  318. def _get_mask(self, dataset, mask):
  319. if mask in ["train", "val", "test"]:
  320. return dataset[0].nodes[dataset.schema["target_node_type"]].data[f"{mask}_mask"]
  321. return mask
  322. def evaluate(self, dataset, mask='val', feval = None):
  323. """
  324. The function of training on the given dataset and keeping valid result.
  325. Parameters
  326. ----------
  327. dataset: The node classification dataset used to be evaluated.
  328. mask: ``train``, ``val``, or ``test``.
  329. The dataset mask.
  330. feval: ``str``.
  331. The evaluation method used in this function.
  332. Returns
  333. -------
  334. res: The evaluation result on the given dataset.
  335. """
  336. G = dataset[0].to(self.device)
  337. mask = self._get_mask(dataset, mask)
  338. label = G.nodes[dataset.schema["target_node_type"]].data['label'].to(self.device)
  339. if feval is None:
  340. feval = self.feval
  341. else:
  342. feval = get_feval(feval)
  343. y_pred_prob = self.predict_proba(dataset, mask)
  344. y_true = label[mask] if mask is not None else label
  345. if not isinstance(feval, list):
  346. feval = [feval]
  347. return_signle = True
  348. else:
  349. return_signle = False
  350. res = []
  351. for f in feval:
  352. try:
  353. res.append(f.evaluate(y_pred_prob, y_true))
  354. except:
  355. res.append(f.evaluate(y_pred_prob.cpu().numpy(), y_true.cpu().numpy()))
  356. if return_signle:
  357. return res[0]
  358. return res
  359. def to(self, new_device):
  360. self.device = new_device
  361. if self.model is not None:
  362. self.model.to(self.device)
  363. def duplicate_from_hyper_parameter(self, hp: dict, model=None, restricted=True):
  364. """
  365. The function of duplicating a new instance from the given hyperparameter.
  366. Parameters
  367. ----------
  368. hp: ``dict``.
  369. The hyperparameter used in the new instance. Should contain 2 keys "trainer", "encoder"
  370. with corresponding hyperparameters as values.
  371. model: ``autogl.module.model.BaseAutoModel``
  372. Currently Heterogeneous trainer doesn't support decoupled model setting.
  373. If only encoder is specified, decoder will be default to "logsoftmax"
  374. restricted: ``bool``.
  375. If False(True), the hyperparameter should (not) be updated from origin hyperparameter.
  376. Returns
  377. -------
  378. self: ``autogl.train.NodeClassificationTrainer``
  379. A new instance of trainer.
  380. """
  381. trainer_hp = hp["trainer"]
  382. model_hp = hp["encoder"]
  383. if not restricted:
  384. origin_hp = deepcopy(self.hyper_parameters)
  385. origin_hp.update(trainer_hp)
  386. trainer_hp = origin_hp
  387. if model is None:
  388. model = self.model
  389. model = model.from_hyper_parameter(model_hp)
  390. ret = self.__class__(
  391. model=model,
  392. dataset=self._dataset,
  393. num_features=self.num_features,
  394. num_classes=self.num_classes,
  395. optimizer=self.opt_received,
  396. lr=trainer_hp["lr"],
  397. max_epoch=trainer_hp["max_epoch"],
  398. early_stopping_round=trainer_hp["early_stopping_round"],
  399. device=self.device,
  400. weight_decay=trainer_hp["weight_decay"],
  401. feval=self.feval,
  402. loss=self.loss,
  403. lr_scheduler_type=self.lr_scheduler_type,
  404. init=True,
  405. *self.args,
  406. **self.kwargs
  407. )
  408. return ret