You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

basic_model.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560
  1. # coding: utf-8
  2. # ================================================================#
  3. # Copyright (C) 2020 Freecss All rights reserved.
  4. #
  5. # File Name :basic_model.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2020/11/21
  9. # Description :
  10. #
  11. # ================================================================#
  12. import sys
  13. sys.path.append("..")
  14. import torch
  15. import numpy
  16. from torch.utils.data import Dataset, DataLoader
  17. import os
  18. from multiprocessing import Pool
  19. from typing import List, Any, T, Tuple, Optional, Callable
  20. class BasicDataset(Dataset):
  21. def __init__(self, X: List[Any], Y: List[Any]):
  22. """Initialize a basic dataset.
  23. Parameters
  24. ----------
  25. X : List[Any]
  26. A list of objects representing the input data.
  27. Y : List[Any]
  28. A list of objects representing the output data.
  29. """
  30. self.X = X
  31. self.Y = Y
  32. def __len__(self):
  33. """Return the length of the dataset.
  34. Returns
  35. -------
  36. int
  37. The length of the dataset.
  38. """
  39. return len(self.X)
  40. def __getitem__(self, index: int) -> Tuple(Any, Any):
  41. """Get an item from the dataset.
  42. Parameters
  43. ----------
  44. index : int
  45. The index of the item to retrieve.
  46. Returns
  47. -------
  48. Tuple(Any, Any)
  49. A tuple containing the input and output data at the specified index.
  50. """
  51. if index >= len(self):
  52. raise ValueError("index range error")
  53. img = self.X[index]
  54. label = self.Y[index]
  55. return (img, label)
  56. class XYDataset(Dataset):
  57. def __init__(self, X: List[Any], Y: List[int], transform: Callable[...] = None):
  58. """
  59. Initialize the dataset used for classification task.
  60. Parameters
  61. ----------
  62. X : List[Any]
  63. The input data.
  64. Y : List[int]
  65. The target data.
  66. transform : callable, optional
  67. A function/transform that takes in an object and returns a transformed version. Defaults to None.
  68. """
  69. self.X = X
  70. self.Y = torch.LongTensor(Y)
  71. self.n_sample = len(X)
  72. self.transform = transform
  73. def __len__(self) -> int:
  74. """
  75. Return the length of the dataset.
  76. Returns
  77. -------
  78. int
  79. The length of the dataset.
  80. """
  81. return len(self.X)
  82. def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]:
  83. """
  84. Get the item at the given index.
  85. Parameters
  86. ----------
  87. index : int
  88. The index of the item to get.
  89. Returns
  90. -------
  91. Tuple[Any, torch.Tensor]
  92. A tuple containing the object and its label.
  93. """
  94. if index >= len(self):
  95. raise ValueError("index range error")
  96. img = self.X[index]
  97. if self.transform is not None:
  98. img = self.transform(img)
  99. label = self.Y[index]
  100. return (img, label)
  101. class FakeRecorder:
  102. def __init__(self):
  103. pass
  104. def print(self, *x):
  105. pass
  106. class BasicModel:
  107. """
  108. Wrap NN models into the form of an sklearn estimator
  109. Parameters
  110. ----------
  111. model : torch.nn.Module
  112. The PyTorch model to be trained or used for prediction.
  113. criterion : torch.nn.Module
  114. The loss function used for training.
  115. optimizer : torch.nn.Module
  116. The optimizer used for training.
  117. device : torch.device
  118. The device on which the model will be trained or used for prediction.
  119. batch_size : int, optional
  120. The batch size used for training, by default 1.
  121. num_epochs : int, optional
  122. The number of epochs used for training, by default 1.
  123. stop_loss : Optional[float], optional
  124. The loss value at which to stop training, by default 0.01.
  125. num_workers : int, optional
  126. The number of workers used for loading data, by default 0.
  127. save_interval : Optional[int], optional
  128. The interval at which to save the model during training, by default None.
  129. save_dir : Optional[str], optional
  130. The directory in which to save the model during training, by default None.
  131. transform : Callable[..., Any], optional
  132. The transformation function used for data augmentation, by default None.
  133. collate_fn : Callable[[List[T]], Any], optional
  134. The function used to collate data, by default None.
  135. recorder : Any, optional
  136. The recorder used to record training progress, by default None.
  137. Attributes
  138. ----------
  139. model : torch.nn.Module
  140. The PyTorch model to be trained or used for prediction.
  141. batch_size : int
  142. The batch size used for training.
  143. num_epochs : int
  144. The number of epochs used for training.
  145. stop_loss : Optional[float]
  146. The loss value at which to stop training.
  147. num_workers : int
  148. The number of workers used for loading data.
  149. criterion : torch.nn.Module
  150. The loss function used for training.
  151. optimizer : torch.nn.Module
  152. The optimizer used for training.
  153. transform : Callable[..., Any]
  154. The transformation function used for data augmentation.
  155. device : torch.device
  156. The device on which the model will be trained or used for prediction.
  157. recorder : Any
  158. The recorder used to record training progress.
  159. save_interval : Optional[int]
  160. The interval at which to save the model during training.
  161. save_dir : Optional[str]
  162. The directory in which to save the model during training.
  163. collate_fn : Callable[[List[T]], Any]
  164. The function used to collate data.
  165. Methods
  166. -------
  167. fit(data_loader=None, X=None, y=None)
  168. Train the model.
  169. train_epoch(data_loader)
  170. Train the model for one epoch.
  171. predict(data_loader=None, X=None, print_prefix="")
  172. Predict the class of the input data.
  173. predict_proba(data_loader=None, X=None, print_prefix="")
  174. Predict the probability of each class for the input data.
  175. val(data_loader=None, X=None, y=None, print_prefix="")
  176. Validate the model.
  177. score(data_loader=None, X=None, y=None, print_prefix="")
  178. Score the model.
  179. _data_loader(X, y=None)
  180. Generate the data_loader.
  181. save(epoch_id, save_dir="")
  182. Save the model.
  183. load(epoch_id, load_dir="")
  184. Load the model.
  185. """
  186. def __init__(
  187. self,
  188. model: torch.nn.Module,
  189. criterion: torch.nn.Module,
  190. optimizer: torch.nn.Module,
  191. device: torch.device,
  192. batch_size: int = 1,
  193. num_epochs: int = 1,
  194. stop_loss: Optional[float] = 0.01,
  195. num_workers: int = 0,
  196. save_interval: Optional[int] = None,
  197. save_dir: Optional[str] = None,
  198. transform: Callable[...] = None,
  199. collate_fn: Callable[[List[T]], Any] = None,
  200. recorder=None,
  201. ):
  202. self.model = model.to(device)
  203. self.batch_size = batch_size
  204. self.num_epochs = num_epochs
  205. self.stop_loss = stop_loss
  206. self.num_workers = num_workers
  207. self.criterion = criterion
  208. self.optimizer = optimizer
  209. self.transform = transform
  210. self.device = device
  211. if recorder is None:
  212. recorder = FakeRecorder()
  213. self.recorder = recorder
  214. self.save_interval = save_interval
  215. self.save_dir = save_dir
  216. self.collate_fn = collate_fn
  217. def _fit(self, data_loader, n_epoch, stop_loss):
  218. recorder = self.recorder
  219. recorder.print("model fitting")
  220. min_loss = 1e10
  221. for epoch in range(n_epoch):
  222. loss_value = self.train_epoch(data_loader)
  223. recorder.print(f"{epoch}/{n_epoch} model training loss is {loss_value}")
  224. if min_loss < 0 or loss_value < min_loss:
  225. min_loss = loss_value
  226. if self.save_interval is not None and (epoch + 1) % self.save_interval == 0:
  227. if self.save_dir is None:
  228. raise ValueError(
  229. "save_dir should not be None if save_interval is not None"
  230. )
  231. self.save(epoch + 1, self.save_dir)
  232. if stop_loss is not None and loss_value < stop_loss:
  233. break
  234. recorder.print("Model fitted, minimal loss is ", min_loss)
  235. return loss_value
  236. def fit(
  237. self, data_loader: DataLoader = None, X: List[Any] = None, y: List[int] = None
  238. ) -> float:
  239. """
  240. Train the model.
  241. Parameters
  242. ----------
  243. data_loader : DataLoader, optional
  244. The data loader used for training, by default None
  245. X : List[Any], optional
  246. The input data, by default None
  247. y : List[int], optional
  248. The target data, by default None
  249. Returns
  250. -------
  251. float
  252. The loss value of the trained model.
  253. """
  254. if data_loader is None:
  255. data_loader = self._data_loader(X, y)
  256. return self._fit(data_loader, self.num_epochs, self.stop_loss)
  257. def train_epoch(self, data_loader: DataLoader):
  258. """
  259. Train the model for one epoch.
  260. Parameters
  261. ----------
  262. data_loader : DataLoader
  263. The data loader used for training.
  264. Returns
  265. -------
  266. float
  267. The loss value of the trained model.
  268. """
  269. model = self.model
  270. criterion = self.criterion
  271. optimizer = self.optimizer
  272. device = self.device
  273. model.train()
  274. total_loss, total_num = 0.0, 0
  275. for data, target in data_loader:
  276. data, target = data.to(device), target.to(device)
  277. out = model(data)
  278. loss = criterion(out, target)
  279. optimizer.zero_grad()
  280. loss.backward()
  281. optimizer.step()
  282. total_loss += loss.item() * data.size(0)
  283. total_num += data.size(0)
  284. return total_loss / total_num
  285. def _predict(self, data_loader):
  286. model = self.model
  287. device = self.device
  288. model.eval()
  289. with torch.no_grad():
  290. results = []
  291. for data, _ in data_loader:
  292. data = data.to(device)
  293. out = model(data)
  294. results.append(out)
  295. return torch.cat(results, axis=0)
  296. def predict(
  297. self,
  298. data_loader: DataLoader = None,
  299. X: List[Any] = None,
  300. print_prefix: str = "",
  301. ) -> numpy.ndarray:
  302. """
  303. Predict the class of the input data.
  304. Parameters
  305. ----------
  306. data_loader : DataLoader, optional
  307. The data loader used for prediction, by default None
  308. X : List[Any], optional
  309. The input data, by default None
  310. print_prefix : str, optional
  311. The prefix used for printing, by default ""
  312. Returns
  313. -------
  314. numpy.ndarray
  315. The predicted class of the input data.
  316. """
  317. recorder = self.recorder
  318. recorder.print("Start Predict Class ", print_prefix)
  319. if data_loader is None:
  320. data_loader = self._data_loader(X)
  321. return self._predict(data_loader).argmax(axis=1).cpu().numpy()
  322. def predict_proba(
  323. self,
  324. data_loader: DataLoader = None,
  325. X: List[Any] = None,
  326. print_prefix: str = "",
  327. ) -> numpy.ndarray:
  328. """
  329. Predict the probability of each class for the input data.
  330. Parameters
  331. ----------
  332. data_loader : DataLoader, optional
  333. The data loader used for prediction, by default None
  334. X : List[Any], optional
  335. The input data, by default None
  336. print_prefix : str, optional
  337. The prefix used for printing, by default ""
  338. Returns
  339. -------
  340. numpy.ndarray
  341. The predicted probability of each class for the input data.
  342. """
  343. recorder = self.recorder
  344. recorder.print("Start Predict Probability ", print_prefix)
  345. if data_loader is None:
  346. data_loader = self._data_loader(X)
  347. return self._predict(data_loader).softmax(axis=1).cpu().numpy()
  348. def _score(self, data_loader):
  349. model = self.model
  350. criterion = self.criterion
  351. device = self.device
  352. model.eval()
  353. total_correct_num, total_num, total_loss = 0, 0, 0.0
  354. with torch.no_grad():
  355. for data, target in data_loader:
  356. data, target = data.to(device), target.to(device)
  357. out = model(data)
  358. if len(out.shape) > 1:
  359. correct_num = sum(target == out.argmax(axis=1)).item()
  360. else:
  361. correct_num = sum(target == (out > 0.5)).item()
  362. loss = criterion(out, target)
  363. total_loss += loss.item() * data.size(0)
  364. total_correct_num += correct_num
  365. total_num += data.size(0)
  366. mean_loss = total_loss / total_num
  367. accuracy = total_correct_num / total_num
  368. return mean_loss, accuracy
  369. def score(
  370. self,
  371. data_loader: DataLoader = None,
  372. X: List[Any] = None,
  373. y: List[int] = None,
  374. print_prefix: str = "",
  375. ) -> float:
  376. """
  377. Validate the model.
  378. Parameters
  379. ----------
  380. data_loader : DataLoader, optional
  381. The data loader used for scoring, by default None
  382. X : List[Any], optional
  383. The input data, by default None
  384. y : List[int], optional
  385. The target data, by default None
  386. print_prefix : str, optional
  387. The prefix used for printing, by default ""
  388. Returns
  389. -------
  390. float
  391. The accuracy of the model.
  392. """
  393. recorder = self.recorder
  394. recorder.print("Start validation ", print_prefix)
  395. if data_loader is None:
  396. data_loader = self._data_loader(X, y)
  397. mean_loss, accuracy = self._score(data_loader)
  398. recorder.print(
  399. "[%s] mean loss: %f, accuray: %f" % (print_prefix, mean_loss, accuracy)
  400. )
  401. return accuracy
  402. def _data_loader(
  403. self,
  404. X: List[Any],
  405. y: List[int] = None,
  406. ) -> DataLoader:
  407. """
  408. Generate data_loader for user provided data.
  409. Parameters
  410. ----------
  411. X : List[Any]
  412. The input data.
  413. y : List[int], optional
  414. The target data, by default None
  415. Returns
  416. -------
  417. DataLoader
  418. The data loader.
  419. """
  420. collate_fn = self.collate_fn
  421. transform = self.transform
  422. if y is None:
  423. y = [0] * len(X)
  424. dataset = XYDataset(X, y, transform=transform)
  425. sampler = None
  426. data_loader = DataLoader(
  427. dataset,
  428. batch_size=self.batch_size,
  429. shuffle=False,
  430. sampler=sampler,
  431. num_workers=int(self.num_workers),
  432. collate_fn=collate_fn,
  433. )
  434. return data_loader
  435. def save(self, epoch_id: int, save_dir: str = ""):
  436. """
  437. Save the model and the optimizer.
  438. Parameters
  439. ----------
  440. epoch_id : int
  441. The epoch id.
  442. save_dir : str, optional
  443. The directory to save the model, by default ""
  444. """
  445. recorder = self.recorder
  446. if not os.path.exists(save_dir):
  447. os.makedirs(save_dir)
  448. recorder.print("Saving model and opter")
  449. save_path = os.path.join(save_dir, str(epoch_id) + "_net.pth")
  450. torch.save(self.model.state_dict(), save_path)
  451. save_path = os.path.join(save_dir, str(epoch_id) + "_opt.pth")
  452. torch.save(self.optimizer.state_dict(), save_path)
  453. def load(self, epoch_id: int, load_dir: str = ""):
  454. """
  455. Load the model and the optimizer.
  456. Parameters
  457. ----------
  458. epoch_id : int
  459. The epoch id.
  460. load_dir : str, optional
  461. The directory to load the model, by default ""
  462. """
  463. recorder = self.recorder
  464. recorder.print("Loading model and opter")
  465. load_path = os.path.join(load_dir, str(epoch_id) + "_net.pth")
  466. self.model.load_state_dict(torch.load(load_path))
  467. load_path = os.path.join(load_dir, str(epoch_id) + "_opt.pth")
  468. self.optimizer.load_state_dict(torch.load(load_path))
  469. if __name__ == "__main__":
  470. pass

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.