You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

basic_nn.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. # coding: utf-8
  2. # ================================================================#
  3. # Copyright (C) 2020 Freecss All rights reserved.
  4. #
  5. # File Name :basic_model.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2020/11/21
  9. # Description :
  10. #
  11. # ================================================================#
  12. import os
  13. from typing import Any, Callable, List, Optional, T, Tuple
  14. import numpy
  15. import torch
  16. from torch.utils.data import DataLoader
  17. from ..dataset import ClassificationDataset
  18. from ..utils.logger import print_log
  19. class BasicNN:
  20. """
  21. Wrap NN models into the form of an sklearn estimator.
  22. Parameters
  23. ----------
  24. model : torch.nn.Module
  25. The PyTorch model to be trained or used for prediction.
  26. criterion : torch.nn.Module
  27. The loss function used for training.
  28. optimizer : torch.optim.Optimizer
  29. The optimizer used for training.
  30. device : torch.device, optional
  31. The device on which the model will be trained or used for prediction, by default torch.device("cpu").
  32. batch_size : int, optional
  33. The batch size used for training, by default 32.
  34. num_epochs : int, optional
  35. The number of epochs used for training, by default 1.
  36. stop_loss : Optional[float], optional
  37. The loss value at which to stop training, by default 0.01.
  38. num_workers : int
  39. The number of workers used for loading data, by default 0.
  40. save_interval : Optional[int], optional
  41. The interval at which to save the model during training, by default None.
  42. save_dir : Optional[str], optional
  43. The directory in which to save the model during training, by default None.
  44. transform : Callable[..., Any], optional
  45. A function/transform that takes in an object and returns a transformed version, by default None.
  46. collate_fn : Callable[[List[T]], Any], optional
  47. The function used to collate data, by default None.
  48. """
  49. def __init__(
  50. self,
  51. model: torch.nn.Module,
  52. criterion: torch.nn.Module,
  53. optimizer: torch.optim.Optimizer,
  54. device: torch.device = torch.device("cpu"),
  55. batch_size: int = 32,
  56. num_epochs: int = 1,
  57. stop_loss: Optional[float] = 0.01,
  58. num_workers: int = 0,
  59. save_interval: Optional[int] = None,
  60. save_dir: Optional[str] = None,
  61. transform: Callable[..., Any] = None,
  62. collate_fn: Callable[[List[T]], Any] = None,
  63. ) -> None:
  64. self.model = model.to(device)
  65. self.criterion = criterion
  66. self.optimizer = optimizer
  67. self.device = device
  68. self.batch_size = batch_size
  69. self.num_epochs = num_epochs
  70. self.stop_loss = stop_loss
  71. self.num_workers = num_workers
  72. self.save_interval = save_interval
  73. self.save_dir = save_dir
  74. self.transform = transform
  75. self.collate_fn = collate_fn
  76. def _fit(self, data_loader) -> float:
  77. """
  78. Internal method to fit the model on data for n epochs, with early stopping.
  79. Parameters
  80. ----------
  81. data_loader : DataLoader
  82. Data loader providing training samples.
  83. Returns
  84. -------
  85. float
  86. The loss value of the trained model.
  87. """
  88. loss_value = 1e9
  89. for epoch in range(self.num_epochs):
  90. loss_value = self.train_epoch(data_loader)
  91. if self.save_interval is not None and (epoch + 1) % self.save_interval == 0:
  92. if self.save_dir is None:
  93. raise ValueError("save_dir should not be None if save_interval is not None.")
  94. self.save(epoch + 1)
  95. if self.stop_loss is not None and loss_value < self.stop_loss:
  96. break
  97. return loss_value
  98. def fit(
  99. self, data_loader: DataLoader = None, X: List[Any] = None, y: List[int] = None
  100. ) -> float:
  101. """
  102. Train the model.
  103. Parameters
  104. ----------
  105. data_loader : DataLoader, optional
  106. The data loader used for training, by default None.
  107. X : List[Any], optional
  108. The input data, by default None.
  109. y : List[int], optional
  110. The target data, by default None.
  111. Returns
  112. -------
  113. float
  114. The loss value of the trained model.
  115. """
  116. if data_loader is None:
  117. if X is None:
  118. raise ValueError("data_loader and X can not be None simultaneously.")
  119. else:
  120. data_loader = self._data_loader(X, y)
  121. return self._fit(data_loader)
  122. def train_epoch(self, data_loader: DataLoader) -> float:
  123. """
  124. Train the model for one epoch.
  125. Parameters
  126. ----------
  127. data_loader : DataLoader
  128. The data loader used for training.
  129. Returns
  130. -------
  131. float
  132. The loss value of the trained model.
  133. """
  134. model = self.model
  135. criterion = self.criterion
  136. optimizer = self.optimizer
  137. device = self.device
  138. model.train()
  139. total_loss, total_num = 0.0, 0
  140. for data, target in data_loader:
  141. data, target = data.to(device), target.to(device)
  142. out = model(data)
  143. loss = criterion(out, target)
  144. optimizer.zero_grad()
  145. loss.backward()
  146. optimizer.step()
  147. total_loss += loss.item() * data.size(0)
  148. total_num += data.size(0)
  149. return total_loss / total_num
  150. def _predict(self, data_loader) -> torch.Tensor:
  151. """
  152. Internal method to predict the outputs given a DataLoader.
  153. Parameters
  154. ----------
  155. data_loader : DataLoader
  156. The DataLoader providing input samples.
  157. Returns
  158. -------
  159. torch.Tensor
  160. Raw output from the model.
  161. """
  162. model = self.model
  163. device = self.device
  164. model.eval()
  165. with torch.no_grad():
  166. results = []
  167. for data in data_loader:
  168. data = data.to(device)
  169. out = model(data)
  170. results.append(out)
  171. return torch.cat(results, axis=0)
  172. def predict(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray:
  173. """
  174. Predict the class of the input data.
  175. Parameters
  176. ----------
  177. data_loader : DataLoader, optional
  178. The data loader used for prediction, by default None.
  179. X : List[Any], optional
  180. The input data, by default None.
  181. Returns
  182. -------
  183. numpy.ndarray
  184. The predicted class of the input data.
  185. """
  186. if data_loader is None:
  187. if self.transform is not None:
  188. X = [self.transform(x) for x in X]
  189. data_loader = DataLoader(X, batch_size=self.batch_size)
  190. return self._predict(data_loader).argmax(axis=1).cpu().numpy()
  191. def predict_proba(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray:
  192. """
  193. Predict the probability of each class for the input data.
  194. Parameters
  195. ----------
  196. data_loader : DataLoader, optional
  197. The data loader used for prediction, by default None.
  198. X : List[Any], optional
  199. The input data, by default None.
  200. Returns
  201. -------
  202. numpy.ndarray
  203. The predicted probability of each class for the input data.
  204. """
  205. if data_loader is None:
  206. if self.transform is not None:
  207. X = [self.transform(x) for x in X]
  208. data_loader = DataLoader(X, batch_size=self.batch_size)
  209. return self._predict(data_loader).softmax(axis=1).cpu().numpy()
  210. def _score(self, data_loader) -> Tuple[float, float]:
  211. """
  212. Internal method to compute loss and accuracy for the data provided through a DataLoader.
  213. Parameters
  214. ----------
  215. data_loader : DataLoader
  216. Data loader to use for evaluation.
  217. Returns
  218. -------
  219. Tuple[float, float]
  220. mean_loss: float, The mean loss of the model on the provided data.
  221. accuracy: float, The accuracy of the model on the provided data.
  222. """
  223. model = self.model
  224. criterion = self.criterion
  225. device = self.device
  226. model.eval()
  227. total_correct_num, total_num, total_loss = 0, 0, 0.0
  228. with torch.no_grad():
  229. for data, target in data_loader:
  230. data, target = data.to(device), target.to(device)
  231. out = model(data)
  232. if len(out.shape) > 1:
  233. correct_num = (target == out.argmax(axis=1)).sum().item()
  234. else:
  235. correct_num = (target == (out > 0.5)).sum().item()
  236. loss = criterion(out, target)
  237. total_loss += loss.item() * data.size(0)
  238. total_correct_num += correct_num
  239. total_num += data.size(0)
  240. mean_loss = total_loss / total_num
  241. accuracy = total_correct_num / total_num
  242. return mean_loss, accuracy
  243. def score(
  244. self, data_loader: DataLoader = None, X: List[Any] = None, y: List[int] = None
  245. ) -> float:
  246. """
  247. Validate the model.
  248. Parameters
  249. ----------
  250. data_loader : DataLoader, optional
  251. The data loader used for scoring, by default None.
  252. X : List[Any], optional
  253. The input data, by default None.
  254. y : List[int], optional
  255. The target data, by default None.
  256. Returns
  257. -------
  258. float
  259. The accuracy of the model.
  260. """
  261. print_log("Start machine learning model validation", logger="current")
  262. if data_loader is None:
  263. data_loader = self._data_loader(X, y)
  264. mean_loss, accuracy = self._score(data_loader)
  265. print_log(f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current")
  266. return accuracy
  267. def _data_loader(
  268. self,
  269. X: List[Any],
  270. y: List[int] = None,
  271. shuffle: bool = True,
  272. ) -> DataLoader:
  273. """
  274. Generate a DataLoader for user-provided input and target data.
  275. Parameters
  276. ----------
  277. X : List[Any]
  278. Input samples.
  279. y : List[int], optional
  280. Target labels. If None, dummy labels are created, by default None.
  281. Returns
  282. -------
  283. DataLoader
  284. A DataLoader providing batches of (X, y) pairs.
  285. """
  286. if X is None:
  287. raise ValueError("X should not be None.")
  288. if y is None:
  289. y = [0] * len(X)
  290. if not (len(y) == len(X)):
  291. raise ValueError("X and y should have equal length.")
  292. dataset = ClassificationDataset(X, y, transform=self.transform)
  293. data_loader = DataLoader(
  294. dataset,
  295. batch_size=self.batch_size,
  296. shuffle=shuffle,
  297. num_workers=int(self.num_workers),
  298. collate_fn=self.collate_fn,
  299. )
  300. return data_loader
  301. def save(self, epoch_id: int = 0, save_path: str = None) -> None:
  302. """
  303. Save the model and the optimizer.
  304. Parameters
  305. ----------
  306. epoch_id : int
  307. The epoch id.
  308. save_path : str, optional
  309. The path to save the model, by default None.
  310. """
  311. if self.save_dir is None and save_path is None:
  312. raise ValueError("'save_dir' and 'save_path' should not be None simultaneously.")
  313. if save_path is not None:
  314. if not os.path.exists(os.path.dirname(save_path)):
  315. os.makedirs(os.path.dirname(save_path))
  316. else:
  317. save_path = os.path.join(self.save_dir, f"model_checkpoint_epoch_{epoch_id}.pth")
  318. if not os.path.exists(self.save_dir):
  319. os.makedirs(self.save_dir)
  320. print_log(f"Checkpoints will be saved to {save_path}", logger="current")
  321. save_parma_dic = {
  322. "model": self.model.state_dict(),
  323. "optimizer": self.optimizer.state_dict(),
  324. }
  325. torch.save(save_parma_dic, save_path)
  326. def load(self, load_path: str = "") -> None:
  327. """
  328. Load the model and the optimizer.
  329. Parameters
  330. ----------
  331. load_path : str
  332. The directory to load the model, by default "".
  333. """
  334. if load_path is None:
  335. raise ValueError("Load path should not be None.")
  336. print_log(
  337. f"Loads checkpoint by local backend from path: {load_path}",
  338. logger="current",
  339. )
  340. param_dic = torch.load(load_path)
  341. self.model.load_state_dict(param_dic["model"])
  342. if "optimizer" in param_dic.keys():
  343. self.optimizer.load_state_dict(param_dic["optimizer"])

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.