eniac.xcw yingda.chen 3 years ago
parent
commit
8886c3c1ae
6 changed files with 334 additions and 1 deletions
  1. +1
    -0
      modelscope/metainfo.py
  2. +5
    -1
      modelscope/trainers/multi_modal/__init__.py
  3. +3
    -0
      modelscope/trainers/multi_modal/team/__init__.py
  4. +144
    -0
      modelscope/trainers/multi_modal/team/team_trainer.py
  5. +87
    -0
      modelscope/trainers/multi_modal/team/team_trainer_utils.py
  6. +94
    -0
      tests/trainers/test_team_transfer_trainer.py

+ 1
- 0
modelscope/metainfo.py View File

@@ -305,6 +305,7 @@ class Trainers(object):
face_detection_scrfd = 'face-detection-scrfd'
card_detection_scrfd = 'card-detection-scrfd'
image_inpainting = 'image-inpainting'
image_classification_team = 'image-classification-team'

# nlp trainers
bert_sentiment_analysis = 'bert-sentiment-analysis'


+ 5
- 1
modelscope/trainers/multi_modal/__init__.py View File

@@ -5,9 +5,13 @@ from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .clip import CLIPTrainer
from .team import TEAMImgClsTrainer

else:
_import_structure = {'clip': ['CLIPTrainer']}
_import_structure = {
'clip': ['CLIPTrainer'],
'team': ['TEAMImgClsTrainer']
}

import sys



+ 3
- 0
modelscope/trainers/multi_modal/team/__init__.py View File

@@ -0,0 +1,3 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from .team_trainer import TEAMImgClsTrainer

+ 144
- 0
modelscope/trainers/multi_modal/team/team_trainer.py View File

@@ -0,0 +1,144 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
from collections import OrderedDict
from typing import Callable, Dict, Optional

import numpy as np
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from sklearn.metrics import confusion_matrix
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset

from modelscope.metainfo import Trainers
from modelscope.models.base import Model
from modelscope.msdatasets import MsDataset
from modelscope.trainers.base import BaseTrainer
from modelscope.trainers.builder import TRAINERS
from modelscope.trainers.multi_modal.team.team_trainer_utils import (
get_optimizer, train_mapping, val_mapping)
from modelscope.utils.config import Config
from modelscope.utils.constant import DownloadMode, ModeKeys
from modelscope.utils.logger import get_logger

logger = get_logger()


@TRAINERS.register_module(module_name=Trainers.image_classification_team)
class TEAMImgClsTrainer(BaseTrainer):

def __init__(self, cfg_file: str, model: str, device_id: int,
data_collator: Callable, train_dataset: Dataset,
val_dataset: Dataset, *args, **kwargs):
super().__init__(cfg_file)

self.cfg = Config.from_file(cfg_file)
team_model = Model.from_pretrained(model)
image_model = team_model.model.image_model.vision_transformer
classification_model = nn.Sequential(
OrderedDict([('encoder', image_model),
('classifier',
nn.Linear(768, self.cfg.dataset.class_num))]))
self.model = classification_model

for pname, param in self.model.named_parameters():
if 'encoder' in pname:
param.requires_grad = False

self.device_id = device_id
self.total_epoch = self.cfg.train.epoch
self.train_batch_size = self.cfg.train.batch_size
self.val_batch_size = self.cfg.evaluation.batch_size
self.ckpt_dir = self.cfg.train.ckpt_dir

self.collate_fn = data_collator
self.train_dataset = train_dataset
self.val_dataset = val_dataset

self.criterion = nn.CrossEntropyLoss().to(self.device_id)

def train(self, *args, **kwargs):
self.model.train()
self.model.to(self.device_id)

optimizer = get_optimizer(self.model)

for epoch in range(self.total_epoch):
train_params = {
'pin_memory': True,
'collate_fn': self.collate_fn,
'batch_size': self.train_batch_size,
'shuffle': True,
'drop_last': True,
'num_workers': 8
}

train_loader = DataLoader(self.train_dataset, **train_params)

for batch_idx, data in enumerate(train_loader):
img_tensor, label_tensor = data['pixel_values'], data['labels']
img_tensor = img_tensor.to(self.device_id, non_blocking=True)
label_tensor = label_tensor.to(
self.device_id, non_blocking=True)

pred_logits = self.model(img_tensor)
loss = self.criterion(pred_logits, label_tensor)

optimizer.zero_grad()
loss.backward()
optimizer.step()

if batch_idx % 10 == 0:
logger.info(
'epoch: {}, train batch {}/{}, loss={:.5f}'.format(
epoch, batch_idx, len(train_loader), loss.item()))

os.makedirs(self.ckpt_dir, exist_ok=True)
torch.save(self.model.state_dict(),
'{}/epoch{}.pth'.format(self.ckpt_dir, epoch))
self.evaluate()

def evaluate(self,
checkpoint_path: Optional[str] = None,
*args,
**kwargs) -> Dict[str, float]:
if checkpoint_path is not None:
checkpoint_params = torch.load(checkpoint_path, 'cpu')
self.model.load_state_dict(checkpoint_params)
self.model.eval()
self.model.to(self.device_id)

val_params = {
'collate_fn': self.collate_fn,
'batch_size': self.val_batch_size,
'shuffle': False,
'drop_last': False,
'num_workers': 8
}
val_loader = DataLoader(self.val_dataset, **val_params)

tp_cnt, processed_cnt = 0, 0
all_pred_labels, all_gt_labels = [], []
with torch.no_grad():
for batch_idx, data in enumerate(val_loader):
img_tensor, label_tensor = data['pixel_values'], data['labels']
img_tensor = img_tensor.to(self.device_id, non_blocking=True)
label_tensor = label_tensor.to(
self.device_id, non_blocking=True)

pred_logits = self.model(img_tensor)
pred_labels = torch.max(pred_logits, dim=1)[1]
tp_cnt += torch.sum(pred_labels == label_tensor).item()
processed_cnt += img_tensor.shape[0]
logger.info('Accuracy: {:.3f}'.format(tp_cnt / processed_cnt))

all_pred_labels.extend(pred_labels.tolist())
all_gt_labels.extend(label_tensor.tolist())
conf_mat = confusion_matrix(all_gt_labels, all_pred_labels)
acc_mean_per_class = np.mean(conf_mat.diagonal()
/ conf_mat.sum(axis=1))
logger.info(
'Accuracy mean per class: {:.3f}'.format(acc_mean_per_class))

+ 87
- 0
modelscope/trainers/multi_modal/team/team_trainer_utils.py View File

@@ -0,0 +1,87 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torchvision.transforms as transforms
from PIL import Image
from torch.optim import AdamW

from modelscope.utils.logger import get_logger

logger = get_logger()

train_transforms = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
val_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])


def train_mapping(examples):
examples['pixel_values'] = [
train_transforms(Image.open(image).convert('RGB'))
for image in examples['image:FILE']
]
examples['labels'] = [label for label in examples['label:LABEL']]
return examples


def val_mapping(examples):
examples['pixel_values'] = [
val_transforms(Image.open(image).convert('RGB'))
for image in examples['image:FILE']
]
examples['labels'] = [label for label in examples['label:LABEL']]
return examples


def collate_fn(examples):
images = []
labels = []
for example in examples:
images.append((example['pixel_values']))
labels.append(example['labels'])

pixel_values = torch.stack(images)
labels = torch.tensor(labels)
return {'pixel_values': pixel_values, 'labels': labels}


def get_params_groups(ddp_model, lr):
large_lr_params = []
small_lr_params = []
for name, param in ddp_model.named_parameters():
if not param.requires_grad:
continue

if 'encoder' in name:
small_lr_params.append(param)
elif 'classifier' in name:
large_lr_params.append(param)
else:
logger.info('skip param: {}'.format(name))

params_groups = [{
'params': small_lr_params,
'lr': lr / 10.0
}, {
'params': large_lr_params,
'lr': lr
}]
return params_groups


def get_optimizer(ddp_model):
lr_init = 1e-3
betas = [0.9, 0.999]
weight_decay = 0.02
params_groups = get_params_groups(ddp_model, lr=lr_init)
return AdamW(
params_groups, lr=lr_init, betas=betas, weight_decay=weight_decay)

+ 94
- 0
tests/trainers/test_team_transfer_trainer.py View File

@@ -0,0 +1,94 @@
import os
import unittest

import json
import requests
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.trainers.multi_modal.team.team_trainer_utils import (
collate_fn, train_mapping, val_mapping)
from modelscope.utils.config import Config
from modelscope.utils.constant import DownloadMode, ModeKeys, ModelFile
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import test_level

logger = get_logger()


def train_worker(device_id):
model_id = 'damo/multi-modal_team-vit-large-patch14_multi-modal-similarity'
ckpt_dir = './ckpt'
os.makedirs(ckpt_dir, exist_ok=True)
# Use epoch=1 for faster training here
cfg = Config({
'framework': 'pytorch',
'task': 'multi-modal-similarity',
'pipeline': {
'type': 'multi-modal-similarity'
},
'model': {
'type': 'team-multi-modal-similarity'
},
'dataset': {
'name': 'Caltech101',
'class_num': 101
},
'preprocessor': {},
'train': {
'epoch': 1,
'batch_size': 32,
'ckpt_dir': ckpt_dir
},
'evaluation': {
'batch_size': 64
}
})
cfg_file = '{}/{}'.format(ckpt_dir, ModelFile.CONFIGURATION)
cfg.dump(cfg_file)

train_dataset = MsDataset.load(
cfg.dataset.name,
namespace='modelscope',
split='train',
download_mode=DownloadMode.FORCE_REDOWNLOAD).to_hf_dataset()
train_dataset = train_dataset.with_transform(train_mapping)
val_dataset = MsDataset.load(
cfg.dataset.name,
namespace='modelscope',
split='validation',
download_mode=DownloadMode.FORCE_REDOWNLOAD).to_hf_dataset()
val_dataset = val_dataset.with_transform(val_mapping)

default_args = dict(
cfg_file=cfg_file,
model=model_id,
device_id=device_id,
data_collator=collate_fn,
train_dataset=train_dataset,
val_dataset=val_dataset)

trainer = build_trainer(
name=Trainers.image_classification_team, default_args=default_args)
trainer.train()
trainer.evaluate()


class TEAMTransferTrainerTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer(self):
if torch.cuda.device_count() > 0:
train_worker(device_id=0)
else:
train_worker(device_id=-1)
logger.info('Training done')


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save