Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10525413master
| @@ -305,6 +305,7 @@ class Trainers(object): | |||
| face_detection_scrfd = 'face-detection-scrfd' | |||
| card_detection_scrfd = 'card-detection-scrfd' | |||
| image_inpainting = 'image-inpainting' | |||
| image_classification_team = 'image-classification-team' | |||
| # nlp trainers | |||
| bert_sentiment_analysis = 'bert-sentiment-analysis' | |||
| @@ -5,9 +5,13 @@ from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .clip import CLIPTrainer | |||
| from .team import TEAMImgClsTrainer | |||
| else: | |||
| _import_structure = {'clip': ['CLIPTrainer']} | |||
| _import_structure = { | |||
| 'clip': ['CLIPTrainer'], | |||
| 'team': ['TEAMImgClsTrainer'] | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,3 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from .team_trainer import TEAMImgClsTrainer | |||
| @@ -0,0 +1,144 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| from collections import OrderedDict | |||
| from typing import Callable, Dict, Optional | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| import torchvision.datasets as datasets | |||
| import torchvision.transforms as transforms | |||
| from sklearn.metrics import confusion_matrix | |||
| from torch.optim import AdamW | |||
| from torch.utils.data import DataLoader, Dataset | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.models.base import Model | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.trainers.base import BaseTrainer | |||
| from modelscope.trainers.builder import TRAINERS | |||
| from modelscope.trainers.multi_modal.team.team_trainer_utils import ( | |||
| get_optimizer, train_mapping, val_mapping) | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import DownloadMode, ModeKeys | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @TRAINERS.register_module(module_name=Trainers.image_classification_team) | |||
| class TEAMImgClsTrainer(BaseTrainer): | |||
| def __init__(self, cfg_file: str, model: str, device_id: int, | |||
| data_collator: Callable, train_dataset: Dataset, | |||
| val_dataset: Dataset, *args, **kwargs): | |||
| super().__init__(cfg_file) | |||
| self.cfg = Config.from_file(cfg_file) | |||
| team_model = Model.from_pretrained(model) | |||
| image_model = team_model.model.image_model.vision_transformer | |||
| classification_model = nn.Sequential( | |||
| OrderedDict([('encoder', image_model), | |||
| ('classifier', | |||
| nn.Linear(768, self.cfg.dataset.class_num))])) | |||
| self.model = classification_model | |||
| for pname, param in self.model.named_parameters(): | |||
| if 'encoder' in pname: | |||
| param.requires_grad = False | |||
| self.device_id = device_id | |||
| self.total_epoch = self.cfg.train.epoch | |||
| self.train_batch_size = self.cfg.train.batch_size | |||
| self.val_batch_size = self.cfg.evaluation.batch_size | |||
| self.ckpt_dir = self.cfg.train.ckpt_dir | |||
| self.collate_fn = data_collator | |||
| self.train_dataset = train_dataset | |||
| self.val_dataset = val_dataset | |||
| self.criterion = nn.CrossEntropyLoss().to(self.device_id) | |||
| def train(self, *args, **kwargs): | |||
| self.model.train() | |||
| self.model.to(self.device_id) | |||
| optimizer = get_optimizer(self.model) | |||
| for epoch in range(self.total_epoch): | |||
| train_params = { | |||
| 'pin_memory': True, | |||
| 'collate_fn': self.collate_fn, | |||
| 'batch_size': self.train_batch_size, | |||
| 'shuffle': True, | |||
| 'drop_last': True, | |||
| 'num_workers': 8 | |||
| } | |||
| train_loader = DataLoader(self.train_dataset, **train_params) | |||
| for batch_idx, data in enumerate(train_loader): | |||
| img_tensor, label_tensor = data['pixel_values'], data['labels'] | |||
| img_tensor = img_tensor.to(self.device_id, non_blocking=True) | |||
| label_tensor = label_tensor.to( | |||
| self.device_id, non_blocking=True) | |||
| pred_logits = self.model(img_tensor) | |||
| loss = self.criterion(pred_logits, label_tensor) | |||
| optimizer.zero_grad() | |||
| loss.backward() | |||
| optimizer.step() | |||
| if batch_idx % 10 == 0: | |||
| logger.info( | |||
| 'epoch: {}, train batch {}/{}, loss={:.5f}'.format( | |||
| epoch, batch_idx, len(train_loader), loss.item())) | |||
| os.makedirs(self.ckpt_dir, exist_ok=True) | |||
| torch.save(self.model.state_dict(), | |||
| '{}/epoch{}.pth'.format(self.ckpt_dir, epoch)) | |||
| self.evaluate() | |||
| def evaluate(self, | |||
| checkpoint_path: Optional[str] = None, | |||
| *args, | |||
| **kwargs) -> Dict[str, float]: | |||
| if checkpoint_path is not None: | |||
| checkpoint_params = torch.load(checkpoint_path, 'cpu') | |||
| self.model.load_state_dict(checkpoint_params) | |||
| self.model.eval() | |||
| self.model.to(self.device_id) | |||
| val_params = { | |||
| 'collate_fn': self.collate_fn, | |||
| 'batch_size': self.val_batch_size, | |||
| 'shuffle': False, | |||
| 'drop_last': False, | |||
| 'num_workers': 8 | |||
| } | |||
| val_loader = DataLoader(self.val_dataset, **val_params) | |||
| tp_cnt, processed_cnt = 0, 0 | |||
| all_pred_labels, all_gt_labels = [], [] | |||
| with torch.no_grad(): | |||
| for batch_idx, data in enumerate(val_loader): | |||
| img_tensor, label_tensor = data['pixel_values'], data['labels'] | |||
| img_tensor = img_tensor.to(self.device_id, non_blocking=True) | |||
| label_tensor = label_tensor.to( | |||
| self.device_id, non_blocking=True) | |||
| pred_logits = self.model(img_tensor) | |||
| pred_labels = torch.max(pred_logits, dim=1)[1] | |||
| tp_cnt += torch.sum(pred_labels == label_tensor).item() | |||
| processed_cnt += img_tensor.shape[0] | |||
| logger.info('Accuracy: {:.3f}'.format(tp_cnt / processed_cnt)) | |||
| all_pred_labels.extend(pred_labels.tolist()) | |||
| all_gt_labels.extend(label_tensor.tolist()) | |||
| conf_mat = confusion_matrix(all_gt_labels, all_pred_labels) | |||
| acc_mean_per_class = np.mean(conf_mat.diagonal() | |||
| / conf_mat.sum(axis=1)) | |||
| logger.info( | |||
| 'Accuracy mean per class: {:.3f}'.format(acc_mean_per_class)) | |||
| @@ -0,0 +1,87 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import torch | |||
| import torchvision.transforms as transforms | |||
| from PIL import Image | |||
| from torch.optim import AdamW | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| train_transforms = transforms.Compose([ | |||
| transforms.RandomResizedCrop(224), | |||
| transforms.RandomHorizontalFlip(), | |||
| transforms.ToTensor(), | |||
| transforms.Normalize((0.48145466, 0.4578275, 0.40821073), | |||
| (0.26862954, 0.26130258, 0.27577711)), | |||
| ]) | |||
| val_transforms = transforms.Compose([ | |||
| transforms.Resize(256), | |||
| transforms.CenterCrop(224), | |||
| transforms.ToTensor(), | |||
| transforms.Normalize((0.48145466, 0.4578275, 0.40821073), | |||
| (0.26862954, 0.26130258, 0.27577711)), | |||
| ]) | |||
| def train_mapping(examples): | |||
| examples['pixel_values'] = [ | |||
| train_transforms(Image.open(image).convert('RGB')) | |||
| for image in examples['image:FILE'] | |||
| ] | |||
| examples['labels'] = [label for label in examples['label:LABEL']] | |||
| return examples | |||
| def val_mapping(examples): | |||
| examples['pixel_values'] = [ | |||
| val_transforms(Image.open(image).convert('RGB')) | |||
| for image in examples['image:FILE'] | |||
| ] | |||
| examples['labels'] = [label for label in examples['label:LABEL']] | |||
| return examples | |||
| def collate_fn(examples): | |||
| images = [] | |||
| labels = [] | |||
| for example in examples: | |||
| images.append((example['pixel_values'])) | |||
| labels.append(example['labels']) | |||
| pixel_values = torch.stack(images) | |||
| labels = torch.tensor(labels) | |||
| return {'pixel_values': pixel_values, 'labels': labels} | |||
| def get_params_groups(ddp_model, lr): | |||
| large_lr_params = [] | |||
| small_lr_params = [] | |||
| for name, param in ddp_model.named_parameters(): | |||
| if not param.requires_grad: | |||
| continue | |||
| if 'encoder' in name: | |||
| small_lr_params.append(param) | |||
| elif 'classifier' in name: | |||
| large_lr_params.append(param) | |||
| else: | |||
| logger.info('skip param: {}'.format(name)) | |||
| params_groups = [{ | |||
| 'params': small_lr_params, | |||
| 'lr': lr / 10.0 | |||
| }, { | |||
| 'params': large_lr_params, | |||
| 'lr': lr | |||
| }] | |||
| return params_groups | |||
| def get_optimizer(ddp_model): | |||
| lr_init = 1e-3 | |||
| betas = [0.9, 0.999] | |||
| weight_decay = 0.02 | |||
| params_groups = get_params_groups(ddp_model, lr=lr_init) | |||
| return AdamW( | |||
| params_groups, lr=lr_init, betas=betas, weight_decay=weight_decay) | |||
| @@ -0,0 +1,94 @@ | |||
| import os | |||
| import unittest | |||
| import json | |||
| import requests | |||
| import torch | |||
| import torch.distributed as dist | |||
| import torch.multiprocessing as mp | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.trainers.multi_modal.team.team_trainer_utils import ( | |||
| collate_fn, train_mapping, val_mapping) | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import DownloadMode, ModeKeys, ModelFile | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.test_utils import test_level | |||
| logger = get_logger() | |||
| def train_worker(device_id): | |||
| model_id = 'damo/multi-modal_team-vit-large-patch14_multi-modal-similarity' | |||
| ckpt_dir = './ckpt' | |||
| os.makedirs(ckpt_dir, exist_ok=True) | |||
| # Use epoch=1 for faster training here | |||
| cfg = Config({ | |||
| 'framework': 'pytorch', | |||
| 'task': 'multi-modal-similarity', | |||
| 'pipeline': { | |||
| 'type': 'multi-modal-similarity' | |||
| }, | |||
| 'model': { | |||
| 'type': 'team-multi-modal-similarity' | |||
| }, | |||
| 'dataset': { | |||
| 'name': 'Caltech101', | |||
| 'class_num': 101 | |||
| }, | |||
| 'preprocessor': {}, | |||
| 'train': { | |||
| 'epoch': 1, | |||
| 'batch_size': 32, | |||
| 'ckpt_dir': ckpt_dir | |||
| }, | |||
| 'evaluation': { | |||
| 'batch_size': 64 | |||
| } | |||
| }) | |||
| cfg_file = '{}/{}'.format(ckpt_dir, ModelFile.CONFIGURATION) | |||
| cfg.dump(cfg_file) | |||
| train_dataset = MsDataset.load( | |||
| cfg.dataset.name, | |||
| namespace='modelscope', | |||
| split='train', | |||
| download_mode=DownloadMode.FORCE_REDOWNLOAD).to_hf_dataset() | |||
| train_dataset = train_dataset.with_transform(train_mapping) | |||
| val_dataset = MsDataset.load( | |||
| cfg.dataset.name, | |||
| namespace='modelscope', | |||
| split='validation', | |||
| download_mode=DownloadMode.FORCE_REDOWNLOAD).to_hf_dataset() | |||
| val_dataset = val_dataset.with_transform(val_mapping) | |||
| default_args = dict( | |||
| cfg_file=cfg_file, | |||
| model=model_id, | |||
| device_id=device_id, | |||
| data_collator=collate_fn, | |||
| train_dataset=train_dataset, | |||
| val_dataset=val_dataset) | |||
| trainer = build_trainer( | |||
| name=Trainers.image_classification_team, default_args=default_args) | |||
| trainer.train() | |||
| trainer.evaluate() | |||
| class TEAMTransferTrainerTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_trainer(self): | |||
| if torch.cuda.device_count() > 0: | |||
| train_worker(device_id=0) | |||
| else: | |||
| train_worker(device_id=-1) | |||
| logger.info('Training done') | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||