# coding: utf-8 #================================================================# # Copyright (C) 2020 Freecss All rights reserved. # # File Name :basic_model.py # Author :freecss # Email :karlfreecss@gmail.com # Created Date :2020/11/21 # Description : # #================================================================# import sys sys.path.append("..") import torch from torch.autograd import Variable from torch.utils.data import Dataset import torchvision import os from multiprocessing import Pool import random import torch from torch.utils.data import Dataset from torch.utils.data import sampler import torchvision.transforms as transforms import six import sys from PIL import Image import numpy as np import collections class resizeNormalize(object): def __init__(self, size, interpolation=Image.BILINEAR): self.size = size self.interpolation = interpolation self.toTensor = transforms.ToTensor() self.transform = transforms.Compose([ #transforms.ToPILImage(), #transforms.RandomHorizontalFlip(), #transforms.RandomVerticalFlip(), #transforms.RandomRotation(30), #transforms.RandomAffine(30), transforms.ToTensor(), ]) def __call__(self, img): #img = img.resize(self.size, self.interpolation) #img = self.toTensor(img) img = self.transform(img) img.sub_(0.5).div_(0.5) return img class XYDataset(Dataset): def __init__(self, X, Y, transform=None, target_transform=None): self.X = X self.Y = Y self.n_sample = len(X) self.transform = transform self.target_transform = target_transform def __len__(self): return len(self.X) def __getitem__(self, index): assert index < len(self), 'index range error' img = self.X[index] if self.transform is not None: img = self.transform(img) label = self.Y[index] if self.target_transform is not None: label = self.target_transform(label) return (img, label, index) class alignCollate(object): def __init__(self, imgH=32, imgW=100, keep_ratio=False, min_ratio=1): self.imgH = imgH self.imgW = imgW self.keep_ratio = keep_ratio self.min_ratio = min_ratio def __call__(self, batch): images, labels, img_keys = zip(*batch) imgH = self.imgH imgW = self.imgW if self.keep_ratio: ratios = [] for image in images: w, h = image.shape[:2] ratios.append(w / float(h)) ratios.sort() max_ratio = ratios[-1] imgW = int(np.floor(max_ratio * imgH)) imgW = max(imgH * self.min_ratio, imgW) # assure imgH >= imgW transform = resizeNormalize((imgW, imgH)) images = [transform(image) for image in images] images = torch.cat([t.unsqueeze(0) for t in images], 0) labels = torch.LongTensor(labels) return images, labels, img_keys class FakeRecorder(): def __init__(self): pass def print(self, *x): pass from torch.nn import init from torch import nn def weigth_init(m): if isinstance(m, nn.Conv2d): init.xavier_uniform_(m.weight.data) init.constant_(m.bias.data,0.1) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): m.weight.data.normal_(0,0.01) m.bias.data.zero_() class BasicModel(): def __init__(self, model, criterion, optimizer, device, params, sign_list, transform = None, target_transform=None, collate_fn = None, pretrained = False, recorder = None): self.model = model.to(device) self.criterion = criterion self.optimizer = optimizer self.transform = transform self.target_transform = target_transform self.device = device sign_list = sorted(list(set(sign_list))) self.mapping = dict(zip(sign_list, list(range(len(sign_list))))) self.remapping = dict(zip(list(range(len(sign_list))), sign_list)) if recorder is None: recorder = FakeRecorder() self.recorder = recorder if pretrained: # the paths of model, optimizer should be included in params self.load(params.load_dir) else: self.model.apply(weigth_init) self.save_interval = params.saveInterval self.params = params self.collate_fn = collate_fn pass def _fit(self, data_loader, n_epoch, stop_loss): recorder = self.recorder recorder.print("model fitting") min_loss = 999999999 for epoch in range(n_epoch): loss_value = self.train_epoch(data_loader) recorder.print(f"{epoch}/{n_epoch} model training loss is {loss_value}") if loss_value < min_loss: min_loss = loss_value if epoch > 0 and self.save_interval is not None and epoch % self.save_interval == 0: assert hasattr(self.params, 'save_dir') self.save(self.params.save_dir) if stop_loss is not None and loss_value < stop_loss: break recorder.print("Model fitted, minimal loss is ", min_loss) return loss_value def str2ints(self, Y): return [self.mapping[y] for y in Y] def fit(self, data_loader = None, X = None, y = None): if data_loader is None: params = self.params collate_fn = self.collate_fn transform = self.transform target_transform = self.target_transform Y = self.str2ints(y) train_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform) sampler = None data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params.batchSize, \ shuffle=True, sampler=sampler, num_workers=int(params.workers), \ collate_fn=collate_fn) return self._fit(data_loader, params.n_epoch, params.stop_loss) def train_epoch(self, data_loader): # loss_avg = mutils.averager() self.model.train() loss_value = 0 for i, data in enumerate(data_loader): X = data[0] Y = data[1] loss = self.train_batch(X, Y) loss_value += loss.item() return loss_value def train_batch(self, X, Y): #cpu_images, cpu_texts, _ = data model = self.model criterion = self.criterion optimizer = self.optimizer device = self.device # init training status # torch.autograd.set_detect_anomaly(True) # model predict X = X.to(device) Y = Y.to(device) pred_Y = model(X) # calculate loss loss = criterion(pred_Y, Y) # back propagation and optimize optimizer.zero_grad() loss.backward() optimizer.step() return loss def _predict(self, data_loader): model = self.model device = self.device model.eval() with torch.no_grad(): results = [] for i, data in enumerate(data_loader): X = data[0].to(device) pred_Y = model(X) results.append(pred_Y) return torch.cat(results, axis=0) def predict(self, data_loader = None, X = None, print_prefix = ""): if data_loader is None: params = self.params collate_fn = self.collate_fn transform = self.transform target_transform = self.target_transform Y = [0] * len(X) val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform) sampler = None data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=params.batchSize, \ shuffle=False, sampler=sampler, num_workers=int(params.workers), \ collate_fn=collate_fn) recorder = self.recorder recorder.print('Start Predict ', print_prefix) Y = self._predict(data_loader).argmax(axis=1) return [self.remapping[int(y)] for y in Y] def predict_proba(self, data_loader = None, X = None, print_prefix = ""): if data_loader is None: params = self.params collate_fn = self.collate_fn transform = self.transform target_transform = self.target_transform Y = [0] * len(X) val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform) sampler = None data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=params.batchSize, \ shuffle=False, sampler=sampler, num_workers=int(params.workers), \ collate_fn=collate_fn) recorder = self.recorder recorder.print('Start Predict ', print_prefix) return torch.softmax(self._predict(data_loader), axis=1) def _val(self, data_loader, print_prefix): model = self.model criterion = self.criterion recorder = self.recorder device = self.device recorder.print('Start val ', print_prefix) model.eval() n_correct = 0 pred_num = 0 loss_value = 0 with torch.no_grad(): for i, data in enumerate(data_loader): X = data[0].to(device) Y = data[1].to(device) pred_Y = model(X) correct_num = sum(Y == pred_Y.argmax(axis=1)) loss = criterion(pred_Y, Y) loss_value += loss.item() n_correct += correct_num pred_num += len(X) accuracy = float(n_correct) / float(pred_num) recorder.print('[%s] Val loss: %f, accuray: %f' % (print_prefix, loss_value, accuracy)) return accuracy def val(self, data_loader = None, X = None, y = None, print_prefix = ""): if data_loader is None: params = self.params collate_fn = self.collate_fn transform = self.transform target_transform = self.target_transform Y = self.str2ints(y) val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform) sampler = None data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=params.batchSize, \ shuffle=True, sampler=sampler, num_workers=int(params.workers), \ collate_fn=collate_fn) return self._val(data_loader, print_prefix) def score(self, data_loader = None, X = None, y = None, print_prefix = ""): return self.val(data_loader, X, y, print_prefix) def save(self, save_dir): recorder = self.recorder if not os.path.exists(save_dir): os.mkdir(save_dir) recorder.print("Saving model and opter") save_path = os.path.join(save_dir, "net.pth") torch.save(self.model.state_dict(), save_path) save_path = os.path.join(save_dir, "opt.pth") torch.save(self.optimizer.state_dict(), save_path) def load(self, load_dir): recorder = self.recorder recorder.print("Loading model and opter") load_path = os.path.join(load_dir, "net.pth") self.model.load_state_dict(torch.load(load_path)) load_path = os.path.join(load_dir, "opt.pth") self.optimizer.load_state_dict(torch.load(load_path)) if __name__ == "__main__": pass