You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

basic_nn.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. # coding: utf-8
  2. # ================================================================#
  3. # Copyright (C) 2020 Freecss All rights reserved.
  4. #
  5. # File Name :basic_model.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2020/11/21
  9. # Description :
  10. #
  11. # ================================================================#
  12. import torch
  13. import numpy
  14. from torch.utils.data import DataLoader
  15. from ..utils.logger import print_log
  16. from ..dataset import ClassificationDataset
  17. import os
  18. from typing import List, Any, T, Optional, Callable, Tuple
  19. class BasicNN:
  20. """
  21. Wrap NN models into the form of an sklearn estimator.
  22. Parameters
  23. ----------
  24. model : torch.nn.Module
  25. The PyTorch model to be trained or used for prediction.
  26. criterion : torch.nn.Module
  27. The loss function used for training.
  28. optimizer : torch.optim.Optimizer
  29. The optimizer used for training.
  30. device : torch.device, optional
  31. The device on which the model will be trained or used for prediction, by default torch.device("cpu").
  32. batch_size : int, optional
  33. The batch size used for training, by default 32.
  34. num_epochs : int, optional
  35. The number of epochs used for training, by default 1.
  36. stop_loss : Optional[float], optional
  37. The loss value at which to stop training, by default 0.01.
  38. num_workers : int
  39. The number of workers used for loading data, by default 0.
  40. save_interval : Optional[int], optional
  41. The interval at which to save the model during training, by default None.
  42. save_dir : Optional[str], optional
  43. The directory in which to save the model during training, by default None.
  44. transform : Callable[..., Any], optional
  45. A function/transform that takes in an object and returns a transformed version, by default None.
  46. collate_fn : Callable[[List[T]], Any], optional
  47. The function used to collate data, by default None.
  48. """
  49. def __init__(
  50. self,
  51. model: torch.nn.Module,
  52. criterion: torch.nn.Module,
  53. optimizer: torch.optim.Optimizer,
  54. device: torch.device = torch.device("cpu"),
  55. batch_size: int = 32,
  56. num_epochs: int = 1,
  57. stop_loss: Optional[float] = 0.01,
  58. num_workers: int = 0,
  59. save_interval: Optional[int] = None,
  60. save_dir: Optional[str] = None,
  61. transform: Callable[..., Any] = None,
  62. collate_fn: Callable[[List[T]], Any] = None,
  63. ) -> None:
  64. self.model = model.to(device)
  65. self.criterion = criterion
  66. self.optimizer = optimizer
  67. self.device = device
  68. self.batch_size = batch_size
  69. self.num_epochs = num_epochs
  70. self.stop_loss = stop_loss
  71. self.num_workers = num_workers
  72. self.save_interval = save_interval
  73. self.save_dir = save_dir
  74. self.transform = transform
  75. self.collate_fn = collate_fn
  76. def _fit(self, data_loader) -> float:
  77. """
  78. Internal method to fit the model on data for n epochs, with early stopping.
  79. Parameters
  80. ----------
  81. data_loader : DataLoader
  82. Data loader providing training samples.
  83. Returns
  84. -------
  85. float
  86. The loss value of the trained model.
  87. """
  88. loss_value = 1e9
  89. for epoch in range(self.num_epochs):
  90. loss_value = self.train_epoch(data_loader)
  91. if self.save_interval is not None and (epoch + 1) % self.save_interval == 0:
  92. if self.save_dir is None:
  93. raise ValueError(
  94. "save_dir should not be None if save_interval is not None."
  95. )
  96. self.save(epoch + 1)
  97. if self.stop_loss is not None and loss_value < self.stop_loss:
  98. break
  99. return loss_value
  100. def fit(
  101. self, data_loader: DataLoader = None, X: List[Any] = None, y: List[int] = None
  102. ) -> float:
  103. """
  104. Train the model.
  105. Parameters
  106. ----------
  107. data_loader : DataLoader, optional
  108. The data loader used for training, by default None.
  109. X : List[Any], optional
  110. The input data, by default None.
  111. y : List[int], optional
  112. The target data, by default None.
  113. Returns
  114. -------
  115. float
  116. The loss value of the trained model.
  117. """
  118. if data_loader is None:
  119. if X is None:
  120. raise ValueError("data_loader and X can not be None simultaneously.")
  121. else:
  122. data_loader = self._data_loader(X, y)
  123. return self._fit(data_loader)
  124. def train_epoch(self, data_loader: DataLoader) -> float:
  125. """
  126. Train the model for one epoch.
  127. Parameters
  128. ----------
  129. data_loader : DataLoader
  130. The data loader used for training.
  131. Returns
  132. -------
  133. float
  134. The loss value of the trained model.
  135. """
  136. model = self.model
  137. criterion = self.criterion
  138. optimizer = self.optimizer
  139. device = self.device
  140. model.train()
  141. total_loss, total_num = 0.0, 0
  142. for data, target in data_loader:
  143. data, target = data.to(device), target.to(device)
  144. out = model(data)
  145. loss = criterion(out, target)
  146. optimizer.zero_grad()
  147. loss.backward()
  148. optimizer.step()
  149. total_loss += loss.item() * data.size(0)
  150. total_num += data.size(0)
  151. return total_loss / total_num
  152. def _predict(self, data_loader) -> torch.Tensor:
  153. """
  154. Internal method to predict the outputs given a DataLoader.
  155. Parameters
  156. ----------
  157. data_loader : DataLoader
  158. The DataLoader providing input samples.
  159. Returns
  160. -------
  161. torch.Tensor
  162. Raw output from the model.
  163. """
  164. model = self.model
  165. device = self.device
  166. model.eval()
  167. with torch.no_grad():
  168. results = []
  169. for data, _ in data_loader:
  170. data = data.to(device)
  171. out = model(data)
  172. results.append(out)
  173. return torch.cat(results, axis=0)
  174. def predict(
  175. self, data_loader: DataLoader = None, X: List[Any] = None
  176. ) -> numpy.ndarray:
  177. """
  178. Predict the class of the input data.
  179. Parameters
  180. ----------
  181. data_loader : DataLoader, optional
  182. The data loader used for prediction, by default None.
  183. X : List[Any], optional
  184. The input data, by default None.
  185. Returns
  186. -------
  187. numpy.ndarray
  188. The predicted class of the input data.
  189. """
  190. if data_loader is None:
  191. data_loader = self._data_loader(X)
  192. return self._predict(data_loader).argmax(axis=1).cpu().numpy()
  193. def predict_proba(
  194. self, data_loader: DataLoader = None, X: List[Any] = None
  195. ) -> numpy.ndarray:
  196. """
  197. Predict the probability of each class for the input data.
  198. Parameters
  199. ----------
  200. data_loader : DataLoader, optional
  201. The data loader used for prediction, by default None.
  202. X : List[Any], optional
  203. The input data, by default None.
  204. Returns
  205. -------
  206. numpy.ndarray
  207. The predicted probability of each class for the input data.
  208. """
  209. if data_loader is None:
  210. data_loader = self._data_loader(X)
  211. return self._predict(data_loader).softmax(axis=1).cpu().numpy()
  212. def _score(self, data_loader) -> Tuple[float, float]:
  213. """
  214. Internal method to compute loss and accuracy for the data provided through a DataLoader.
  215. Parameters
  216. ----------
  217. data_loader : DataLoader
  218. Data loader to use for evaluation.
  219. Returns
  220. -------
  221. Tuple[float, float]
  222. mean_loss: float, The mean loss of the model on the provided data.
  223. accuracy: float, The accuracy of the model on the provided data.
  224. """
  225. model = self.model
  226. criterion = self.criterion
  227. device = self.device
  228. model.eval()
  229. total_correct_num, total_num, total_loss = 0, 0, 0.0
  230. with torch.no_grad():
  231. for data, target in data_loader:
  232. data, target = data.to(device), target.to(device)
  233. out = model(data)
  234. if len(out.shape) > 1:
  235. correct_num = (target == out.argmax(axis=1)).sum().item()
  236. else:
  237. correct_num = (target == (out > 0.5)).sum().item()
  238. loss = criterion(out, target)
  239. total_loss += loss.item() * data.size(0)
  240. total_correct_num += correct_num
  241. total_num += data.size(0)
  242. mean_loss = total_loss / total_num
  243. accuracy = total_correct_num / total_num
  244. return mean_loss, accuracy
  245. def score(
  246. self, data_loader: DataLoader = None, X: List[Any] = None, y: List[int] = None
  247. ) -> float:
  248. """
  249. Validate the model.
  250. Parameters
  251. ----------
  252. data_loader : DataLoader, optional
  253. The data loader used for scoring, by default None.
  254. X : List[Any], optional
  255. The input data, by default None.
  256. y : List[int], optional
  257. The target data, by default None.
  258. Returns
  259. -------
  260. float
  261. The accuracy of the model.
  262. """
  263. print_log("Start machine learning model validation", logger="current")
  264. if data_loader is None:
  265. data_loader = self._data_loader(X, y)
  266. mean_loss, accuracy = self._score(data_loader)
  267. print_log(
  268. f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current"
  269. )
  270. return accuracy
  271. def _data_loader(
  272. self,
  273. X: List[Any],
  274. y: List[int] = None,
  275. ) -> DataLoader:
  276. """
  277. Generate a DataLoader for user-provided input and target data.
  278. Parameters
  279. ----------
  280. X : List[Any]
  281. Input samples.
  282. y : List[int], optional
  283. Target labels. If None, dummy labels are created, by default None.
  284. Returns
  285. -------
  286. DataLoader
  287. A DataLoader providing batches of (X, y) pairs.
  288. """
  289. if X is None:
  290. raise ValueError("X should not be None.")
  291. if y is None:
  292. y = [0] * len(X)
  293. if not (len(y) == len(X)):
  294. raise ValueError("X and y should have equal length.")
  295. dataset = ClassificationDataset(X, y, transform=self.transform)
  296. data_loader = DataLoader(
  297. dataset,
  298. batch_size=self.batch_size,
  299. shuffle=True,
  300. num_workers=int(self.num_workers),
  301. collate_fn=self.collate_fn,
  302. )
  303. return data_loader
  304. def save(self, epoch_id: int = 0, save_path: str = None) -> None:
  305. """
  306. Save the model and the optimizer.
  307. Parameters
  308. ----------
  309. epoch_id : int
  310. The epoch id.
  311. save_path : str, optional
  312. The path to save the model, by default None.
  313. """
  314. if self.save_dir is None and save_path is None:
  315. raise ValueError(
  316. "'save_dir' and 'save_path' should not be None simultaneously."
  317. )
  318. if save_path is None:
  319. save_path = os.path.join(
  320. self.save_dir, f"model_checkpoint_epoch_{epoch_id}.pth"
  321. )
  322. if not os.path.exists(self.save_dir):
  323. os.makedirs(self.save_dir)
  324. print_log(f"Checkpoints will be saved to {save_path}", logger="current")
  325. save_parma_dic = {
  326. "model": self.model.state_dict(),
  327. "optimizer": self.optimizer.state_dict(),
  328. }
  329. torch.save(save_parma_dic, save_path)
  330. def load(self, load_path: str = "") -> None:
  331. """
  332. Load the model and the optimizer.
  333. Parameters
  334. ----------
  335. load_path : str
  336. The directory to load the model, by default "".
  337. """
  338. if load_path is None:
  339. raise ValueError("Load path should not be None.")
  340. print_log(
  341. f"Loads checkpoint by local backend from path: {load_path}",
  342. logger="current",
  343. )
  344. param_dic = torch.load(load_path)
  345. self.model.load_state_dict(param_dic["model"])
  346. if "optimizer" in param_dic.keys():
  347. self.optimizer.load_state_dict(param_dic["optimizer"])

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.