You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

basic_nn.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. import logging
  2. import os
  3. from typing import Any, Callable, List, Optional, T, Tuple
  4. import numpy
  5. import torch
  6. from torch.utils.data import DataLoader
  7. from ..dataset import ClassificationDataset, PredictionDataset
  8. from ..utils.logger import print_log
  9. class BasicNN:
  10. """
  11. Wrap NN models into the form of an sklearn estimator.
  12. Parameters
  13. ----------
  14. model : torch.nn.Module
  15. The PyTorch model to be trained or used for prediction.
  16. criterion : torch.nn.Module
  17. The loss function used for training.
  18. optimizer : torch.optim.Optimizer
  19. The optimizer used for training.
  20. device : torch.device, optional
  21. The device on which the model will be trained or used for prediction,
  22. by default torch.device("cpu").
  23. batch_size : int, optional
  24. The batch size used for training, by default 32.
  25. num_epochs : int, optional
  26. The number of epochs used for training, by default 1.
  27. stop_loss : Optional[float], optional
  28. The loss value at which to stop training, by default 0.01.
  29. num_workers : int
  30. The number of workers used for loading data, by default 0.
  31. save_interval : Optional[int], optional
  32. The interval at which to save the model during training, by default None.
  33. save_dir : Optional[str], optional
  34. The directory in which to save the model during training, by default None.
  35. train_transform : Callable[..., Any], optional
  36. A function/transform that takes an object and returns a transformed version used
  37. in the `fit` and `train_epoch` methods, by default None.
  38. test_transform : Callable[..., Any], optional
  39. A function/transform that takes an object and returns a transformed version in the
  40. `predict`, `predict_proba` and `score` methods, , by default None.
  41. collate_fn : Callable[[List[T]], Any], optional
  42. The function used to collate data, by default None.
  43. """
  44. def __init__(
  45. self,
  46. model: torch.nn.Module,
  47. criterion: torch.nn.Module,
  48. optimizer: torch.optim.Optimizer,
  49. device: torch.device = torch.device("cpu"),
  50. batch_size: int = 32,
  51. num_epochs: int = 1,
  52. stop_loss: Optional[float] = 0.01,
  53. num_workers: int = 0,
  54. save_interval: Optional[int] = None,
  55. save_dir: Optional[str] = None,
  56. train_transform: Callable[..., Any] = None,
  57. test_transform: Callable[..., Any] = None,
  58. collate_fn: Callable[[List[T]], Any] = None,
  59. ) -> None:
  60. self.model = model.to(device)
  61. self.criterion = criterion
  62. self.optimizer = optimizer
  63. self.device = device
  64. self.batch_size = batch_size
  65. self.num_epochs = num_epochs
  66. self.stop_loss = stop_loss
  67. self.num_workers = num_workers
  68. self.save_interval = save_interval
  69. self.save_dir = save_dir
  70. self.train_transform = train_transform
  71. self.test_transform = test_transform
  72. self.collate_fn = collate_fn
  73. if self.train_transform is not None and self.test_transform is None:
  74. print_log(
  75. "Transform used in the training phase will be used in prediction.",
  76. "current",
  77. level=logging.WARNING,
  78. )
  79. self.test_transform = self.train_transform
  80. def _fit(self, data_loader: DataLoader) -> float:
  81. """
  82. Internal method to fit the model on data for n epochs, with early stopping.
  83. Parameters
  84. ----------
  85. data_loader : DataLoader
  86. Data loader providing training samples.
  87. Returns
  88. -------
  89. float
  90. The loss value of the trained model.
  91. """
  92. loss_value = 1e9
  93. for epoch in range(self.num_epochs):
  94. loss_value = self.train_epoch(data_loader)
  95. if self.save_interval is not None and (epoch + 1) % self.save_interval == 0:
  96. if self.save_dir is None:
  97. raise ValueError("save_dir should not be None if save_interval is not None.")
  98. self.save(epoch + 1)
  99. if self.stop_loss is not None and loss_value < self.stop_loss:
  100. break
  101. return loss_value
  102. def fit(
  103. self, data_loader: DataLoader = None, X: List[Any] = None, y: List[int] = None
  104. ) -> float:
  105. """
  106. Train the model.
  107. Parameters
  108. ----------
  109. data_loader : DataLoader, optional
  110. The data loader used for training, by default None.
  111. X : List[Any], optional
  112. The input data, by default None.
  113. y : List[int], optional
  114. The target data, by default None.
  115. Returns
  116. -------
  117. float
  118. The loss value of the trained model.
  119. """
  120. if data_loader is None:
  121. if X is None:
  122. raise ValueError("data_loader and X can not be None simultaneously.")
  123. else:
  124. data_loader = self._data_loader(X, y)
  125. return self._fit(data_loader)
  126. def train_epoch(self, data_loader: DataLoader) -> float:
  127. """
  128. Train the model for one epoch.
  129. Parameters
  130. ----------
  131. data_loader : DataLoader
  132. The data loader used for training.
  133. Returns
  134. -------
  135. float
  136. The loss value of the trained model.
  137. """
  138. model = self.model
  139. criterion = self.criterion
  140. optimizer = self.optimizer
  141. device = self.device
  142. model.train()
  143. total_loss, total_num = 0.0, 0
  144. for data, target in data_loader:
  145. data, target = data.to(device), target.to(device)
  146. out = model(data)
  147. loss = criterion(out, target)
  148. optimizer.zero_grad()
  149. loss.backward()
  150. optimizer.step()
  151. total_loss += loss.item() * data.size(0)
  152. total_num += data.size(0)
  153. return total_loss / total_num
  154. def _predict(self, data_loader: DataLoader) -> torch.Tensor:
  155. """
  156. Internal method to predict the outputs given a DataLoader.
  157. Parameters
  158. ----------
  159. data_loader : DataLoader
  160. The DataLoader providing input samples.
  161. Returns
  162. -------
  163. torch.Tensor
  164. Raw output from the model.
  165. """
  166. model = self.model
  167. device = self.device
  168. model.eval()
  169. with torch.no_grad():
  170. results = []
  171. for data in data_loader:
  172. data = data.to(device)
  173. out = model(data)
  174. results.append(out)
  175. return torch.cat(results, axis=0)
  176. def predict(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray:
  177. """
  178. Predict the class of the input data.
  179. Parameters
  180. ----------
  181. data_loader : DataLoader, optional
  182. The data loader used for prediction, by default None.
  183. X : List[Any], optional
  184. The input data, by default None.
  185. Returns
  186. -------
  187. numpy.ndarray
  188. The predicted class of the input data.
  189. """
  190. if data_loader is None:
  191. dataset = PredictionDataset(X, self.test_transform)
  192. data_loader = DataLoader(
  193. dataset,
  194. batch_size=self.batch_size,
  195. num_workers=int(self.num_workers),
  196. collate_fn=self.collate_fn,
  197. )
  198. return self._predict(data_loader).argmax(axis=1).cpu().numpy()
  199. def predict_proba(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray:
  200. """
  201. Predict the probability of each class for the input data.
  202. Parameters
  203. ----------
  204. data_loader : DataLoader, optional
  205. The data loader used for prediction, by default None.
  206. X : List[Any], optional
  207. The input data, by default None.
  208. Returns
  209. -------
  210. numpy.ndarray
  211. The predicted probability of each class for the input data.
  212. """
  213. if data_loader is None:
  214. dataset = PredictionDataset(X, self.test_transform)
  215. data_loader = DataLoader(
  216. dataset,
  217. batch_size=self.batch_size,
  218. num_workers=int(self.num_workers),
  219. collate_fn=self.collate_fn,
  220. )
  221. return self._predict(data_loader).softmax(axis=1).cpu().numpy()
  222. def _score(self, data_loader: DataLoader) -> Tuple[float, float]:
  223. """
  224. Internal method to compute loss and accuracy for the data provided through a DataLoader.
  225. Parameters
  226. ----------
  227. data_loader : DataLoader
  228. Data loader to use for evaluation.
  229. Returns
  230. -------
  231. Tuple[float, float]
  232. mean_loss: float, The mean loss of the model on the provided data.
  233. accuracy: float, The accuracy of the model on the provided data.
  234. """
  235. model = self.model
  236. criterion = self.criterion
  237. device = self.device
  238. model.eval()
  239. total_correct_num, total_num, total_loss = 0, 0, 0.0
  240. with torch.no_grad():
  241. for data, target in data_loader:
  242. data, target = data.to(device), target.to(device)
  243. out = model(data)
  244. if len(out.shape) > 1:
  245. correct_num = (target == out.argmax(axis=1)).sum().item()
  246. else:
  247. correct_num = (target == (out > 0.5)).sum().item()
  248. loss = criterion(out, target)
  249. total_loss += loss.item() * data.size(0)
  250. total_correct_num += correct_num
  251. total_num += data.size(0)
  252. mean_loss = total_loss / total_num
  253. accuracy = total_correct_num / total_num
  254. return mean_loss, accuracy
  255. def score(
  256. self, data_loader: DataLoader = None, X: List[Any] = None, y: List[int] = None
  257. ) -> float:
  258. """
  259. Validate the model.
  260. Parameters
  261. ----------
  262. data_loader : DataLoader, optional
  263. The data loader used for scoring, by default None.
  264. X : List[Any], optional
  265. The input data, by default None.
  266. y : List[int], optional
  267. The target data, by default None.
  268. Returns
  269. -------
  270. float
  271. The accuracy of the model.
  272. """
  273. print_log("Start machine learning model validation", logger="current")
  274. if data_loader is None:
  275. data_loader = self._data_loader(X, y)
  276. mean_loss, accuracy = self._score(data_loader)
  277. print_log(f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current")
  278. return accuracy
  279. def _data_loader(self, X: List[Any], y: List[int] = None, shuffle: bool = True) -> DataLoader:
  280. """
  281. Generate a DataLoader for user-provided input and target data.
  282. Parameters
  283. ----------
  284. X : List[Any]
  285. Input samples.
  286. y : List[int], optional
  287. Target labels. If None, dummy labels are created, by default None.
  288. shuffle : bool, optional
  289. Whether to shuffle the data, by default True.
  290. Returns
  291. -------
  292. DataLoader
  293. A DataLoader providing batches of (X, y) pairs.
  294. """
  295. if X is None:
  296. raise ValueError("X should not be None.")
  297. if y is None:
  298. y = [0] * len(X)
  299. if not (len(y) == len(X)):
  300. raise ValueError("X and y should have equal length.")
  301. dataset = ClassificationDataset(X, y, transform=self.train_transform)
  302. data_loader = DataLoader(
  303. dataset,
  304. batch_size=self.batch_size,
  305. shuffle=shuffle,
  306. num_workers=int(self.num_workers),
  307. collate_fn=self.collate_fn,
  308. )
  309. return data_loader
  310. def save(self, epoch_id: int = 0, save_path: str = None) -> None:
  311. """
  312. Save the model and the optimizer.
  313. Parameters
  314. ----------
  315. epoch_id : int
  316. The epoch id.
  317. save_path : str, optional
  318. The path to save the model, by default None.
  319. """
  320. if self.save_dir is None and save_path is None:
  321. raise ValueError("'save_dir' and 'save_path' should not be None simultaneously.")
  322. if save_path is not None:
  323. if not os.path.exists(os.path.dirname(save_path)):
  324. os.makedirs(os.path.dirname(save_path))
  325. else:
  326. save_path = os.path.join(self.save_dir, f"model_checkpoint_epoch_{epoch_id}.pth")
  327. if not os.path.exists(self.save_dir):
  328. os.makedirs(self.save_dir)
  329. print_log(f"Checkpoints will be saved to {save_path}", logger="current")
  330. save_parma_dic = {
  331. "model": self.model.state_dict(),
  332. "optimizer": self.optimizer.state_dict(),
  333. }
  334. torch.save(save_parma_dic, save_path)
  335. def load(self, load_path: str = "") -> None:
  336. """
  337. Load the model and the optimizer.
  338. Parameters
  339. ----------
  340. load_path : str
  341. The directory to load the model, by default "".
  342. """
  343. if load_path is None:
  344. raise ValueError("Load path should not be None.")
  345. print_log(
  346. f"Loads checkpoint by local backend from path: {load_path}",
  347. logger="current",
  348. )
  349. param_dic = torch.load(load_path)
  350. self.model.load_state_dict(param_dic["model"])
  351. if "optimizer" in param_dic.keys():
  352. self.optimizer.load_state_dict(param_dic["optimizer"])

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.