You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

basic_model.py 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. # coding: utf-8
  2. #================================================================#
  3. # Copyright (C) 2020 Freecss All rights reserved.
  4. #
  5. # File Name :basic_model.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2020/11/21
  9. # Description :
  10. #
  11. #================================================================#
  12. import sys
  13. sys.path.append("..")
  14. import torch
  15. from torch.utils.data import Dataset
  16. import os
  17. from multiprocessing import Pool
  18. class XYDataset(Dataset):
  19. def __init__(self, X, Y, transform=None):
  20. self.X = X
  21. self.Y = torch.LongTensor(Y)
  22. self.n_sample = len(X)
  23. self.transform = transform
  24. def __len__(self):
  25. return len(self.X)
  26. def __getitem__(self, index):
  27. assert index < len(self), 'index range error'
  28. img = self.X[index]
  29. if self.transform is not None:
  30. img = self.transform(img)
  31. label = self.Y[index]
  32. return (img, label)
  33. class FakeRecorder():
  34. def __init__(self):
  35. pass
  36. def print(self, *x):
  37. pass
  38. class BasicModel():
  39. def __init__(self,
  40. model,
  41. criterion,
  42. optimizer,
  43. device,
  44. batch_size = 1,
  45. num_epochs = 1,
  46. stop_loss = 0.01,
  47. num_workers = 0,
  48. save_interval = None,
  49. save_dir = None,
  50. transform = None,
  51. collate_fn = None,
  52. recorder = None):
  53. self.model = model.to(device)
  54. self.batch_size = batch_size
  55. self.num_epochs = num_epochs
  56. self.stop_loss = stop_loss
  57. self.num_workers = num_workers
  58. self.criterion = criterion
  59. self.optimizer = optimizer
  60. self.transform = transform
  61. self.device = device
  62. if recorder is None:
  63. recorder = FakeRecorder()
  64. self.recorder = recorder
  65. self.save_interval = save_interval
  66. self.save_dir = save_dir
  67. self.collate_fn = collate_fn
  68. pass
  69. def _fit(self, data_loader, n_epoch, stop_loss):
  70. recorder = self.recorder
  71. recorder.print("model fitting")
  72. min_loss = 1e10
  73. for epoch in range(n_epoch):
  74. loss_value = self.train_epoch(data_loader)
  75. recorder.print(f"{epoch}/{n_epoch} model training loss is {loss_value}")
  76. if min_loss < 0 or loss_value < min_loss:
  77. min_loss = loss_value
  78. if self.save_interval is not None and (epoch + 1) % self.save_interval == 0:
  79. assert self.save_dir is not None
  80. self.save(epoch + 1, self.save_dir)
  81. if stop_loss is not None and loss_value < stop_loss:
  82. break
  83. recorder.print("Model fitted, minimal loss is ", min_loss)
  84. return loss_value
  85. def fit(self, data_loader = None,
  86. X = None,
  87. y = None):
  88. if data_loader is None:
  89. data_loader = self._data_loader(X, y)
  90. return self._fit(data_loader, self.num_epochs, self.stop_loss)
  91. def train_epoch(self, data_loader):
  92. model = self.model
  93. criterion = self.criterion
  94. optimizer = self.optimizer
  95. device = self.device
  96. model.train()
  97. total_loss, total_num = 0.0, 0
  98. for data, target in data_loader:
  99. data, target = data.to(device), target.to(device)
  100. out = model(data)
  101. loss = criterion(out, target)
  102. optimizer.zero_grad()
  103. loss.backward()
  104. optimizer.step()
  105. total_loss += loss.item() * data.size(0)
  106. total_num += data.size(0)
  107. return total_loss / total_num
  108. def _predict(self, data_loader):
  109. model = self.model
  110. device = self.device
  111. model.eval()
  112. with torch.no_grad():
  113. results = []
  114. for data, _ in data_loader:
  115. data = data.to(device)
  116. out = model(data)
  117. results.append(out)
  118. return torch.cat(results, axis=0)
  119. def predict(self, data_loader = None, X = None, print_prefix = ""):
  120. recorder = self.recorder
  121. recorder.print('Start Predict Class ', print_prefix)
  122. if data_loader is None:
  123. data_loader = self._data_loader(X)
  124. return self._predict(data_loader).argmax(axis=1).cpu().numpy()
  125. def predict_proba(self, data_loader = None, X = None, print_prefix = ""):
  126. recorder = self.recorder
  127. recorder.print('Start Predict Probability ', print_prefix)
  128. if data_loader is None:
  129. data_loader = self._data_loader(X)
  130. return self._predict(data_loader).softmax(axis=1).cpu().numpy()
  131. def _val(self, data_loader):
  132. model = self.model
  133. criterion = self.criterion
  134. device = self.device
  135. model.eval()
  136. total_correct_num, total_num, total_loss = 0, 0, 0.0
  137. with torch.no_grad():
  138. for data, target in data_loader:
  139. data, target = data.to(device), target.to(device)
  140. out = model(data)
  141. correct_num = sum(target == out.argmax(axis=1)).item()
  142. loss = criterion(out, target)
  143. total_loss += loss.item() * data.size(0)
  144. total_correct_num += correct_num
  145. total_num += data.size(0)
  146. mean_loss = total_loss / total_num
  147. accuracy = total_correct_num / total_num
  148. return mean_loss, accuracy
  149. def val(self, data_loader = None, X = None, y = None, print_prefix = ""):
  150. recorder = self.recorder
  151. recorder.print('Start val ', print_prefix)
  152. if data_loader is None:
  153. data_loader = self._data_loader(X, y)
  154. mean_loss, accuracy = self._val(data_loader)
  155. recorder.print('[%s] Val loss: %f, accuray: %f' % (print_prefix, mean_loss, accuracy))
  156. return accuracy
  157. def score(self, data_loader = None, X = None, y = None, print_prefix = ""):
  158. return self.val(data_loader, X, y, print_prefix)
  159. def _data_loader(self, X, y = None):
  160. collate_fn = self.collate_fn
  161. transform = self.transform
  162. if y is None:
  163. y = [0] * len(X)
  164. dataset = XYDataset(X, y, transform=transform)
  165. sampler = None
  166. data_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \
  167. shuffle=False, sampler=sampler, num_workers=int(self.num_workers), \
  168. collate_fn=collate_fn)
  169. return data_loader
  170. def save(self, epoch_id, save_dir):
  171. recorder = self.recorder
  172. if not os.path.exists(save_dir):
  173. os.mkdir(save_dir)
  174. recorder.print("Saving model and opter")
  175. save_path = os.path.join(save_dir, str(epoch_id) + "_net.pth")
  176. torch.save(self.model.state_dict(), save_path)
  177. save_path = os.path.join(save_dir, str(epoch_id) + "_opt.pth")
  178. torch.save(self.optimizer.state_dict(), save_path)
  179. def load(self, epoch_id, load_dir):
  180. recorder = self.recorder
  181. recorder.print("Loading model and opter")
  182. load_path = os.path.join(load_dir, str(epoch_id) + "_net.pth")
  183. self.model.load_state_dict(torch.load(load_path))
  184. load_path = os.path.join(load_dir, str(epoch_id) + "_opt.pth")
  185. self.optimizer.load_state_dict(torch.load(load_path))
  186. if __name__ == "__main__":
  187. pass

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.