Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10525413master
| @@ -305,6 +305,7 @@ class Trainers(object): | |||||
| face_detection_scrfd = 'face-detection-scrfd' | face_detection_scrfd = 'face-detection-scrfd' | ||||
| card_detection_scrfd = 'card-detection-scrfd' | card_detection_scrfd = 'card-detection-scrfd' | ||||
| image_inpainting = 'image-inpainting' | image_inpainting = 'image-inpainting' | ||||
| image_classification_team = 'image-classification-team' | |||||
| # nlp trainers | # nlp trainers | ||||
| bert_sentiment_analysis = 'bert-sentiment-analysis' | bert_sentiment_analysis = 'bert-sentiment-analysis' | ||||
| @@ -5,9 +5,13 @@ from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from .clip import CLIPTrainer | from .clip import CLIPTrainer | ||||
| from .team import TEAMImgClsTrainer | |||||
| else: | else: | ||||
| _import_structure = {'clip': ['CLIPTrainer']} | |||||
| _import_structure = { | |||||
| 'clip': ['CLIPTrainer'], | |||||
| 'team': ['TEAMImgClsTrainer'] | |||||
| } | |||||
| import sys | import sys | ||||
| @@ -0,0 +1,3 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from .team_trainer import TEAMImgClsTrainer | |||||
| @@ -0,0 +1,144 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| from collections import OrderedDict | |||||
| from typing import Callable, Dict, Optional | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torchvision.datasets as datasets | |||||
| import torchvision.transforms as transforms | |||||
| from sklearn.metrics import confusion_matrix | |||||
| from torch.optim import AdamW | |||||
| from torch.utils.data import DataLoader, Dataset | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.models.base import Model | |||||
| from modelscope.msdatasets import MsDataset | |||||
| from modelscope.trainers.base import BaseTrainer | |||||
| from modelscope.trainers.builder import TRAINERS | |||||
| from modelscope.trainers.multi_modal.team.team_trainer_utils import ( | |||||
| get_optimizer, train_mapping, val_mapping) | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import DownloadMode, ModeKeys | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @TRAINERS.register_module(module_name=Trainers.image_classification_team) | |||||
| class TEAMImgClsTrainer(BaseTrainer): | |||||
| def __init__(self, cfg_file: str, model: str, device_id: int, | |||||
| data_collator: Callable, train_dataset: Dataset, | |||||
| val_dataset: Dataset, *args, **kwargs): | |||||
| super().__init__(cfg_file) | |||||
| self.cfg = Config.from_file(cfg_file) | |||||
| team_model = Model.from_pretrained(model) | |||||
| image_model = team_model.model.image_model.vision_transformer | |||||
| classification_model = nn.Sequential( | |||||
| OrderedDict([('encoder', image_model), | |||||
| ('classifier', | |||||
| nn.Linear(768, self.cfg.dataset.class_num))])) | |||||
| self.model = classification_model | |||||
| for pname, param in self.model.named_parameters(): | |||||
| if 'encoder' in pname: | |||||
| param.requires_grad = False | |||||
| self.device_id = device_id | |||||
| self.total_epoch = self.cfg.train.epoch | |||||
| self.train_batch_size = self.cfg.train.batch_size | |||||
| self.val_batch_size = self.cfg.evaluation.batch_size | |||||
| self.ckpt_dir = self.cfg.train.ckpt_dir | |||||
| self.collate_fn = data_collator | |||||
| self.train_dataset = train_dataset | |||||
| self.val_dataset = val_dataset | |||||
| self.criterion = nn.CrossEntropyLoss().to(self.device_id) | |||||
| def train(self, *args, **kwargs): | |||||
| self.model.train() | |||||
| self.model.to(self.device_id) | |||||
| optimizer = get_optimizer(self.model) | |||||
| for epoch in range(self.total_epoch): | |||||
| train_params = { | |||||
| 'pin_memory': True, | |||||
| 'collate_fn': self.collate_fn, | |||||
| 'batch_size': self.train_batch_size, | |||||
| 'shuffle': True, | |||||
| 'drop_last': True, | |||||
| 'num_workers': 8 | |||||
| } | |||||
| train_loader = DataLoader(self.train_dataset, **train_params) | |||||
| for batch_idx, data in enumerate(train_loader): | |||||
| img_tensor, label_tensor = data['pixel_values'], data['labels'] | |||||
| img_tensor = img_tensor.to(self.device_id, non_blocking=True) | |||||
| label_tensor = label_tensor.to( | |||||
| self.device_id, non_blocking=True) | |||||
| pred_logits = self.model(img_tensor) | |||||
| loss = self.criterion(pred_logits, label_tensor) | |||||
| optimizer.zero_grad() | |||||
| loss.backward() | |||||
| optimizer.step() | |||||
| if batch_idx % 10 == 0: | |||||
| logger.info( | |||||
| 'epoch: {}, train batch {}/{}, loss={:.5f}'.format( | |||||
| epoch, batch_idx, len(train_loader), loss.item())) | |||||
| os.makedirs(self.ckpt_dir, exist_ok=True) | |||||
| torch.save(self.model.state_dict(), | |||||
| '{}/epoch{}.pth'.format(self.ckpt_dir, epoch)) | |||||
| self.evaluate() | |||||
| def evaluate(self, | |||||
| checkpoint_path: Optional[str] = None, | |||||
| *args, | |||||
| **kwargs) -> Dict[str, float]: | |||||
| if checkpoint_path is not None: | |||||
| checkpoint_params = torch.load(checkpoint_path, 'cpu') | |||||
| self.model.load_state_dict(checkpoint_params) | |||||
| self.model.eval() | |||||
| self.model.to(self.device_id) | |||||
| val_params = { | |||||
| 'collate_fn': self.collate_fn, | |||||
| 'batch_size': self.val_batch_size, | |||||
| 'shuffle': False, | |||||
| 'drop_last': False, | |||||
| 'num_workers': 8 | |||||
| } | |||||
| val_loader = DataLoader(self.val_dataset, **val_params) | |||||
| tp_cnt, processed_cnt = 0, 0 | |||||
| all_pred_labels, all_gt_labels = [], [] | |||||
| with torch.no_grad(): | |||||
| for batch_idx, data in enumerate(val_loader): | |||||
| img_tensor, label_tensor = data['pixel_values'], data['labels'] | |||||
| img_tensor = img_tensor.to(self.device_id, non_blocking=True) | |||||
| label_tensor = label_tensor.to( | |||||
| self.device_id, non_blocking=True) | |||||
| pred_logits = self.model(img_tensor) | |||||
| pred_labels = torch.max(pred_logits, dim=1)[1] | |||||
| tp_cnt += torch.sum(pred_labels == label_tensor).item() | |||||
| processed_cnt += img_tensor.shape[0] | |||||
| logger.info('Accuracy: {:.3f}'.format(tp_cnt / processed_cnt)) | |||||
| all_pred_labels.extend(pred_labels.tolist()) | |||||
| all_gt_labels.extend(label_tensor.tolist()) | |||||
| conf_mat = confusion_matrix(all_gt_labels, all_pred_labels) | |||||
| acc_mean_per_class = np.mean(conf_mat.diagonal() | |||||
| / conf_mat.sum(axis=1)) | |||||
| logger.info( | |||||
| 'Accuracy mean per class: {:.3f}'.format(acc_mean_per_class)) | |||||
| @@ -0,0 +1,87 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import torch | |||||
| import torchvision.transforms as transforms | |||||
| from PIL import Image | |||||
| from torch.optim import AdamW | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| train_transforms = transforms.Compose([ | |||||
| transforms.RandomResizedCrop(224), | |||||
| transforms.RandomHorizontalFlip(), | |||||
| transforms.ToTensor(), | |||||
| transforms.Normalize((0.48145466, 0.4578275, 0.40821073), | |||||
| (0.26862954, 0.26130258, 0.27577711)), | |||||
| ]) | |||||
| val_transforms = transforms.Compose([ | |||||
| transforms.Resize(256), | |||||
| transforms.CenterCrop(224), | |||||
| transforms.ToTensor(), | |||||
| transforms.Normalize((0.48145466, 0.4578275, 0.40821073), | |||||
| (0.26862954, 0.26130258, 0.27577711)), | |||||
| ]) | |||||
| def train_mapping(examples): | |||||
| examples['pixel_values'] = [ | |||||
| train_transforms(Image.open(image).convert('RGB')) | |||||
| for image in examples['image:FILE'] | |||||
| ] | |||||
| examples['labels'] = [label for label in examples['label:LABEL']] | |||||
| return examples | |||||
| def val_mapping(examples): | |||||
| examples['pixel_values'] = [ | |||||
| val_transforms(Image.open(image).convert('RGB')) | |||||
| for image in examples['image:FILE'] | |||||
| ] | |||||
| examples['labels'] = [label for label in examples['label:LABEL']] | |||||
| return examples | |||||
| def collate_fn(examples): | |||||
| images = [] | |||||
| labels = [] | |||||
| for example in examples: | |||||
| images.append((example['pixel_values'])) | |||||
| labels.append(example['labels']) | |||||
| pixel_values = torch.stack(images) | |||||
| labels = torch.tensor(labels) | |||||
| return {'pixel_values': pixel_values, 'labels': labels} | |||||
| def get_params_groups(ddp_model, lr): | |||||
| large_lr_params = [] | |||||
| small_lr_params = [] | |||||
| for name, param in ddp_model.named_parameters(): | |||||
| if not param.requires_grad: | |||||
| continue | |||||
| if 'encoder' in name: | |||||
| small_lr_params.append(param) | |||||
| elif 'classifier' in name: | |||||
| large_lr_params.append(param) | |||||
| else: | |||||
| logger.info('skip param: {}'.format(name)) | |||||
| params_groups = [{ | |||||
| 'params': small_lr_params, | |||||
| 'lr': lr / 10.0 | |||||
| }, { | |||||
| 'params': large_lr_params, | |||||
| 'lr': lr | |||||
| }] | |||||
| return params_groups | |||||
| def get_optimizer(ddp_model): | |||||
| lr_init = 1e-3 | |||||
| betas = [0.9, 0.999] | |||||
| weight_decay = 0.02 | |||||
| params_groups = get_params_groups(ddp_model, lr=lr_init) | |||||
| return AdamW( | |||||
| params_groups, lr=lr_init, betas=betas, weight_decay=weight_decay) | |||||
| @@ -0,0 +1,94 @@ | |||||
| import os | |||||
| import unittest | |||||
| import json | |||||
| import requests | |||||
| import torch | |||||
| import torch.distributed as dist | |||||
| import torch.multiprocessing as mp | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.msdatasets import MsDataset | |||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.trainers.multi_modal.team.team_trainer_utils import ( | |||||
| collate_fn, train_mapping, val_mapping) | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import DownloadMode, ModeKeys, ModelFile | |||||
| from modelscope.utils.logger import get_logger | |||||
| from modelscope.utils.test_utils import test_level | |||||
| logger = get_logger() | |||||
| def train_worker(device_id): | |||||
| model_id = 'damo/multi-modal_team-vit-large-patch14_multi-modal-similarity' | |||||
| ckpt_dir = './ckpt' | |||||
| os.makedirs(ckpt_dir, exist_ok=True) | |||||
| # Use epoch=1 for faster training here | |||||
| cfg = Config({ | |||||
| 'framework': 'pytorch', | |||||
| 'task': 'multi-modal-similarity', | |||||
| 'pipeline': { | |||||
| 'type': 'multi-modal-similarity' | |||||
| }, | |||||
| 'model': { | |||||
| 'type': 'team-multi-modal-similarity' | |||||
| }, | |||||
| 'dataset': { | |||||
| 'name': 'Caltech101', | |||||
| 'class_num': 101 | |||||
| }, | |||||
| 'preprocessor': {}, | |||||
| 'train': { | |||||
| 'epoch': 1, | |||||
| 'batch_size': 32, | |||||
| 'ckpt_dir': ckpt_dir | |||||
| }, | |||||
| 'evaluation': { | |||||
| 'batch_size': 64 | |||||
| } | |||||
| }) | |||||
| cfg_file = '{}/{}'.format(ckpt_dir, ModelFile.CONFIGURATION) | |||||
| cfg.dump(cfg_file) | |||||
| train_dataset = MsDataset.load( | |||||
| cfg.dataset.name, | |||||
| namespace='modelscope', | |||||
| split='train', | |||||
| download_mode=DownloadMode.FORCE_REDOWNLOAD).to_hf_dataset() | |||||
| train_dataset = train_dataset.with_transform(train_mapping) | |||||
| val_dataset = MsDataset.load( | |||||
| cfg.dataset.name, | |||||
| namespace='modelscope', | |||||
| split='validation', | |||||
| download_mode=DownloadMode.FORCE_REDOWNLOAD).to_hf_dataset() | |||||
| val_dataset = val_dataset.with_transform(val_mapping) | |||||
| default_args = dict( | |||||
| cfg_file=cfg_file, | |||||
| model=model_id, | |||||
| device_id=device_id, | |||||
| data_collator=collate_fn, | |||||
| train_dataset=train_dataset, | |||||
| val_dataset=val_dataset) | |||||
| trainer = build_trainer( | |||||
| name=Trainers.image_classification_team, default_args=default_args) | |||||
| trainer.train() | |||||
| trainer.evaluate() | |||||
| class TEAMTransferTrainerTest(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_trainer(self): | |||||
| if torch.cuda.device_count() > 0: | |||||
| train_worker(device_id=0) | |||||
| else: | |||||
| train_worker(device_id=-1) | |||||
| logger.info('Training done') | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||