You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

basic_nn.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. # coding: utf-8
  2. # ================================================================#
  3. # Copyright (C) 2020 Freecss All rights reserved.
  4. #
  5. # File Name :basic_model.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2020/11/21
  9. # Description :
  10. #
  11. # ================================================================#
  12. import os
  13. import logging
  14. from typing import Any, Callable, List, Optional, T, Tuple
  15. import numpy
  16. import torch
  17. from torch.utils.data import DataLoader
  18. from ..dataset import ClassificationDataset, PredictionDataset
  19. from ..utils.logger import print_log
  20. class BasicNN:
  21. """
  22. Wrap NN models into the form of an sklearn estimator.
  23. Parameters
  24. ----------
  25. model : torch.nn.Module
  26. The PyTorch model to be trained or used for prediction.
  27. criterion : torch.nn.Module
  28. The loss function used for training.
  29. optimizer : torch.optim.Optimizer
  30. The optimizer used for training.
  31. device : torch.device, optional
  32. The device on which the model will be trained or used for prediction, by default torch.device("cpu").
  33. batch_size : int, optional
  34. The batch size used for training, by default 32.
  35. num_epochs : int, optional
  36. The number of epochs used for training, by default 1.
  37. stop_loss : Optional[float], optional
  38. The loss value at which to stop training, by default 0.01.
  39. num_workers : int
  40. The number of workers used for loading data, by default 0.
  41. save_interval : Optional[int], optional
  42. The interval at which to save the model during training, by default None.
  43. save_dir : Optional[str], optional
  44. The directory in which to save the model during training, by default None.
  45. transform : Callable[..., Any], optional
  46. A function/transform that takes in an object and returns a transformed version, by default None.
  47. collate_fn : Callable[[List[T]], Any], optional
  48. The function used to collate data, by default None.
  49. """
  50. def __init__(
  51. self,
  52. model: torch.nn.Module,
  53. criterion: torch.nn.Module,
  54. optimizer: torch.optim.Optimizer,
  55. device: torch.device = torch.device("cpu"),
  56. batch_size: int = 32,
  57. num_epochs: int = 1,
  58. stop_loss: Optional[float] = 0.01,
  59. num_workers: int = 0,
  60. save_interval: Optional[int] = None,
  61. save_dir: Optional[str] = None,
  62. transform: Callable[..., Any] = None,
  63. collate_fn: Callable[[List[T]], Any] = None,
  64. ) -> None:
  65. self.model = model.to(device)
  66. self.criterion = criterion
  67. self.optimizer = optimizer
  68. self.device = device
  69. self.batch_size = batch_size
  70. self.num_epochs = num_epochs
  71. self.stop_loss = stop_loss
  72. self.num_workers = num_workers
  73. self.save_interval = save_interval
  74. self.save_dir = save_dir
  75. self.transform = transform
  76. self.collate_fn = collate_fn
  77. def _fit(self, data_loader) -> float:
  78. """
  79. Internal method to fit the model on data for n epochs, with early stopping.
  80. Parameters
  81. ----------
  82. data_loader : DataLoader
  83. Data loader providing training samples.
  84. Returns
  85. -------
  86. float
  87. The loss value of the trained model.
  88. """
  89. loss_value = 1e9
  90. for epoch in range(self.num_epochs):
  91. loss_value = self.train_epoch(data_loader)
  92. if self.save_interval is not None and (epoch + 1) % self.save_interval == 0:
  93. if self.save_dir is None:
  94. raise ValueError("save_dir should not be None if save_interval is not None.")
  95. self.save(epoch + 1)
  96. if self.stop_loss is not None and loss_value < self.stop_loss:
  97. break
  98. return loss_value
  99. def fit(
  100. self, data_loader: DataLoader = None, X: List[Any] = None, y: List[int] = None
  101. ) -> float:
  102. """
  103. Train the model.
  104. Parameters
  105. ----------
  106. data_loader : DataLoader, optional
  107. The data loader used for training, by default None.
  108. X : List[Any], optional
  109. The input data, by default None.
  110. y : List[int], optional
  111. The target data, by default None.
  112. Returns
  113. -------
  114. float
  115. The loss value of the trained model.
  116. """
  117. if data_loader is None:
  118. if X is None:
  119. raise ValueError("data_loader and X can not be None simultaneously.")
  120. else:
  121. data_loader = self._data_loader(X, y)
  122. return self._fit(data_loader)
  123. def train_epoch(self, data_loader: DataLoader) -> float:
  124. """
  125. Train the model for one epoch.
  126. Parameters
  127. ----------
  128. data_loader : DataLoader
  129. The data loader used for training.
  130. Returns
  131. -------
  132. float
  133. The loss value of the trained model.
  134. """
  135. model = self.model
  136. criterion = self.criterion
  137. optimizer = self.optimizer
  138. device = self.device
  139. model.train()
  140. total_loss, total_num = 0.0, 0
  141. for data, target in data_loader:
  142. data, target = data.to(device), target.to(device)
  143. out = model(data)
  144. loss = criterion(out, target)
  145. optimizer.zero_grad()
  146. loss.backward()
  147. optimizer.step()
  148. total_loss += loss.item() * data.size(0)
  149. total_num += data.size(0)
  150. return total_loss / total_num
  151. def _predict(self, data_loader) -> torch.Tensor:
  152. """
  153. Internal method to predict the outputs given a DataLoader.
  154. Parameters
  155. ----------
  156. data_loader : DataLoader
  157. The DataLoader providing input samples.
  158. Returns
  159. -------
  160. torch.Tensor
  161. Raw output from the model.
  162. """
  163. model = self.model
  164. device = self.device
  165. model.eval()
  166. with torch.no_grad():
  167. results = []
  168. for data in data_loader:
  169. data = data.to(device)
  170. out = model(data)
  171. results.append(out)
  172. return torch.cat(results, axis=0)
  173. def predict(
  174. self,
  175. data_loader: DataLoader = None,
  176. X: List[Any] = None,
  177. test_transform: Callable[..., Any] = None,
  178. ) -> numpy.ndarray:
  179. """
  180. Predict the class of the input data.
  181. Parameters
  182. ----------
  183. data_loader : DataLoader, optional
  184. The data loader used for prediction, by default None.
  185. X : List[Any], optional
  186. The input data, by default None.
  187. Returns
  188. -------
  189. numpy.ndarray
  190. The predicted class of the input data.
  191. """
  192. if data_loader is None:
  193. if test_transform is None:
  194. print_log(
  195. "Transform used in the training phase will be used in prediction.",
  196. "current",
  197. level=logging.WARNING,
  198. )
  199. dataset = PredictionDataset(X, self.transform)
  200. else:
  201. dataset = PredictionDataset(X, test_transform)
  202. data_loader = DataLoader(
  203. dataset,
  204. batch_size=self.batch_size,
  205. num_workers=int(self.num_workers),
  206. collate_fn=self.collate_fn,
  207. )
  208. return self._predict(data_loader).argmax(axis=1).cpu().numpy()
  209. def predict_proba(
  210. self,
  211. data_loader: DataLoader = None,
  212. X: List[Any] = None,
  213. test_transform: Callable[..., Any] = None,
  214. ) -> numpy.ndarray:
  215. """
  216. Predict the probability of each class for the input data.
  217. Parameters
  218. ----------
  219. data_loader : DataLoader, optional
  220. The data loader used for prediction, by default None.
  221. X : List[Any], optional
  222. The input data, by default None.
  223. Returns
  224. -------
  225. numpy.ndarray
  226. The predicted probability of each class for the input data.
  227. """
  228. if data_loader is None:
  229. if test_transform is None:
  230. print_log(
  231. "Transform used in the training phase will be used in prediction.",
  232. "current",
  233. level=logging.WARNING,
  234. )
  235. dataset = PredictionDataset(X, self.transform)
  236. else:
  237. dataset = PredictionDataset(X, test_transform)
  238. data_loader = DataLoader(
  239. dataset,
  240. batch_size=self.batch_size,
  241. num_workers=int(self.num_workers),
  242. collate_fn=self.collate_fn,
  243. )
  244. return self._predict(data_loader).softmax(axis=1).cpu().numpy()
  245. def _score(self, data_loader) -> Tuple[float, float]:
  246. """
  247. Internal method to compute loss and accuracy for the data provided through a DataLoader.
  248. Parameters
  249. ----------
  250. data_loader : DataLoader
  251. Data loader to use for evaluation.
  252. Returns
  253. -------
  254. Tuple[float, float]
  255. mean_loss: float, The mean loss of the model on the provided data.
  256. accuracy: float, The accuracy of the model on the provided data.
  257. """
  258. model = self.model
  259. criterion = self.criterion
  260. device = self.device
  261. model.eval()
  262. total_correct_num, total_num, total_loss = 0, 0, 0.0
  263. with torch.no_grad():
  264. for data, target in data_loader:
  265. data, target = data.to(device), target.to(device)
  266. out = model(data)
  267. if len(out.shape) > 1:
  268. correct_num = (target == out.argmax(axis=1)).sum().item()
  269. else:
  270. correct_num = (target == (out > 0.5)).sum().item()
  271. loss = criterion(out, target)
  272. total_loss += loss.item() * data.size(0)
  273. total_correct_num += correct_num
  274. total_num += data.size(0)
  275. mean_loss = total_loss / total_num
  276. accuracy = total_correct_num / total_num
  277. return mean_loss, accuracy
  278. def score(
  279. self, data_loader: DataLoader = None, X: List[Any] = None, y: List[int] = None
  280. ) -> float:
  281. """
  282. Validate the model.
  283. Parameters
  284. ----------
  285. data_loader : DataLoader, optional
  286. The data loader used for scoring, by default None.
  287. X : List[Any], optional
  288. The input data, by default None.
  289. y : List[int], optional
  290. The target data, by default None.
  291. Returns
  292. -------
  293. float
  294. The accuracy of the model.
  295. """
  296. print_log("Start machine learning model validation", logger="current")
  297. if data_loader is None:
  298. data_loader = self._data_loader(X, y)
  299. mean_loss, accuracy = self._score(data_loader)
  300. print_log(f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current")
  301. return accuracy
  302. def _data_loader(
  303. self,
  304. X: List[Any],
  305. y: List[int] = None,
  306. shuffle: bool = True,
  307. ) -> DataLoader:
  308. """
  309. Generate a DataLoader for user-provided input and target data.
  310. Parameters
  311. ----------
  312. X : List[Any]
  313. Input samples.
  314. y : List[int], optional
  315. Target labels. If None, dummy labels are created, by default None.
  316. Returns
  317. -------
  318. DataLoader
  319. A DataLoader providing batches of (X, y) pairs.
  320. """
  321. if X is None:
  322. raise ValueError("X should not be None.")
  323. if y is None:
  324. y = [0] * len(X)
  325. if not (len(y) == len(X)):
  326. raise ValueError("X and y should have equal length.")
  327. dataset = ClassificationDataset(X, y, transform=self.transform)
  328. data_loader = DataLoader(
  329. dataset,
  330. batch_size=self.batch_size,
  331. shuffle=shuffle,
  332. num_workers=int(self.num_workers),
  333. collate_fn=self.collate_fn,
  334. )
  335. return data_loader
  336. def save(self, epoch_id: int = 0, save_path: str = None) -> None:
  337. """
  338. Save the model and the optimizer.
  339. Parameters
  340. ----------
  341. epoch_id : int
  342. The epoch id.
  343. save_path : str, optional
  344. The path to save the model, by default None.
  345. """
  346. if self.save_dir is None and save_path is None:
  347. raise ValueError("'save_dir' and 'save_path' should not be None simultaneously.")
  348. if save_path is not None:
  349. if not os.path.exists(os.path.dirname(save_path)):
  350. os.makedirs(os.path.dirname(save_path))
  351. else:
  352. save_path = os.path.join(self.save_dir, f"model_checkpoint_epoch_{epoch_id}.pth")
  353. if not os.path.exists(self.save_dir):
  354. os.makedirs(self.save_dir)
  355. print_log(f"Checkpoints will be saved to {save_path}", logger="current")
  356. save_parma_dic = {
  357. "model": self.model.state_dict(),
  358. "optimizer": self.optimizer.state_dict(),
  359. }
  360. torch.save(save_parma_dic, save_path)
  361. def load(self, load_path: str = "") -> None:
  362. """
  363. Load the model and the optimizer.
  364. Parameters
  365. ----------
  366. load_path : str
  367. The directory to load the model, by default "".
  368. """
  369. if load_path is None:
  370. raise ValueError("Load path should not be None.")
  371. print_log(
  372. f"Loads checkpoint by local backend from path: {load_path}",
  373. logger="current",
  374. )
  375. param_dic = torch.load(load_path)
  376. self.model.load_state_dict(param_dic["model"])
  377. if "optimizer" in param_dic.keys():
  378. self.optimizer.load_state_dict(param_dic["optimizer"])

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.