You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

basic_nn.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438
  1. # coding: utf-8
  2. # ================================================================#
  3. # Copyright (C) 2020 Freecss All rights reserved.
  4. #
  5. # File Name :basic_model.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2020/11/21
  9. # Description :
  10. #
  11. # ================================================================#
  12. import sys
  13. sys.path.append("..")
  14. import torch
  15. import numpy
  16. from torch.utils.data import DataLoader
  17. from ..utils.logger import print_log
  18. from ..dataset import ClassificationDataset
  19. import os
  20. from typing import List, Any, T, Optional, Callable
  21. class BasicNN:
  22. """
  23. Wrap NN models into the form of an sklearn estimator
  24. Parameters
  25. ----------
  26. model : torch.nn.Module
  27. The PyTorch model to be trained or used for prediction.
  28. criterion : torch.nn.Module
  29. The loss function used for training.
  30. optimizer : torch.nn.Module
  31. The optimizer used for training.
  32. device : torch.device, optional
  33. The device on which the model will be trained or used for prediction, by default torch.decive("cpu").
  34. batch_size : int, optional
  35. The batch size used for training, by default 1.
  36. num_epochs : int, optional
  37. The number of epochs used for training, by default 1.
  38. stop_loss : Optional[float], optional
  39. The loss value at which to stop training, by default 0.01.
  40. num_workers : int, optional
  41. The number of workers used for loading data, by default 0.
  42. save_interval : Optional[int], optional
  43. The interval at which to save the model during training, by default None.
  44. save_dir : Optional[str], optional
  45. The directory in which to save the model during training, by default None.
  46. transform : Callable[..., Any], optional
  47. A function/transform that takes in an object and returns a transformed version. Defaults to None.
  48. collate_fn : Callable[[List[T]], Any], optional
  49. The function used to collate data, by default None.
  50. Attributes
  51. ----------
  52. model : torch.nn.Module
  53. The PyTorch model to be trained or used for prediction.
  54. batch_size : int
  55. The batch size used for training.
  56. num_epochs : int
  57. The number of epochs used for training.
  58. stop_loss : Optional[float]
  59. The loss value at which to stop training.
  60. num_workers : int
  61. The number of workers used for loading data.
  62. criterion : torch.nn.Module
  63. The loss function used for training.
  64. optimizer : torch.nn.Module
  65. The optimizer used for training.
  66. transform : Callable[..., Any]
  67. The transformation function used for data augmentation.
  68. device : torch.device
  69. The device on which the model will be trained or used for prediction.
  70. save_interval : Optional[int]
  71. The interval at which to save the model during training.
  72. save_dir : Optional[str]
  73. The directory in which to save the model during training.
  74. collate_fn : Callable[[List[T]], Any]
  75. The function used to collate data.
  76. Methods
  77. -------
  78. fit(data_loader=None, X=None, y=None)
  79. Train the model.
  80. train_epoch(data_loader)
  81. Train the model for one epoch.
  82. predict(data_loader=None, X=None, print_prefix="")
  83. Predict the class of the input data.
  84. predict_proba(data_loader=None, X=None, print_prefix="")
  85. Predict the probability of each class for the input data.
  86. val(data_loader=None, X=None, y=None, print_prefix="")
  87. Validate the model.
  88. score(data_loader=None, X=None, y=None, print_prefix="")
  89. Score the model.
  90. _data_loader(X, y=None)
  91. Generate the data_loader.
  92. save(epoch_id, save_dir="")
  93. Save the model.
  94. load(epoch_id, load_dir="")
  95. Load the model.
  96. """
  97. def __init__(
  98. self,
  99. model: torch.nn.Module,
  100. criterion: torch.nn.Module,
  101. optimizer: torch.nn.Module,
  102. device: torch.device = torch.device("cpu"),
  103. batch_size: int = 1,
  104. num_epochs: int = 1,
  105. stop_loss: Optional[float] = 0.01,
  106. num_workers: int = 0,
  107. save_interval: Optional[int] = None,
  108. save_dir: Optional[str] = None,
  109. transform: Callable[..., Any] = None,
  110. collate_fn: Callable[[List[T]], Any] = None,
  111. ):
  112. self.model = model.to(device)
  113. self.batch_size = batch_size
  114. self.num_epochs = num_epochs
  115. self.stop_loss = stop_loss
  116. self.num_workers = num_workers
  117. self.criterion = criterion
  118. self.optimizer = optimizer
  119. self.transform = transform
  120. self.device = device
  121. self.save_interval = save_interval
  122. self.save_dir = save_dir
  123. self.collate_fn = collate_fn
  124. def _fit(self, data_loader, n_epoch, stop_loss):
  125. min_loss = 1e10
  126. for epoch in range(n_epoch):
  127. loss_value = self.train_epoch(data_loader)
  128. if min_loss < 0 or loss_value < min_loss:
  129. min_loss = loss_value
  130. if self.save_interval is not None and (epoch + 1) % self.save_interval == 0:
  131. if self.save_dir is None:
  132. raise ValueError(
  133. "save_dir should not be None if save_interval is not None"
  134. )
  135. self.save(epoch + 1, self.save_dir)
  136. if stop_loss is not None and loss_value < stop_loss:
  137. break
  138. return min_loss
  139. def fit(
  140. self, data_loader: DataLoader = None, X: List[Any] = None, y: List[int] = None
  141. ) -> float:
  142. """
  143. Train the model.
  144. Parameters
  145. ----------
  146. data_loader : DataLoader, optional
  147. The data loader used for training, by default None
  148. X : List[Any], optional
  149. The input data, by default None
  150. y : List[int], optional
  151. The target data, by default None
  152. Returns
  153. -------
  154. float
  155. The loss value of the trained model.
  156. """
  157. if data_loader is None:
  158. data_loader = self._data_loader(X, y)
  159. return self._fit(data_loader, self.num_epochs, self.stop_loss)
  160. def train_epoch(self, data_loader: DataLoader):
  161. """
  162. Train the model for one epoch.
  163. Parameters
  164. ----------
  165. data_loader : DataLoader
  166. The data loader used for training.
  167. Returns
  168. -------
  169. float
  170. The loss value of the trained model.
  171. """
  172. model = self.model
  173. criterion = self.criterion
  174. optimizer = self.optimizer
  175. device = self.device
  176. model.train()
  177. total_loss, total_num = 0.0, 0
  178. for data, target in data_loader:
  179. data, target = data.to(device), target.to(device)
  180. out = model(data)
  181. loss = criterion(out, target)
  182. optimizer.zero_grad()
  183. loss.backward()
  184. optimizer.step()
  185. total_loss += loss.item() * data.size(0)
  186. total_num += data.size(0)
  187. return total_loss / total_num
  188. def _predict(self, data_loader):
  189. model = self.model
  190. device = self.device
  191. model.eval()
  192. with torch.no_grad():
  193. results = []
  194. for data, _ in data_loader:
  195. data = data.to(device)
  196. out = model(data)
  197. results.append(out)
  198. return torch.cat(results, axis=0)
  199. def predict(
  200. self,
  201. data_loader: DataLoader = None,
  202. X: List[Any] = None,
  203. print_prefix: str = "",
  204. ) -> numpy.ndarray:
  205. """
  206. Predict the class of the input data.
  207. Parameters
  208. ----------
  209. data_loader : DataLoader, optional
  210. The data loader used for prediction, by default None
  211. X : List[Any], optional
  212. The input data, by default None
  213. print_prefix : str, optional
  214. The prefix used for printing, by default ""
  215. Returns
  216. -------
  217. numpy.ndarray
  218. The predicted class of the input data.
  219. """
  220. if data_loader is None:
  221. data_loader = self._data_loader(X)
  222. return self._predict(data_loader).argmax(axis=1).cpu().numpy()
  223. def predict_proba(
  224. self,
  225. data_loader: DataLoader = None,
  226. X: List[Any] = None,
  227. print_prefix: str = "",
  228. ) -> numpy.ndarray:
  229. """
  230. Predict the probability of each class for the input data.
  231. Parameters
  232. ----------
  233. data_loader : DataLoader, optional
  234. The data loader used for prediction, by default None
  235. X : List[Any], optional
  236. The input data, by default None
  237. print_prefix : str, optional
  238. The prefix used for printing, by default ""
  239. Returns
  240. -------
  241. numpy.ndarray
  242. The predicted probability of each class for the input data.
  243. """
  244. if data_loader is None:
  245. data_loader = self._data_loader(X)
  246. return self._predict(data_loader).softmax(axis=1).cpu().numpy()
  247. def _score(self, data_loader):
  248. model = self.model
  249. criterion = self.criterion
  250. device = self.device
  251. model.eval()
  252. total_correct_num, total_num, total_loss = 0, 0, 0.0
  253. with torch.no_grad():
  254. for data, target in data_loader:
  255. data, target = data.to(device), target.to(device)
  256. out = model(data)
  257. if len(out.shape) > 1:
  258. correct_num = sum(target == out.argmax(axis=1)).item()
  259. else:
  260. correct_num = sum(target == (out > 0.5)).item()
  261. loss = criterion(out, target)
  262. total_loss += loss.item() * data.size(0)
  263. total_correct_num += correct_num
  264. total_num += data.size(0)
  265. mean_loss = total_loss / total_num
  266. accuracy = total_correct_num / total_num
  267. return mean_loss, accuracy
  268. def score(
  269. self,
  270. data_loader: DataLoader = None,
  271. X: List[Any] = None,
  272. y: List[int] = None,
  273. print_prefix: str = "",
  274. ) -> float:
  275. """
  276. Validate the model.
  277. Parameters
  278. ----------
  279. data_loader : DataLoader, optional
  280. The data loader used for scoring, by default None
  281. X : List[Any], optional
  282. The input data, by default None
  283. y : List[int], optional
  284. The target data, by default None
  285. print_prefix : str, optional
  286. The prefix used for printing, by default ""
  287. Returns
  288. -------
  289. float
  290. The accuracy of the model.
  291. """
  292. print_log(f"Start machine learning model validation", logger="current")
  293. if data_loader is None:
  294. data_loader = self._data_loader(X, y)
  295. mean_loss, accuracy = self._score(data_loader)
  296. print_log(f"{print_prefix} mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current")
  297. return accuracy
  298. def _data_loader(
  299. self,
  300. X: List[Any],
  301. y: List[int] = None,
  302. ) -> DataLoader:
  303. """
  304. Generate data_loader for user provided data.
  305. Parameters
  306. ----------
  307. X : List[Any]
  308. The input data.
  309. y : List[int], optional
  310. The target data, by default None
  311. Returns
  312. -------
  313. DataLoader
  314. The data loader.
  315. """
  316. collate_fn = self.collate_fn
  317. transform = self.transform
  318. if y is None:
  319. y = [0] * len(X)
  320. dataset = ClassificationDataset(X, y, transform=transform)
  321. sampler = None
  322. data_loader = DataLoader(
  323. dataset,
  324. batch_size=self.batch_size,
  325. shuffle=False,
  326. sampler=sampler,
  327. num_workers=int(self.num_workers),
  328. collate_fn=collate_fn,
  329. )
  330. return data_loader
  331. def save(self, epoch_id: int = 0, save_dir: str = None, save_path: str = None):
  332. """
  333. Save the model and the optimizer.
  334. Parameters
  335. ----------
  336. epoch_id : int
  337. The epoch id.
  338. save_dir : str, optional
  339. The directory to save the model, by default ""
  340. """
  341. if save_dir and (not os.path.exists(save_dir)):
  342. os.makedirs(save_dir)
  343. print_log(f"Checkpoints will be saved to {save_dir}", logger="current")
  344. if save_path is None:
  345. save_path = os.path.join(save_dir, str(epoch_id) + ".pth")
  346. print_log(f"Checkpoints will be saved to {save_path}", logger="current")
  347. save_parma_dic = {
  348. "model": self.model.state_dict(),
  349. "optimizer": self.optimizer.state_dict(),
  350. }
  351. torch.save(save_parma_dic, save_path)
  352. def load(self, epoch_id: int = 0, load_dir: str = "", load_path: str = None):
  353. """
  354. Load the model and the optimizer.
  355. Parameters
  356. ----------
  357. epoch_id : int
  358. The epoch id.
  359. load_dir : str, optional
  360. The directory to load the model, by default ""
  361. """
  362. if load_path is not None:
  363. print_log(f"Loads checkpoint by local backend from path: {load_path}", logger="current")
  364. else:
  365. print_log(f"Loads checkpoint by local backend from dir: {load_dir}", logger="current")
  366. load_path = os.path.join(load_dir, str(epoch_id) + ".pth")
  367. param_dic = torch.load(load_path)
  368. self.model.load_state_dict(param_dic["model"])
  369. if "optimizer" in param_dic.keys():
  370. self.optimizer.load_state_dict(param_dic["optimizer"])
  371. if __name__ == "__main__":
  372. pass

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.