You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

basic_model.py 7.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. # coding: utf-8
  2. # ================================================================#
  3. # Copyright (C) 2020 Freecss All rights reserved.
  4. #
  5. # File Name :basic_model.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2020/11/21
  9. # Description :
  10. #
  11. # ================================================================#
  12. import sys
  13. sys.path.append("..")
  14. import torch
  15. from torch.utils.data import Dataset
  16. import os
  17. from multiprocessing import Pool
  18. class BasicDataset(Dataset):
  19. def __init__(self, X, Y):
  20. self.X = X
  21. self.Y = Y
  22. def __len__(self):
  23. return len(self.X)
  24. def __getitem__(self, index):
  25. assert index < len(self), "index range error"
  26. img = self.X[index]
  27. label = self.Y[index]
  28. return (img, label)
  29. class XYDataset(Dataset):
  30. def __init__(self, X, Y, transform=None):
  31. self.X = X
  32. self.Y = torch.LongTensor(Y)
  33. self.n_sample = len(X)
  34. self.transform = transform
  35. def __len__(self):
  36. return len(self.X)
  37. def __getitem__(self, index):
  38. assert index < len(self), "index range error"
  39. img = self.X[index]
  40. if self.transform is not None:
  41. img = self.transform(img)
  42. label = self.Y[index]
  43. return (img, label)
  44. class FakeRecorder:
  45. def __init__(self):
  46. pass
  47. def print(self, *x):
  48. pass
  49. class BasicModel:
  50. def __init__(
  51. self,
  52. model,
  53. criterion,
  54. optimizer,
  55. device,
  56. batch_size=1,
  57. num_epochs=1,
  58. stop_loss=0.01,
  59. num_workers=0,
  60. save_interval=None,
  61. save_dir=None,
  62. transform=None,
  63. collate_fn=None,
  64. recorder=None,
  65. ):
  66. self.model = model.to(device)
  67. self.batch_size = batch_size
  68. self.num_epochs = num_epochs
  69. self.stop_loss = stop_loss
  70. self.num_workers = num_workers
  71. self.criterion = criterion
  72. self.optimizer = optimizer
  73. self.transform = transform
  74. self.device = device
  75. if recorder is None:
  76. recorder = FakeRecorder()
  77. self.recorder = recorder
  78. self.save_interval = save_interval
  79. self.save_dir = save_dir
  80. self.collate_fn = collate_fn
  81. pass
  82. def _fit(self, data_loader, n_epoch, stop_loss):
  83. recorder = self.recorder
  84. recorder.print("model fitting")
  85. min_loss = 1e10
  86. for epoch in range(n_epoch):
  87. loss_value = self.train_epoch(data_loader)
  88. recorder.print(f"{epoch}/{n_epoch} model training loss is {loss_value}")
  89. if min_loss < 0 or loss_value < min_loss:
  90. min_loss = loss_value
  91. if self.save_interval is not None and (epoch + 1) % self.save_interval == 0:
  92. assert self.save_dir is not None
  93. self.save(epoch + 1, self.save_dir)
  94. if stop_loss is not None and loss_value < stop_loss:
  95. break
  96. recorder.print("Model fitted, minimal loss is ", min_loss)
  97. return loss_value
  98. def fit(self, data_loader=None, X=None, y=None):
  99. if data_loader is None:
  100. data_loader = self._data_loader(X, y)
  101. return self._fit(data_loader, self.num_epochs, self.stop_loss)
  102. def train_epoch(self, data_loader):
  103. model = self.model
  104. criterion = self.criterion
  105. optimizer = self.optimizer
  106. device = self.device
  107. model.train()
  108. total_loss, total_num = 0.0, 0
  109. for data, target in data_loader:
  110. data, target = data.to(device), target.to(device)
  111. out = model(data)
  112. loss = criterion(out, target)
  113. optimizer.zero_grad()
  114. loss.backward()
  115. optimizer.step()
  116. total_loss += loss.item() * data.size(0)
  117. total_num += data.size(0)
  118. return total_loss / total_num
  119. def _predict(self, data_loader):
  120. model = self.model
  121. device = self.device
  122. model.eval()
  123. with torch.no_grad():
  124. results = []
  125. for data, _ in data_loader:
  126. data = data.to(device)
  127. out = model(data)
  128. results.append(out)
  129. return torch.cat(results, axis=0)
  130. def predict(self, data_loader=None, X=None, print_prefix=""):
  131. recorder = self.recorder
  132. recorder.print("Start Predict Class ", print_prefix)
  133. if data_loader is None:
  134. data_loader = self._data_loader(X)
  135. return self._predict(data_loader).argmax(axis=1).cpu().numpy()
  136. def predict_proba(self, data_loader=None, X=None, print_prefix=""):
  137. recorder = self.recorder
  138. # recorder.print('Start Predict Probability ', print_prefix)
  139. if data_loader is None:
  140. data_loader = self._data_loader(X)
  141. return self._predict(data_loader).softmax(axis=1).cpu().numpy()
  142. def _val(self, data_loader):
  143. model = self.model
  144. criterion = self.criterion
  145. device = self.device
  146. model.eval()
  147. total_correct_num, total_num, total_loss = 0, 0, 0.0
  148. with torch.no_grad():
  149. for data, target in data_loader:
  150. data, target = data.to(device), target.to(device)
  151. out = model(data)
  152. if len(out.shape) > 1:
  153. correct_num = sum(target == out.argmax(axis=1)).item()
  154. else:
  155. correct_num = sum(target == (out > 0.5)).item()
  156. loss = criterion(out, target)
  157. total_loss += loss.item() * data.size(0)
  158. total_correct_num += correct_num
  159. total_num += data.size(0)
  160. mean_loss = total_loss / total_num
  161. accuracy = total_correct_num / total_num
  162. return mean_loss, accuracy
  163. def val(self, data_loader=None, X=None, y=None, print_prefix=""):
  164. recorder = self.recorder
  165. recorder.print("Start val ", print_prefix)
  166. if data_loader is None:
  167. data_loader = self._data_loader(X, y)
  168. mean_loss, accuracy = self._val(data_loader)
  169. recorder.print(
  170. "[%s] Val loss: %f, accuray: %f" % (print_prefix, mean_loss, accuracy)
  171. )
  172. return accuracy
  173. def score(self, data_loader=None, X=None, y=None, print_prefix=""):
  174. return self.val(data_loader, X, y, print_prefix)
  175. def _data_loader(self, X, y=None):
  176. collate_fn = self.collate_fn
  177. transform = self.transform
  178. if y is None:
  179. y = [0] * len(X)
  180. dataset = XYDataset(X, y, transform=transform)
  181. sampler = None
  182. data_loader = torch.utils.data.DataLoader(
  183. dataset,
  184. batch_size=self.batch_size,
  185. shuffle=False,
  186. sampler=sampler,
  187. num_workers=int(self.num_workers),
  188. collate_fn=collate_fn,
  189. )
  190. return data_loader
  191. def save(self, epoch_id, save_dir):
  192. recorder = self.recorder
  193. if not os.path.exists(save_dir):
  194. os.mkdir(save_dir)
  195. recorder.print("Saving model and opter")
  196. save_path = os.path.join(save_dir, str(epoch_id) + "_net.pth")
  197. torch.save(self.model.state_dict(), save_path)
  198. save_path = os.path.join(save_dir, str(epoch_id) + "_opt.pth")
  199. torch.save(self.optimizer.state_dict(), save_path)
  200. def load(self, epoch_id, load_dir):
  201. recorder = self.recorder
  202. recorder.print("Loading model and opter")
  203. load_path = os.path.join(load_dir, str(epoch_id) + "_net.pth")
  204. self.model.load_state_dict(torch.load(load_path))
  205. load_path = os.path.join(load_dir, str(epoch_id) + "_opt.pth")
  206. self.optimizer.load_state_dict(torch.load(load_path))
  207. if __name__ == "__main__":
  208. pass

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.