You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_classification_het.py 16 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. """
  2. Node classification Het Trainer Implementation
  3. """
  4. from . import register_trainer
  5. from .base import BaseNodeClassificationHetTrainer, EarlyStopping
  6. import torch
  7. from torch.optim.lr_scheduler import (
  8. StepLR,
  9. MultiStepLR,
  10. ExponentialLR,
  11. ReduceLROnPlateau,
  12. )
  13. import torch.nn.functional as F
  14. import numpy as np
  15. from ..model import MODEL_DICT, BaseModel
  16. from .evaluation import get_feval, Logloss
  17. from typing import Union
  18. from copy import deepcopy
  19. from sklearn.metrics import f1_score
  20. from ...utils import get_logger
  21. from ...backend import DependentBackend
  22. LOGGER = get_logger("node classification het trainer")
  23. def score(logits, labels):
  24. _, indices = torch.max(logits, dim=1)
  25. prediction = indices.long().cpu().numpy()
  26. labels = labels.cpu().numpy()
  27. accuracy = (prediction == labels).sum() / len(prediction)
  28. micro_f1 = f1_score(labels, prediction, average='micro')
  29. macro_f1 = f1_score(labels, prediction, average='macro')
  30. return accuracy, micro_f1, macro_f1
  31. @register_trainer("NodeClassificationHet")
  32. class NodeClassificationHetTrainer(BaseNodeClassificationHetTrainer):
  33. """
  34. The node classification trainer.
  35. Used to automatically train the node classification problem.
  36. Parameters
  37. ----------
  38. model: ``BaseModel`` or ``str``
  39. The (name of) model used to train and predict.
  40. optimizer: ``Optimizer`` of ``str``
  41. The (name of) optimizer used to train and predict.
  42. lr: ``float``
  43. The learning rate of node classification task.
  44. max_epoch: ``int``
  45. The max number of epochs in training.
  46. early_stopping_round: ``int``
  47. The round of early stop.
  48. device: ``torch.device`` or ``str``
  49. The device where model will be running on.
  50. init: ``bool``
  51. If True(False), the model will (not) be initialized.
  52. """
  53. def __init__(
  54. self,
  55. model: Union[BaseModel, str] = None,
  56. G=None,
  57. meta_paths = None, # Jie
  58. num_features=None,
  59. num_classes=None,
  60. optimizer=None,
  61. lr=None,
  62. max_epoch=None,
  63. early_stopping_round=None,
  64. weight_decay=1e-4,
  65. device="auto",
  66. init=True,
  67. feval=[Logloss],
  68. loss="nll_loss",
  69. lr_scheduler_type=None,
  70. *args,
  71. **kwargs
  72. ):
  73. super().__init__(
  74. model,
  75. G,
  76. meta_paths,
  77. num_features,
  78. num_classes,
  79. device=device,
  80. init=init,
  81. feval=feval,
  82. loss=loss,
  83. )
  84. self.opt_received = optimizer
  85. if type(optimizer) == str and optimizer.lower() == "adam":
  86. self.optimizer = torch.optim.Adam
  87. elif type(optimizer) == str and optimizer.lower() == "sgd":
  88. self.optimizer = torch.optim.SGD
  89. else:
  90. self.optimizer = torch.optim.Adam
  91. self.lr_scheduler_type = lr_scheduler_type
  92. self.lr = lr if lr is not None else 1e-4
  93. self.max_epoch = max_epoch if max_epoch is not None else 100
  94. self.early_stopping_round = (
  95. early_stopping_round if early_stopping_round is not None else 100
  96. )
  97. self.args = args
  98. self.kwargs = kwargs
  99. self.feval = get_feval(feval)
  100. self.weight_decay = weight_decay
  101. self.early_stopping = EarlyStopping(
  102. patience=early_stopping_round, verbose=False
  103. )
  104. self.valid_result = None
  105. self.valid_result_prob = None
  106. self.valid_score = None
  107. self.initialized = False
  108. self.pyg_dgl = DependentBackend.get_backend_name()
  109. self.space = [
  110. {
  111. "parameterName": "max_epoch",
  112. "type": "INTEGER",
  113. "maxValue": 500,
  114. "minValue": 10,
  115. "scalingType": "LINEAR",
  116. },
  117. {
  118. "parameterName": "early_stopping_round",
  119. "type": "INTEGER",
  120. "maxValue": 30,
  121. "minValue": 10,
  122. "scalingType": "LINEAR",
  123. },
  124. {
  125. "parameterName": "lr",
  126. "type": "DOUBLE",
  127. "maxValue": 1e-1,
  128. "minValue": 1e-4,
  129. "scalingType": "LOG",
  130. },
  131. {
  132. "parameterName": "weight_decay",
  133. "type": "DOUBLE",
  134. "maxValue": 1e-2,
  135. "minValue": 1e-4,
  136. "scalingType": "LOG",
  137. },
  138. ]
  139. self.hyperparams = {
  140. "max_epoch": self.max_epoch,
  141. "early_stopping_round": self.early_stopping_round,
  142. "lr": self.lr,
  143. "weight_decay": self.weight_decay,
  144. }
  145. if init is True:
  146. self.initialize()
  147. def initialize(self):
  148. # Initialize the auto model in trainer.
  149. if self.initialized is True:
  150. return
  151. self.initialized = True
  152. self.model.initialize()
  153. def get_model(self):
  154. # Get auto model used in trainer.
  155. return self.model
  156. @classmethod
  157. def get_task_name(cls):
  158. # Get task name, i.e., `NodeClassification`.
  159. return "NodeClassificationHet"
  160. def train_only(self, data, G, train_mask=None):
  161. """
  162. The function of training on the given dataset and mask.
  163. Parameters
  164. ----------
  165. data: The node classification dataset used to be trained. It should consist of masks, including train_mask, and etc.
  166. train_mask: The mask used in training stage.
  167. Returns
  168. -------
  169. self: ``autogl.train.NodeClassificationTrainer``
  170. A reference of current trainer.
  171. """
  172. labels = data["labels"]
  173. labels = labels.to(self.device)
  174. val_mask = data["val_mask"]
  175. val_mask = val_mask.to(self.device)
  176. optimizer = self.optimizer(
  177. self.model.model.parameters(), lr=self.lr, weight_decay=self.weight_decay
  178. )
  179. lr_scheduler_type = self.lr_scheduler_type
  180. if type(lr_scheduler_type) == str and lr_scheduler_type == "steplr":
  181. scheduler = StepLR(optimizer, step_size=100, gamma=0.1)
  182. elif type(lr_scheduler_type) == str and lr_scheduler_type == "multisteplr":
  183. scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)
  184. elif type(lr_scheduler_type) == str and lr_scheduler_type == "exponentiallr":
  185. scheduler = ExponentialLR(optimizer, gamma=0.1)
  186. elif (
  187. type(lr_scheduler_type) == str and lr_scheduler_type == "reducelronplateau"
  188. ):
  189. scheduler = ReduceLROnPlateau(optimizer, "min")
  190. else:
  191. scheduler = None
  192. for epoch in range(1, self.max_epoch):
  193. self.model.model.train()
  194. optimizer.zero_grad()
  195. if hasattr(self.model.model, 'cls_forward'):
  196. logits = self.model.model.cls_forward(G, 'paper')
  197. else:
  198. logits = self.model.model.forward(G, 'paper')
  199. # 手工定义的loss,后面需要改
  200. loss_fcn = torch.nn.CrossEntropyLoss()
  201. # print(logits[train_mask].size()) # torch.Size([1, 4025, 3])
  202. # print(labels[train_mask].size()) # torch.Size([1, 4025])
  203. loss = loss_fcn(logits[train_mask], labels[train_mask]) #!!!!!!
  204. # loss = loss_fcn(logits[train_mask][0], labels[train_mask][0]) #!!!!!!
  205. loss.backward()
  206. optimizer.step()
  207. if self.lr_scheduler_type:
  208. scheduler.step()
  209. if val_mask is not None:
  210. if type(self.feval) is list:
  211. feval = self.feval[0]
  212. else:
  213. feval = self.feval
  214. # val_loss = self.evaluate([data], mask=val_mask, feval=feval)
  215. val_loss = self.evaluate(G, labels, val_mask, loss_fcn)[0]
  216. if feval.is_higher_better() is True:
  217. val_loss = -val_loss
  218. self.early_stopping(val_loss, self.model.model)
  219. if self.early_stopping.early_stop:
  220. LOGGER.debug("Early stopping at %d", epoch)
  221. break
  222. if hasattr(data, "val_mask") and data.val_mask is not None:
  223. self.early_stopping.load_checkpoint(self.model.model)
  224. def predict_only(self, G, mask=None):
  225. """
  226. The function of predicting on the given dataset and mask.
  227. Parameters
  228. ----------
  229. data: The node classification dataset used to be predicted.
  230. train_mask: The mask used in training stage.
  231. Returns
  232. -------
  233. res: The result of predicting on the given dataset.
  234. """
  235. self.model.model.eval()
  236. with torch.no_grad():
  237. if hasattr(self.model.model, 'cls_forward'):
  238. res = self.model.model.cls_forward(G, 'paper')
  239. else:
  240. res = self.model.model.forward(G, 'paper')
  241. if mask is None:
  242. return res
  243. else:
  244. return res[mask]
  245. def train(self, data, G, keep_valid_result=True, train_mask=None):
  246. """
  247. The function of training on the given dataset and keeping valid result.
  248. Parameters
  249. ----------
  250. dataset: The node classification dataset used to be trained.
  251. keep_valid_result: ``bool``
  252. If True(False), save the validation result after training.
  253. train_mask: The mask for training data
  254. Returns
  255. -------
  256. self: ``autogl.train.NodeClassificationTrainer``
  257. A reference of current trainer.
  258. """
  259. self.train_only(data, G, train_mask)
  260. if keep_valid_result:
  261. # generate labels
  262. val_mask = data["val_mask"]
  263. self.valid_result = self.predict_only(G)[val_mask].max(1)[1]
  264. self.valid_result_prob = self.predict_only(G)[val_mask]
  265. self.valid_score = self.evaluate(
  266. data, mask=val_mask, feval=self.feval
  267. )
  268. # print(self.valid_score)
  269. def predict(self, G, mask=None):
  270. """
  271. The function of predicting on the given dataset.
  272. Parameters
  273. ----------
  274. dataset: The node classification dataset used to be predicted.
  275. mask: ``train``, ``val``, or ``test``.
  276. The dataset mask.
  277. Returns
  278. -------
  279. The prediction result of ``predict_proba``.
  280. """
  281. return self.predict_proba(G, mask=mask, in_log_format=True).max(1)[1]
  282. def predict_proba(self, G, mask=None, in_log_format=False):
  283. """
  284. The function of predicting the probability on the given dataset.
  285. Parameters
  286. ----------
  287. dataset: The node classification dataset used to be predicted.
  288. mask: ``train``, ``val``, ``test``, or ``Tensor``.
  289. The dataset mask.
  290. in_log_format: ``bool``.
  291. If True(False), the probability will (not) be log format.
  292. Returns
  293. -------
  294. The prediction result.
  295. """
  296. # data = dataset[0]
  297. # data = data.to(self.device)
  298. ret = self.predict_only(G, mask)
  299. if in_log_format is True:
  300. return ret
  301. else:
  302. return torch.exp(ret)
  303. def get_valid_predict(self):
  304. # """Get the valid result."""
  305. return self.valid_result
  306. def get_valid_predict_proba(self):
  307. # """Get the valid result (prediction probability)."""
  308. return self.valid_result_prob
  309. def get_valid_score(self, return_major=True):
  310. """
  311. The function of getting the valid score.
  312. Parameters
  313. ----------
  314. return_major: ``bool``.
  315. If True, the return only consists of the major result.
  316. If False, the return consists of the all results.
  317. Returns
  318. -------
  319. result: The valid score in training stage.
  320. """
  321. if isinstance(self.feval, list):
  322. if return_major:
  323. return self.valid_score[0], self.feval[0].is_higher_better()
  324. else:
  325. return self.valid_score, [f.is_higher_better() for f in self.feval]
  326. else:
  327. return self.valid_score, self.feval.is_higher_better()
  328. def __repr__(self) -> str:
  329. import yaml
  330. return yaml.dump(
  331. {
  332. "trainer_name": self.__class__.__name__,
  333. "optimizer": self.optimizer,
  334. "learning_rate": self.lr,
  335. "max_epoch": self.max_epoch,
  336. "early_stopping_round": self.early_stopping_round,
  337. "model": repr(self.model),
  338. }
  339. )
  340. def evaluate(self, G, labels, mask=None, loss_func = None):
  341. """
  342. The function of training on the given dataset and keeping valid result.
  343. Parameters
  344. ----------
  345. dataset: The node classification dataset used to be evaluated.
  346. mask: ``train``, ``val``, or ``test``.
  347. The dataset mask.
  348. feval: ``str``.
  349. The evaluation method used in this function.
  350. Returns
  351. -------
  352. res: The evaluation result on the given dataset.
  353. """
  354. mask = mask.bool()
  355. self.model.model.eval()
  356. with torch.no_grad():
  357. logits = self.model.model(G, 'paper')
  358. loss = loss_func(logits[mask], labels[mask])
  359. accuracy, micro_f1, macro_f1 = score(logits[mask], labels[mask])
  360. return loss, accuracy, micro_f1, macro_f1
  361. def to(self, new_device):
  362. assert isinstance(new_device, torch.device)
  363. self.device = new_device
  364. if self.model is not None:
  365. self.model.to(self.device)
  366. def duplicate_from_hyper_parameter(self, hp: dict, model=None, restricted=True):
  367. """
  368. The function of duplicating a new instance from the given hyperparameter.
  369. Parameters
  370. ----------
  371. hp: ``dict``.
  372. The hyperparameter used in the new instance.
  373. model: The model used in the new instance of trainer.
  374. restricted: ``bool``.
  375. If False(True), the hyperparameter should (not) be updated from origin hyperparameter.
  376. Returns
  377. -------
  378. self: ``autogl.train.NodeClassificationTrainer``
  379. A new instance of trainer.
  380. """
  381. if not restricted:
  382. origin_hp = deepcopy(self.hyperparams)
  383. origin_hp.update(hp)
  384. hp = origin_hp
  385. if model is None:
  386. model = self.model
  387. model = model.from_hyper_parameter(
  388. dict(
  389. [
  390. x
  391. for x in hp.items()
  392. if x[0] in [y["parameterName"] for y in model.space]
  393. ]
  394. )
  395. )
  396. ret = self.__class__(
  397. model=model,
  398. num_features=self.num_features,
  399. num_classes=self.num_classes,
  400. optimizer=self.opt_received,
  401. lr=hp["lr"],
  402. max_epoch=hp["max_epoch"],
  403. early_stopping_round=hp["early_stopping_round"],
  404. device=self.device,
  405. weight_decay=hp["weight_decay"],
  406. feval=self.feval,
  407. loss=self.loss,
  408. lr_scheduler_type=self.lr_scheduler_type,
  409. init=True,
  410. *self.args,
  411. **self.kwargs
  412. )
  413. return ret
  414. @property
  415. def hyper_parameter_space(self):
  416. # """Get the space of hyperparameter."""
  417. return self.space
  418. @hyper_parameter_space.setter
  419. def hyper_parameter_space(self, space):
  420. # """Set the space of hyperparameter."""
  421. self.space = space
  422. def get_hyper_parameter(self):
  423. # """Get the hyperparameter in this trainer."""
  424. return self.hyperparams