You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

basic_model.py 9.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. # coding: utf-8
  2. #================================================================#
  3. # Copyright (C) 2020 Freecss All rights reserved.
  4. #
  5. # File Name :basic_model.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2020/11/21
  9. # Description :
  10. #
  11. #================================================================#
  12. import sys
  13. sys.path.append("..")
  14. import torch
  15. from torch.utils.data import Dataset
  16. import os
  17. from multiprocessing import Pool
  18. class XYDataset(Dataset):
  19. def __init__(self, X, Y, transform=None, target_transform=None):
  20. self.X = X
  21. self.Y = Y
  22. self.n_sample = len(X)
  23. self.transform = transform
  24. self.target_transform = target_transform
  25. def __len__(self):
  26. return len(self.X)
  27. def __getitem__(self, index):
  28. assert index < len(self), 'index range error'
  29. img = self.X[index]
  30. if self.transform is not None:
  31. img = self.transform(img)
  32. label = self.Y[index]
  33. if self.target_transform is not None:
  34. label = self.target_transform(label)
  35. return (img, label, index)
  36. class FakeRecorder():
  37. def __init__(self):
  38. pass
  39. def print(self, *x):
  40. pass
  41. class BasicModel():
  42. def __init__(self,
  43. model,
  44. criterion,
  45. optimizer,
  46. device,
  47. params,
  48. sign_list,
  49. transform = None,
  50. target_transform=None,
  51. collate_fn = None,
  52. recorder = None):
  53. self.model = model.to(device)
  54. self.criterion = criterion
  55. self.optimizer = optimizer
  56. self.transform = transform
  57. self.target_transform = target_transform
  58. self.device = device
  59. self.sign_list = sorted(list(set(sign_list)))
  60. self.mapping = dict(zip(sign_list, list(range(len(sign_list)))))
  61. self.remapping = dict(zip(list(range(len(sign_list))), sign_list))
  62. if recorder is None:
  63. recorder = FakeRecorder()
  64. self.recorder = recorder
  65. self.save_interval = params.saveInterval
  66. self.params = params
  67. self.collate_fn = collate_fn
  68. pass
  69. def _fit(self, data_loader, n_epoch, stop_loss):
  70. recorder = self.recorder
  71. recorder.print("model fitting")
  72. min_loss = 999999999
  73. for epoch in range(n_epoch):
  74. loss_value = self.train_epoch(data_loader)
  75. recorder.print(f"{epoch}/{n_epoch} model training loss is {loss_value}")
  76. if loss_value < min_loss:
  77. min_loss = loss_value
  78. if epoch > 0 and self.save_interval is not None and epoch % self.save_interval == 0:
  79. assert hasattr(self.params, 'save_dir')
  80. self.save(self.params.save_dir)
  81. if stop_loss is not None and loss_value < stop_loss:
  82. break
  83. recorder.print("Model fitted, minimal loss is ", min_loss)
  84. return loss_value
  85. def str2ints(self, Y):
  86. return [self.mapping[y] for y in Y]
  87. def fit(self, data_loader = None,
  88. X = None,
  89. y = None):
  90. if data_loader is None:
  91. params = self.params
  92. collate_fn = self.collate_fn
  93. transform = self.transform
  94. target_transform = self.target_transform
  95. Y = self.str2ints(y)
  96. train_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform)
  97. sampler = None
  98. data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params.batchSize, \
  99. shuffle=True, sampler=sampler, num_workers=int(params.workers), \
  100. collate_fn=collate_fn)
  101. return self._fit(data_loader, params.n_epoch, params.stop_loss)
  102. def train_epoch(self, data_loader):
  103. model = self.model
  104. criterion = self.criterion
  105. optimizer = self.optimizer
  106. device = self.device
  107. model.train()
  108. loss_value = 0
  109. for _, data in enumerate(data_loader):
  110. X = data[0].to(device)
  111. Y = data[1].to(device)
  112. pred_Y = model(X)
  113. loss = criterion(pred_Y, Y)
  114. optimizer.zero_grad()
  115. loss.backward()
  116. optimizer.step()
  117. loss_value += loss.item()
  118. return loss_value
  119. def _predict(self, data_loader):
  120. model = self.model
  121. device = self.device
  122. model.eval()
  123. with torch.no_grad():
  124. results = []
  125. for i, data in enumerate(data_loader):
  126. X = data[0].to(device)
  127. pred_Y = model(X)
  128. results.append(pred_Y)
  129. return torch.cat(results, axis=0)
  130. def predict(self, data_loader = None, X = None, print_prefix = ""):
  131. if data_loader is None:
  132. params = self.params
  133. collate_fn = self.collate_fn
  134. transform = self.transform
  135. target_transform = self.target_transform
  136. Y = [0] * len(X)
  137. val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform)
  138. sampler = None
  139. data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=params.batchSize, \
  140. shuffle=False, sampler=sampler, num_workers=int(params.workers), \
  141. collate_fn=collate_fn)
  142. recorder = self.recorder
  143. recorder.print('Start Predict ', print_prefix)
  144. Y = self._predict(data_loader).argmax(axis=1)
  145. return [self.remapping[int(y)] for y in Y]
  146. def predict_proba(self, data_loader = None, X = None, print_prefix = ""):
  147. if data_loader is None:
  148. params = self.params
  149. collate_fn = self.collate_fn
  150. transform = self.transform
  151. target_transform = self.target_transform
  152. Y = [0] * len(X)
  153. val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform)
  154. sampler = None
  155. data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=params.batchSize, \
  156. shuffle=False, sampler=sampler, num_workers=int(params.workers), \
  157. collate_fn=collate_fn)
  158. recorder = self.recorder
  159. recorder.print('Start Predict ', print_prefix)
  160. return torch.softmax(self._predict(data_loader), axis=1).cpu()
  161. def _val(self, data_loader, print_prefix):
  162. model = self.model
  163. criterion = self.criterion
  164. recorder = self.recorder
  165. device = self.device
  166. recorder.print('Start val ', print_prefix)
  167. model.eval()
  168. n_correct = 0
  169. pred_num = 0
  170. loss_value = 0
  171. with torch.no_grad():
  172. for i, data in enumerate(data_loader):
  173. X = data[0].to(device)
  174. Y = data[1].to(device)
  175. pred_Y = model(X)
  176. correct_num = sum(Y == pred_Y.argmax(axis=1))
  177. loss = criterion(pred_Y, Y)
  178. loss_value += loss.item()
  179. n_correct += correct_num
  180. pred_num += len(X)
  181. accuracy = float(n_correct) / float(pred_num)
  182. recorder.print('[%s] Val loss: %f, accuray: %f' % (print_prefix, loss_value, accuracy))
  183. return accuracy
  184. def val(self, data_loader = None, X = None, y = None, print_prefix = ""):
  185. if data_loader is None:
  186. params = self.params
  187. collate_fn = self.collate_fn
  188. transform = self.transform
  189. target_transform = self.target_transform
  190. Y = self.str2ints(y)
  191. val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform)
  192. sampler = None
  193. data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=params.batchSize, \
  194. shuffle=True, sampler=sampler, num_workers=int(params.workers), \
  195. collate_fn=collate_fn)
  196. return self._val(data_loader, print_prefix)
  197. def score(self, data_loader = None, X = None, y = None, print_prefix = ""):
  198. return self.val(data_loader, X, y, print_prefix)
  199. def save(self, save_dir):
  200. recorder = self.recorder
  201. if not os.path.exists(save_dir):
  202. os.mkdir(save_dir)
  203. recorder.print("Saving model and opter")
  204. save_path = os.path.join(save_dir, "net.pth")
  205. torch.save(self.model.state_dict(), save_path)
  206. save_path = os.path.join(save_dir, "opt.pth")
  207. torch.save(self.optimizer.state_dict(), save_path)
  208. def load(self, load_dir):
  209. recorder = self.recorder
  210. recorder.print("Loading model and opter")
  211. load_path = os.path.join(load_dir, "net.pth")
  212. self.model.load_state_dict(torch.load(load_path))
  213. load_path = os.path.join(load_dir, "opt.pth")
  214. self.optimizer.load_state_dict(torch.load(load_path))
  215. if __name__ == "__main__":
  216. pass

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.