You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

basic_nn.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. # coding: utf-8
  2. # ================================================================#
  3. # Copyright (C) 2020 Freecss All rights reserved.
  4. #
  5. # File Name :basic_model.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2020/11/21
  9. # Description :
  10. #
  11. # ================================================================#
  12. import os
  13. import logging
  14. from typing import Any, Callable, List, Optional, T, Tuple
  15. import numpy
  16. import torch
  17. from torch.utils.data import DataLoader
  18. from ..dataset import ClassificationDataset, PredictionDataset
  19. from ..utils.logger import print_log
  20. class BasicNN:
  21. """
  22. Wrap NN models into the form of an sklearn estimator.
  23. Parameters
  24. ----------
  25. model : torch.nn.Module
  26. The PyTorch model to be trained or used for prediction.
  27. criterion : torch.nn.Module
  28. The loss function used for training.
  29. optimizer : torch.optim.Optimizer
  30. The optimizer used for training.
  31. device : torch.device, optional
  32. The device on which the model will be trained or used for prediction, by default torch.device("cpu").
  33. batch_size : int, optional
  34. The batch size used for training, by default 32.
  35. num_epochs : int, optional
  36. The number of epochs used for training, by default 1.
  37. stop_loss : Optional[float], optional
  38. The loss value at which to stop training, by default 0.01.
  39. num_workers : int
  40. The number of workers used for loading data, by default 0.
  41. save_interval : Optional[int], optional
  42. The interval at which to save the model during training, by default None.
  43. save_dir : Optional[str], optional
  44. The directory in which to save the model during training, by default None.
  45. train_transform : Callable[..., Any], optional
  46. A function/transform that takes in an object and returns a transformed version used in the `fit` and `train_epoch` methods, by default None.
  47. test_transform : Callable[..., Any], optional
  48. A function/transform that takes in an object and returns a transformed version in the `predict`, `predict_proba` and `score` methods, , by default None.
  49. collate_fn : Callable[[List[T]], Any], optional
  50. The function used to collate data, by default None.
  51. """
  52. def __init__(
  53. self,
  54. model: torch.nn.Module,
  55. criterion: torch.nn.Module,
  56. optimizer: torch.optim.Optimizer,
  57. device: torch.device = torch.device("cpu"),
  58. batch_size: int = 32,
  59. num_epochs: int = 1,
  60. stop_loss: Optional[float] = 0.01,
  61. num_workers: int = 0,
  62. save_interval: Optional[int] = None,
  63. save_dir: Optional[str] = None,
  64. train_transform: Callable[..., Any] = None,
  65. test_transform: Callable[..., Any] = None,
  66. collate_fn: Callable[[List[T]], Any] = None,
  67. ) -> None:
  68. self.model = model.to(device)
  69. self.criterion = criterion
  70. self.optimizer = optimizer
  71. self.device = device
  72. self.batch_size = batch_size
  73. self.num_epochs = num_epochs
  74. self.stop_loss = stop_loss
  75. self.num_workers = num_workers
  76. self.save_interval = save_interval
  77. self.save_dir = save_dir
  78. self.train_transform = train_transform
  79. self.test_transform = test_transform
  80. self.collate_fn = collate_fn
  81. if self.train_transform is not None and self.test_transform is None:
  82. print_log(
  83. "Transform used in the training phase will be used in prediction.",
  84. "current",
  85. level=logging.WARNING,
  86. )
  87. self.test_transform = self.train_transform
  88. def _fit(self, data_loader: DataLoader) -> float:
  89. """
  90. Internal method to fit the model on data for n epochs, with early stopping.
  91. Parameters
  92. ----------
  93. data_loader : DataLoader
  94. Data loader providing training samples.
  95. Returns
  96. -------
  97. float
  98. The loss value of the trained model.
  99. """
  100. loss_value = 1e9
  101. for epoch in range(self.num_epochs):
  102. loss_value = self.train_epoch(data_loader)
  103. if self.save_interval is not None and (epoch + 1) % self.save_interval == 0:
  104. if self.save_dir is None:
  105. raise ValueError("save_dir should not be None if save_interval is not None.")
  106. self.save(epoch + 1)
  107. if self.stop_loss is not None and loss_value < self.stop_loss:
  108. break
  109. return loss_value
  110. def fit(
  111. self, data_loader: DataLoader = None, X: List[Any] = None, y: List[int] = None
  112. ) -> float:
  113. """
  114. Train the model.
  115. Parameters
  116. ----------
  117. data_loader : DataLoader, optional
  118. The data loader used for training, by default None.
  119. X : List[Any], optional
  120. The input data, by default None.
  121. y : List[int], optional
  122. The target data, by default None.
  123. Returns
  124. -------
  125. float
  126. The loss value of the trained model.
  127. """
  128. if data_loader is None:
  129. if X is None:
  130. raise ValueError("data_loader and X can not be None simultaneously.")
  131. else:
  132. data_loader = self._data_loader(X, y)
  133. return self._fit(data_loader)
  134. def train_epoch(self, data_loader: DataLoader) -> float:
  135. """
  136. Train the model for one epoch.
  137. Parameters
  138. ----------
  139. data_loader : DataLoader
  140. The data loader used for training.
  141. Returns
  142. -------
  143. float
  144. The loss value of the trained model.
  145. """
  146. model = self.model
  147. criterion = self.criterion
  148. optimizer = self.optimizer
  149. device = self.device
  150. model.train()
  151. total_loss, total_num = 0.0, 0
  152. for data, target in data_loader:
  153. data, target = data.to(device), target.to(device)
  154. out = model(data)
  155. loss = criterion(out, target)
  156. optimizer.zero_grad()
  157. loss.backward()
  158. optimizer.step()
  159. total_loss += loss.item() * data.size(0)
  160. total_num += data.size(0)
  161. return total_loss / total_num
  162. def _predict(self, data_loader: DataLoader) -> torch.Tensor:
  163. """
  164. Internal method to predict the outputs given a DataLoader.
  165. Parameters
  166. ----------
  167. data_loader : DataLoader
  168. The DataLoader providing input samples.
  169. Returns
  170. -------
  171. torch.Tensor
  172. Raw output from the model.
  173. """
  174. model = self.model
  175. device = self.device
  176. model.eval()
  177. with torch.no_grad():
  178. results = []
  179. for data in data_loader:
  180. data = data.to(device)
  181. out = model(data)
  182. results.append(out)
  183. return torch.cat(results, axis=0)
  184. def predict(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray:
  185. """
  186. Predict the class of the input data.
  187. Parameters
  188. ----------
  189. data_loader : DataLoader, optional
  190. The data loader used for prediction, by default None.
  191. X : List[Any], optional
  192. The input data, by default None.
  193. Returns
  194. -------
  195. numpy.ndarray
  196. The predicted class of the input data.
  197. """
  198. if data_loader is None:
  199. dataset = PredictionDataset(X, self.test_transform)
  200. data_loader = DataLoader(
  201. dataset,
  202. batch_size=self.batch_size,
  203. num_workers=int(self.num_workers),
  204. collate_fn=self.collate_fn,
  205. )
  206. return self._predict(data_loader).argmax(axis=1).cpu().numpy()
  207. def predict_proba(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray:
  208. """
  209. Predict the probability of each class for the input data.
  210. Parameters
  211. ----------
  212. data_loader : DataLoader, optional
  213. The data loader used for prediction, by default None.
  214. X : List[Any], optional
  215. The input data, by default None.
  216. Returns
  217. -------
  218. numpy.ndarray
  219. The predicted probability of each class for the input data.
  220. """
  221. if data_loader is None:
  222. dataset = PredictionDataset(X, self.test_transform)
  223. data_loader = DataLoader(
  224. dataset,
  225. batch_size=self.batch_size,
  226. num_workers=int(self.num_workers),
  227. collate_fn=self.collate_fn,
  228. )
  229. return self._predict(data_loader).softmax(axis=1).cpu().numpy()
  230. def _score(self, data_loader: DataLoader) -> Tuple[float, float]:
  231. """
  232. Internal method to compute loss and accuracy for the data provided through a DataLoader.
  233. Parameters
  234. ----------
  235. data_loader : DataLoader
  236. Data loader to use for evaluation.
  237. Returns
  238. -------
  239. Tuple[float, float]
  240. mean_loss: float, The mean loss of the model on the provided data.
  241. accuracy: float, The accuracy of the model on the provided data.
  242. """
  243. model = self.model
  244. criterion = self.criterion
  245. device = self.device
  246. model.eval()
  247. total_correct_num, total_num, total_loss = 0, 0, 0.0
  248. with torch.no_grad():
  249. for data, target in data_loader:
  250. data, target = data.to(device), target.to(device)
  251. out = model(data)
  252. if len(out.shape) > 1:
  253. correct_num = (target == out.argmax(axis=1)).sum().item()
  254. else:
  255. correct_num = (target == (out > 0.5)).sum().item()
  256. loss = criterion(out, target)
  257. total_loss += loss.item() * data.size(0)
  258. total_correct_num += correct_num
  259. total_num += data.size(0)
  260. mean_loss = total_loss / total_num
  261. accuracy = total_correct_num / total_num
  262. return mean_loss, accuracy
  263. def score(
  264. self, data_loader: DataLoader = None, X: List[Any] = None, y: List[int] = None
  265. ) -> float:
  266. """
  267. Validate the model.
  268. Parameters
  269. ----------
  270. data_loader : DataLoader, optional
  271. The data loader used for scoring, by default None.
  272. X : List[Any], optional
  273. The input data, by default None.
  274. y : List[int], optional
  275. The target data, by default None.
  276. Returns
  277. -------
  278. float
  279. The accuracy of the model.
  280. """
  281. print_log("Start machine learning model validation", logger="current")
  282. if data_loader is None:
  283. data_loader = self._data_loader(X, y)
  284. mean_loss, accuracy = self._score(data_loader)
  285. print_log(f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current")
  286. return accuracy
  287. def _data_loader(self, X: List[Any], y: List[int] = None, shuffle: bool = True) -> DataLoader:
  288. """
  289. Generate a DataLoader for user-provided input and target data.
  290. Parameters
  291. ----------
  292. X : List[Any]
  293. Input samples.
  294. y : List[int], optional
  295. Target labels. If None, dummy labels are created, by default None.
  296. shuffle : bool, optional
  297. Whether to shuffle the data, by default True.
  298. Returns
  299. -------
  300. DataLoader
  301. A DataLoader providing batches of (X, y) pairs.
  302. """
  303. if X is None:
  304. raise ValueError("X should not be None.")
  305. if y is None:
  306. y = [0] * len(X)
  307. if not (len(y) == len(X)):
  308. raise ValueError("X and y should have equal length.")
  309. dataset = ClassificationDataset(X, y, transform=self.train_transform)
  310. data_loader = DataLoader(
  311. dataset,
  312. batch_size=self.batch_size,
  313. shuffle=shuffle,
  314. num_workers=int(self.num_workers),
  315. collate_fn=self.collate_fn,
  316. )
  317. return data_loader
  318. def save(self, epoch_id: int = 0, save_path: str = None) -> None:
  319. """
  320. Save the model and the optimizer.
  321. Parameters
  322. ----------
  323. epoch_id : int
  324. The epoch id.
  325. save_path : str, optional
  326. The path to save the model, by default None.
  327. """
  328. if self.save_dir is None and save_path is None:
  329. raise ValueError("'save_dir' and 'save_path' should not be None simultaneously.")
  330. if save_path is not None:
  331. if not os.path.exists(os.path.dirname(save_path)):
  332. os.makedirs(os.path.dirname(save_path))
  333. else:
  334. save_path = os.path.join(self.save_dir, f"model_checkpoint_epoch_{epoch_id}.pth")
  335. if not os.path.exists(self.save_dir):
  336. os.makedirs(self.save_dir)
  337. print_log(f"Checkpoints will be saved to {save_path}", logger="current")
  338. save_parma_dic = {
  339. "model": self.model.state_dict(),
  340. "optimizer": self.optimizer.state_dict(),
  341. }
  342. torch.save(save_parma_dic, save_path)
  343. def load(self, load_path: str = "") -> None:
  344. """
  345. Load the model and the optimizer.
  346. Parameters
  347. ----------
  348. load_path : str
  349. The directory to load the model, by default "".
  350. """
  351. if load_path is None:
  352. raise ValueError("Load path should not be None.")
  353. print_log(
  354. f"Loads checkpoint by local backend from path: {load_path}",
  355. logger="current",
  356. )
  357. param_dic = torch.load(load_path)
  358. self.model.load_state_dict(param_dic["model"])
  359. if "optimizer" in param_dic.keys():
  360. self.optimizer.load_state_dict(param_dic["optimizer"])

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.