# coding: utf-8 #================================================================# # Copyright (C) 2020 Freecss All rights reserved. # # File Name :basic_model.py # Author :freecss # Email :karlfreecss@gmail.com # Created Date :2020/11/21 # Description : # #================================================================# import sys sys.path.append("..") import torch from torch.utils.data import Dataset import os from multiprocessing import Pool class XYDataset(Dataset): def __init__(self, X, Y, transform=None): self.X = X self.Y = torch.LongTensor(Y) self.n_sample = len(X) self.transform = transform def __len__(self): return len(self.X) def __getitem__(self, index): assert index < len(self), 'index range error' img = self.X[index] if self.transform is not None: img = self.transform(img) label = self.Y[index] return (img, label) class FakeRecorder(): def __init__(self): pass def print(self, *x): pass class BasicModel(): def __init__(self, model, criterion, optimizer, device, batch_size = 1, num_epochs = 10, stop_loss = 0.01, num_workers = 0, save_interval = None, save_dir = None, transform = None, collate_fn = None, recorder = None): self.model = model.to(device) self.batch_size = batch_size self.num_epochs = num_epochs self.stop_loss = stop_loss self.num_workers = num_workers self.criterion = criterion self.optimizer = optimizer self.transform = transform self.device = device if recorder is None: recorder = FakeRecorder() self.recorder = recorder self.save_interval = save_interval self.save_dir = save_dir self.collate_fn = collate_fn pass def _fit(self, data_loader, n_epoch, stop_loss): recorder = self.recorder recorder.print("model fitting") min_loss = 999999999 for epoch in range(n_epoch): loss_value = self.train_epoch(data_loader) recorder.print(f"{epoch}/{n_epoch} model training loss is {loss_value}") if loss_value < min_loss: min_loss = loss_value if epoch > 0 and self.save_interval is not None and epoch % self.save_interval == 0: assert self.save_dir is not None self.save(self.save_dir) if stop_loss is not None and loss_value < stop_loss: break recorder.print("Model fitted, minimal loss is ", min_loss) return loss_value def fit(self, data_loader = None, X = None, y = None): if data_loader is None: collate_fn = self.collate_fn transform = self.transform train_dataset = XYDataset(X, y, transform=transform) sampler = None data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, \ shuffle=True, sampler=sampler, num_workers=int(self.num_workers), \ collate_fn=collate_fn) return self._fit(data_loader, self.num_epochs, self.stop_loss) def train_epoch(self, data_loader): model = self.model criterion = self.criterion optimizer = self.optimizer device = self.device model.train() loss_value = 0 for _, data in enumerate(data_loader): X = data[0].to(device) Y = data[1].to(device) pred_Y = model(X) loss = criterion(pred_Y, Y) optimizer.zero_grad() loss.backward() optimizer.step() loss_value += loss.item() return loss_value def _predict(self, data_loader): model = self.model device = self.device model.eval() with torch.no_grad(): results = [] for _, data in enumerate(data_loader): X = data[0].to(device) pred_Y = model(X) results.append(pred_Y) return torch.cat(results, axis=0) def predict(self, data_loader = None, X = None, print_prefix = ""): if data_loader is None: collate_fn = self.collate_fn transform = self.transform Y = [0] * len(X) val_dataset = XYDataset(X, Y, transform=transform) sampler = None data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, \ shuffle=False, sampler=sampler, num_workers=int(self.num_workers), \ collate_fn=collate_fn) recorder = self.recorder recorder.print('Start Predict ', print_prefix) Y = self._predict(data_loader).argmax(axis=1) return [int(y) for y in Y] def predict_proba(self, data_loader = None, X = None, print_prefix = ""): if data_loader is None: collate_fn = self.collate_fn transform = self.transform Y = [0] * len(X) val_dataset = XYDataset(X, Y, transform=transform) sampler = None data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, \ shuffle=False, sampler=sampler, num_workers=int(self.num_workers), \ collate_fn=collate_fn) recorder = self.recorder recorder.print('Start Predict ', print_prefix) return torch.softmax(self._predict(data_loader), axis=1).cpu().numpy() def _val(self, data_loader, print_prefix): model = self.model criterion = self.criterion recorder = self.recorder device = self.device recorder.print('Start val ', print_prefix) model.eval() n_correct = 0 pred_num = 0 loss_value = 0 with torch.no_grad(): for _, data in enumerate(data_loader): X = data[0].to(device) Y = data[1].to(device) pred_Y = model(X) correct_num = sum(Y == pred_Y.argmax(axis=1)) loss = criterion(pred_Y, Y) loss_value += loss.item() n_correct += correct_num pred_num += len(X) accuracy = float(n_correct) / float(pred_num) recorder.print('[%s] Val loss: %f, accuray: %f' % (print_prefix, loss_value, accuracy)) return accuracy def val(self, data_loader = None, X = None, y = None, print_prefix = ""): if data_loader is None: collate_fn = self.collate_fn transform = self.transform val_dataset = XYDataset(X, y, transform=transform) sampler = None data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, \ shuffle=True, sampler=sampler, num_workers=int(self.num_workers), \ collate_fn=collate_fn) return self._val(data_loader, print_prefix) def score(self, data_loader = None, X = None, y = None, print_prefix = ""): return self.val(data_loader, X, y, print_prefix) def save(self, save_dir): recorder = self.recorder if not os.path.exists(save_dir): os.mkdir(save_dir) recorder.print("Saving model and opter") save_path = os.path.join(save_dir, "net.pth") torch.save(self.model.state_dict(), save_path) save_path = os.path.join(save_dir, "opt.pth") torch.save(self.optimizer.state_dict(), save_path) def load(self, load_dir): recorder = self.recorder recorder.print("Loading model and opter") load_path = os.path.join(load_dir, "net.pth") self.model.load_state_dict(torch.load(load_path)) load_path = os.path.join(load_dir, "opt.pth") self.optimizer.load_state_dict(torch.load(load_path)) if __name__ == "__main__": pass