You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

basic_model.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. # coding: utf-8
  2. #================================================================#
  3. # Copyright (C) 2020 Freecss All rights reserved.
  4. #
  5. # File Name :basic_model.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2020/11/21
  9. # Description :
  10. #
  11. #================================================================#
  12. import sys
  13. sys.path.append("..")
  14. import torch
  15. from torch.autograd import Variable
  16. from torch.utils.data import Dataset
  17. import torchvision
  18. import utils.utils as mutils
  19. import os
  20. from multiprocessing import Pool
  21. import random
  22. import torch
  23. from torch.utils.data import Dataset
  24. from torch.utils.data import sampler
  25. import torchvision.transforms as transforms
  26. import six
  27. import sys
  28. from PIL import Image
  29. import numpy as np
  30. import collections
  31. class resizeNormalize(object):
  32. def __init__(self, size, interpolation=Image.BILINEAR):
  33. self.size = size
  34. self.interpolation = interpolation
  35. self.toTensor = transforms.ToTensor()
  36. self.transform = transforms.Compose([
  37. #transforms.ToPILImage(),
  38. #transforms.RandomHorizontalFlip(),
  39. #transforms.RandomVerticalFlip(),
  40. #transforms.RandomRotation(30),
  41. #transforms.RandomAffine(30),
  42. transforms.ToTensor(),
  43. ])
  44. def __call__(self, img):
  45. #img = img.resize(self.size, self.interpolation)
  46. #img = self.toTensor(img)
  47. img = self.transform(img)
  48. img.sub_(0.5).div_(0.5)
  49. return img
  50. class XYDataset(Dataset):
  51. def __init__(self, X, Y, transform=None, target_transform=None):
  52. self.X = X
  53. self.Y = Y
  54. self.n_sample = len(X)
  55. self.transform = transform
  56. self.target_transform = target_transform
  57. def __len__(self):
  58. return len(self.X)
  59. def __getitem__(self, index):
  60. assert index < len(self), 'index range error'
  61. img = self.X[index]
  62. if self.transform is not None:
  63. img = self.transform(img)
  64. label = self.Y[index]
  65. if self.target_transform is not None:
  66. label = self.target_transform(label)
  67. return (img, label, index)
  68. class alignCollate(object):
  69. def __init__(self, imgH=32, imgW=100, keep_ratio=False, min_ratio=1):
  70. self.imgH = imgH
  71. self.imgW = imgW
  72. self.keep_ratio = keep_ratio
  73. self.min_ratio = min_ratio
  74. def __call__(self, batch):
  75. images, labels, img_keys = zip(*batch)
  76. imgH = self.imgH
  77. imgW = self.imgW
  78. if self.keep_ratio:
  79. ratios = []
  80. for image in images:
  81. w, h = image.shape[:2]
  82. ratios.append(w / float(h))
  83. ratios.sort()
  84. max_ratio = ratios[-1]
  85. imgW = int(np.floor(max_ratio * imgH))
  86. imgW = max(imgH * self.min_ratio, imgW) # assure imgH >= imgW
  87. transform = resizeNormalize((imgW, imgH))
  88. images = [transform(image) for image in images]
  89. images = torch.cat([t.unsqueeze(0) for t in images], 0)
  90. labels = torch.LongTensor(labels)
  91. return images, labels, img_keys
  92. class FakeRecorder():
  93. def __init__(self):
  94. pass
  95. def print(self, *x):
  96. pass
  97. from torch.nn import init
  98. from torch import nn
  99. def weigth_init(m):
  100. if isinstance(m, nn.Conv2d):
  101. init.xavier_uniform_(m.weight.data)
  102. init.constant_(m.bias.data,0.1)
  103. elif isinstance(m, nn.BatchNorm2d):
  104. m.weight.data.fill_(1)
  105. m.bias.data.zero_()
  106. elif isinstance(m, nn.Linear):
  107. m.weight.data.normal_(0,0.01)
  108. m.bias.data.zero_()
  109. class BasicModel():
  110. def __init__(self,
  111. model,
  112. criterion,
  113. optimizer,
  114. converter,
  115. device,
  116. params,
  117. sign_list,
  118. recorder = None):
  119. self.model = model.to(device)
  120. self.model.apply(weigth_init)
  121. self.criterion = criterion
  122. self.optimizer = optimizer
  123. self.converter = converter
  124. self.device = device
  125. sign_list = sorted(list(set(sign_list)))
  126. self.mapping = dict(zip(sign_list, list(range(len(sign_list)))))
  127. self.remapping = dict(zip(list(range(len(sign_list))), sign_list))
  128. if recorder is None:
  129. recorder = FakeRecorder()
  130. self.recorder = recorder
  131. self.save_interval = params.saveInterval
  132. self.params = params
  133. pass
  134. def _fit(self, data_loader, n_epoch, stop_loss):
  135. recorder = self.recorder
  136. recorder.print("model fitting")
  137. min_loss = 999999999
  138. for epoch in range(n_epoch):
  139. loss_value = self.train_epoch(data_loader)
  140. recorder.print(f"{epoch}/{n_epoch} model training loss is {loss_value}")
  141. if loss_value < min_loss:
  142. min_loss = loss_value
  143. if loss_value < stop_loss:
  144. break
  145. recorder.print("Model fitted, minimal loss is ", min_loss)
  146. return loss_value
  147. def str2ints(self, Y):
  148. return [self.mapping[y] for y in Y]
  149. def fit(self, data_loader = None,
  150. X = None,
  151. y = None,
  152. n_epoch = 100,
  153. stop_loss = 0.001):
  154. if data_loader is None:
  155. params = self.params
  156. Y = self.str2ints(y)
  157. train_dataset = XYDataset(X, Y)
  158. sampler = None
  159. data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params.batchSize, \
  160. shuffle=True, sampler=sampler, num_workers=int(params.workers), \
  161. collate_fn=alignCollate(imgH=params.imgH, imgW=params.imgW, keep_ratio=params.keep_ratio))
  162. return self._fit(data_loader, n_epoch, stop_loss)
  163. def train_epoch(self, data_loader):
  164. loss_avg = mutils.averager()
  165. for i, data in enumerate(data_loader):
  166. X = data[0]
  167. Y = data[1]
  168. cost = self.train_batch(X, Y)
  169. loss_avg.add(cost)
  170. loss_value = float(loss_avg.val())
  171. loss_avg.reset()
  172. return loss_value
  173. def train_batch(self, X, Y):
  174. #cpu_images, cpu_texts, _ = data
  175. model = self.model
  176. criterion = self.criterion
  177. optimizer = self.optimizer
  178. converter = self.converter
  179. device = self.device
  180. # set training mode
  181. for p in model.parameters():
  182. p.requires_grad = True
  183. model.train()
  184. # init training status
  185. torch.autograd.set_detect_anomaly(True)
  186. optimizer.zero_grad()
  187. # model predict
  188. X = X.to(device)
  189. Y = Y.to(device)
  190. pred_Y = model(X)
  191. # calculate loss
  192. loss = criterion(pred_Y, Y)
  193. # back propagation and optimize
  194. loss.backward()
  195. optimizer.step()
  196. return loss
  197. def _predict(self, data_loader):
  198. model = self.model
  199. criterion = self.criterion
  200. converter = self.converter
  201. params = self.params
  202. device = self.device
  203. for p in model.parameters():
  204. p.requires_grad = False
  205. model.eval()
  206. n_correct = 0
  207. results = []
  208. for i, data in enumerate(data_loader):
  209. X = data[0].to(device)
  210. pred_Y = model(X)
  211. results.append(pred_Y)
  212. return torch.cat(results, axis=0)
  213. def predict(self, data_loader = None, X = None, print_prefix = ""):
  214. params = self.params
  215. if data_loader is None:
  216. Y = [0] * len(X)
  217. val_dataset = XYDataset(X, Y)
  218. sampler = None
  219. data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=params.batchSize, \
  220. shuffle=False, sampler=sampler, num_workers=int(params.workers), \
  221. collate_fn=alignCollate(imgH=params.imgH, imgW=params.imgW, keep_ratio=params.keep_ratio))
  222. recorder = self.recorder
  223. recorder.print('Start Predict ', print_prefix)
  224. Y = self._predict(data_loader).argmax(axis=1)
  225. return [self.remapping[int(y)] for y in Y]
  226. def predict_proba(self, data_loader = None, X = None, print_prefix = ""):
  227. params = self.params
  228. if data_loader is None:
  229. Y = [0] * len(X)
  230. val_dataset = XYDataset(X, Y)
  231. sampler = None
  232. data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=params.batchSize, \
  233. shuffle=False, sampler=sampler, num_workers=int(params.workers), \
  234. collate_fn=alignCollate(imgH=params.imgH, imgW=params.imgW, keep_ratio=params.keep_ratio))
  235. recorder = self.recorder
  236. recorder.print('Start Predict ', print_prefix)
  237. return torch.softmax(self._predict(data_loader), axis=1)
  238. def _val(self, data_loader, print_prefix):
  239. model = self.model
  240. criterion = self.criterion
  241. recorder = self.recorder
  242. converter = self.converter
  243. params = self.params
  244. device = self.device
  245. recorder.print('Start val ', print_prefix)
  246. for p in model.parameters():
  247. p.requires_grad = False
  248. model.eval()
  249. n_correct = 0
  250. pred_num = 0
  251. loss_avg = mutils.averager()
  252. for i, data in enumerate(data_loader):
  253. X = data[0].to(device)
  254. Y = data[1].to(device)
  255. pred_Y = model(X)
  256. correct_num = sum(Y == pred_Y.argmax(axis=1))
  257. loss = criterion(pred_Y, Y)
  258. loss_avg.add(loss)
  259. n_correct += correct_num
  260. pred_num += len(X)
  261. accuracy = float(n_correct) / float(pred_num)
  262. recorder.print('[%s] Val loss: %f, accuray: %f' % (print_prefix, loss_avg.val(), accuracy))
  263. return accuracy
  264. def val(self, data_loader = None, X = None, y = None, print_prefix = ""):
  265. params = self.params
  266. if data_loader is None:
  267. y = self.str2ints(y)
  268. val_dataset = XYDataset(X, y)
  269. sampler = None
  270. data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=params.batchSize, \
  271. shuffle=True, sampler=sampler, num_workers=int(params.workers), \
  272. collate_fn=alignCollate(imgH=params.imgH, imgW=params.imgW, keep_ratio=params.keep_ratio))
  273. return self._val(data_loader, print_prefix)
  274. def score(self, data_loader = None, X = None, y = None, print_prefix = ""):
  275. return self.val(data_loader, X, y, print_prefix)
  276. def save(self, save_dir):
  277. recorder = self.recorder
  278. if not os.path.exists(save_dir):
  279. os.mkdir(save_dir)
  280. recorder.print("Saving model and opter")
  281. save_path = os.path.join(save_dir, "net.pth")
  282. torch.save(self.model.state_dict(), save_path)
  283. save_path = os.path.join(save_dir, "opt.pth")
  284. torch.save(self.optimizer.state_dict(), save_path)
  285. def load(self, load_dir):
  286. recorder = self.recorder
  287. recorder.print("Loading model and opter")
  288. load_path = os.path.join(load_dir, "net.pth")
  289. self.model.load_state_dict(torch.load(load_path))
  290. load_path = os.path.join(load_dir, "opt.pth")
  291. self.optimizer.load_state_dict(torch.load(load_path))
  292. if __name__ == "__main__":
  293. pass

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.