diff --git a/data/test/images/coco_cn/train.json b/data/test/images/coco_cn/train.json new file mode 100644 index 00000000..634706bd --- /dev/null +++ b/data/test/images/coco_cn/train.json @@ -0,0 +1,51 @@ +[ + { + "image": "train/COCO_train2014_000000496606.jpg", + "caption": [ + "一只黄色的小狗趴在长椅上" + ] + }, + { + "image": "val/COCO_val2014_000000043734.jpg", + "caption": [ + "两只黑色的狗从水里翻着水花游过来" + ] + }, + { + "image": "train/COCO_train2014_000000404748.jpg", + "caption": [ + "两只长颈鹿站在岩石旁的草地上" + ] + }, + { + "image": "val/COCO_val2014_000000574392.jpg", + "caption": [ + "一个木制的公园长椅在森林里。" + ] + }, + { + "image": "train/COCO_train2014_000000563734.jpg", + "caption": [ + "许多公交车排成队在广场上停着。" + ] + }, + { + "image": "train/COCO_train2014_000000197406.jpg", + "caption": [ + "一个男人和一只长颈鹿站在沙滩上" + ] + }, + { + "image": "val/COCO_val2014_000000473869.jpg", + "caption": [ + "一个微笑的男人在厨房里做饭。" + ] + }, + { + "image": "train/COCO_train2014_000000021183.jpg", + "caption": [ + "一个年龄比较大,坐在街道旁座椅上的男人手里握着一个装着写有标语的板子的手推车", + "一个年老的男人坐在街道上的长椅上,手搭在面前放着告示牌的小推车上" + ] + } +] diff --git a/data/test/images/coco_cn/train/COCO_train2014_000000021183.jpg b/data/test/images/coco_cn/train/COCO_train2014_000000021183.jpg new file mode 100644 index 00000000..6d684e76 --- /dev/null +++ b/data/test/images/coco_cn/train/COCO_train2014_000000021183.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eff26436ee5ca4146a5c7218c8a1814a324574e92114736792dcc768ac1e566f +size 134292 diff --git a/data/test/images/coco_cn/train/COCO_train2014_000000177625.jpg b/data/test/images/coco_cn/train/COCO_train2014_000000177625.jpg new file mode 100644 index 00000000..57d9c322 --- /dev/null +++ b/data/test/images/coco_cn/train/COCO_train2014_000000177625.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:07fb36fb94301aa067c1c7f9ca4c8c04d6d7282b4a5494e392c54928d242a56b +size 149178 diff --git a/data/test/images/coco_cn/train/COCO_train2014_000000197406.jpg b/data/test/images/coco_cn/train/COCO_train2014_000000197406.jpg new file mode 100644 index 00000000..fbc2aeea --- /dev/null +++ b/data/test/images/coco_cn/train/COCO_train2014_000000197406.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:33473d2a21e669196271e28eca437696625e4a5e11eb6efc5b57e7961f15cf0d +size 68914 diff --git a/data/test/images/coco_cn/train/COCO_train2014_000000275612.jpg b/data/test/images/coco_cn/train/COCO_train2014_000000275612.jpg new file mode 100644 index 00000000..0b36fd13 --- /dev/null +++ b/data/test/images/coco_cn/train/COCO_train2014_000000275612.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1c1f6dc406b0e08b43668e73f9700e63420eb4e384a53c539062e89315b64ad6 +size 84248 diff --git a/data/test/images/coco_cn/train/COCO_train2014_000000404748.jpg b/data/test/images/coco_cn/train/COCO_train2014_000000404748.jpg new file mode 100644 index 00000000..43633b77 --- /dev/null +++ b/data/test/images/coco_cn/train/COCO_train2014_000000404748.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ed36ab05878caee478d6532777c862af11a4c62182ba989dfb3bf32e41277c65 +size 239503 diff --git a/data/test/images/coco_cn/train/COCO_train2014_000000493952.jpg b/data/test/images/coco_cn/train/COCO_train2014_000000493952.jpg new file mode 100644 index 00000000..f8f4a2e9 --- /dev/null +++ b/data/test/images/coco_cn/train/COCO_train2014_000000493952.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:95815c59443288b019e496d0c81cf8e734b347e8a31d996a9f1463eb506f3717 +size 177175 diff --git a/data/test/images/coco_cn/train/COCO_train2014_000000496606.jpg b/data/test/images/coco_cn/train/COCO_train2014_000000496606.jpg new file mode 100644 index 00000000..292457aa --- /dev/null +++ b/data/test/images/coco_cn/train/COCO_train2014_000000496606.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7482e789876cbdd18e1e5f0487d2a10f40be1cf4ce696d8e203da80418ec580b +size 195821 diff --git a/data/test/images/coco_cn/train/COCO_train2014_000000563734.jpg b/data/test/images/coco_cn/train/COCO_train2014_000000563734.jpg new file mode 100644 index 00000000..e1083b01 --- /dev/null +++ b/data/test/images/coco_cn/train/COCO_train2014_000000563734.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df3cee336d965ca249b5e4acd9618d0e2d0e267267222408b6565bb331a5fb23 +size 198775 diff --git a/data/test/images/coco_cn/train/COCO_train2014_000000573854.jpg b/data/test/images/coco_cn/train/COCO_train2014_000000573854.jpg new file mode 100644 index 00000000..bfdaeb4d --- /dev/null +++ b/data/test/images/coco_cn/train/COCO_train2014_000000573854.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c03f9d1eb963d6b385e22f3a26d202bea7637d3effd347af53435f5ad9434d72 +size 179422 diff --git a/data/test/images/coco_cn/val.json b/data/test/images/coco_cn/val.json new file mode 100644 index 00000000..cab5adf0 --- /dev/null +++ b/data/test/images/coco_cn/val.json @@ -0,0 +1,52 @@ +[ + { + "image": "train/COCO_train2014_000000573854.jpg", + "caption": [ + "机场跑道的喷气式飞机正准备起飞。" + ] + }, + { + "image": "val/COCO_val2014_000000412975.jpg", + "caption": [ + "一个女孩走下台阶。" + ] + }, + { + "image": "val/COCO_val2014_000000341725.jpg", + "caption": [ + "窗台上蓝色的花瓶里有一束粉色的郁金香。" + ] + }, + { + "image": "val/COCO_val2014_000000163020.jpg", + "caption": [ + "一只海鸥在水面上飞翔。" + ] + }, + { + "image": "train/COCO_train2014_000000177625.jpg", + "caption": [ + "一男一女在聚会上玩电子游戏,男人的脚边趴着一只狗" + ] + }, + { + "image": "train/COCO_train2014_000000275612.jpg", + "caption": [ + "厕所中的一个马桶", + "浴室里,高档的马桶与各式洗浴用品一应俱全。" + ] + }, + { + "image": "train/COCO_train2014_000000493952.jpg", + "caption": [ + "一辆黑色轿车停在一栋大楼前。" + ] + }, + { + "image": "val/COCO_val2014_000000044723.jpg", + "caption": [ + "阴天下一张伦敦塔的照片。", + "一座大楼的顶端悬挂着钟表。" + ] + } +] diff --git a/data/test/images/coco_cn/val/COCO_val2014_000000043734.jpg b/data/test/images/coco_cn/val/COCO_val2014_000000043734.jpg new file mode 100644 index 00000000..b9293cce --- /dev/null +++ b/data/test/images/coco_cn/val/COCO_val2014_000000043734.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c26dc7c54a1202744d50bc2186ea2a49865879a3a3a174099c4e9ecc1199a16a +size 93126 diff --git a/data/test/images/coco_cn/val/COCO_val2014_000000044723.jpg b/data/test/images/coco_cn/val/COCO_val2014_000000044723.jpg new file mode 100644 index 00000000..afaf372d --- /dev/null +++ b/data/test/images/coco_cn/val/COCO_val2014_000000044723.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:24cd7ad56cf00d57a7b2d182957a8ad6b44d5eb55dfe3bc69ad5a292151d482e +size 122140 diff --git a/data/test/images/coco_cn/val/COCO_val2014_000000163020.jpg b/data/test/images/coco_cn/val/COCO_val2014_000000163020.jpg new file mode 100644 index 00000000..de16ebc5 --- /dev/null +++ b/data/test/images/coco_cn/val/COCO_val2014_000000163020.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee9c78c8c141d1bb3cd064f2a003f0786a19c0b2cc54e0cfa2ee2459daf7bebe +size 63796 diff --git a/data/test/images/coco_cn/val/COCO_val2014_000000341725.jpg b/data/test/images/coco_cn/val/COCO_val2014_000000341725.jpg new file mode 100644 index 00000000..85f5d815 --- /dev/null +++ b/data/test/images/coco_cn/val/COCO_val2014_000000341725.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e2c08174e59610f65797f50e9eea968eec0ba092c5aca69574e70a6e98862da7 +size 92038 diff --git a/data/test/images/coco_cn/val/COCO_val2014_000000412975.jpg b/data/test/images/coco_cn/val/COCO_val2014_000000412975.jpg new file mode 100644 index 00000000..5ba16dbd --- /dev/null +++ b/data/test/images/coco_cn/val/COCO_val2014_000000412975.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a324c2f213442f8ab2fcc5f16f59d2d31ec08993b27b13a623b3a32dd4c408ac +size 182587 diff --git a/data/test/images/coco_cn/val/COCO_val2014_000000473869.jpg b/data/test/images/coco_cn/val/COCO_val2014_000000473869.jpg new file mode 100644 index 00000000..e6ac2baa --- /dev/null +++ b/data/test/images/coco_cn/val/COCO_val2014_000000473869.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d2f1193dc4c0cd50e0a233810fde1875f7211e936c35d9a3754bf71c2c8da84e +size 109371 diff --git a/data/test/images/coco_cn/val/COCO_val2014_000000574392.jpg b/data/test/images/coco_cn/val/COCO_val2014_000000574392.jpg new file mode 100644 index 00000000..b62feea9 --- /dev/null +++ b/data/test/images/coco_cn/val/COCO_val2014_000000574392.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:74f3626546a174ca28da8ce35eeea6d62d230da5ff74fd73d37211557c35d83e +size 377231 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index ad914618..31f37e76 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -111,6 +111,9 @@ class Trainers(object): default = 'Trainer' + # multi-modal tasks + clip_multi_modal_embedding = 'clip-multi-modal-embedding' + class Preprocessors(object): """ Names for different preprocessor. diff --git a/modelscope/models/multi_modal/clip/clip_bert.py b/modelscope/models/multi_modal/clip/clip_bert.py index 50ddba99..24ccc1fa 100644 --- a/modelscope/models/multi_modal/clip/clip_bert.py +++ b/modelscope/models/multi_modal/clip/clip_bert.py @@ -4,9 +4,12 @@ from transformers import BertConfig, BertForMaskedLM class TextTransformer(nn.Module): - def __init__(self, config_dict, feat_dim=768): + def __init__(self, config_dict, feat_dim=768, use_grad_ckp=True): super(TextTransformer, self).__init__() bert_config = BertConfig.from_dict(config_dict) + if use_grad_ckp: + bert_config.gradient_checkpointing = True + self.bert = BertForMaskedLM(bert_config).bert self.projector = nn.Linear( diff --git a/modelscope/models/multi_modal/clip/clip_model.py b/modelscope/models/multi_modal/clip/clip_model.py index 8dd36acf..eafb3902 100644 --- a/modelscope/models/multi_modal/clip/clip_model.py +++ b/modelscope/models/multi_modal/clip/clip_model.py @@ -8,6 +8,8 @@ import torch.nn as nn import torch.nn.functional as F from PIL import Image from tokenizers import BertWordPieceTokenizer +from torch.distributed.nn.functional import \ + all_gather as all_gather_with_backprop from torchvision.transforms import Compose, Normalize, Resize, ToTensor from modelscope.metainfo import Models @@ -15,7 +17,7 @@ from modelscope.models.base import Model from modelscope.models.builder import MODELS from modelscope.models.multi_modal.clip.clip_bert import TextTransformer from modelscope.models.multi_modal.clip.clip_vit import VisionTransformer -from modelscope.utils.constant import Tasks +from modelscope.utils.constant import ModeKeys, Tasks from modelscope.utils.logger import get_logger logger = get_logger() @@ -40,13 +42,62 @@ class CLIPModel(nn.Module): width=vision_config['width'], layers=vision_config['layers'], heads=vision_config['heads'], - output_dim=vision_config['feat_dim']) + output_dim=vision_config['feat_dim'], + use_grad_ckp=True) # text encoder text_config = model_config['text_config'] self.text_encoder = TextTransformer( text_config['bert_config'], feat_dim=text_config['feat_dim']) + self.logit_scale = nn.Parameter(torch.ones([]) * 4.6) + + def contrastive_loss(self, logits, dim): + neg_ce = torch.diag(F.log_softmax(logits, dim=dim)) + return -neg_ce.mean() + + def clip_loss(self, t2i_sim, i2t_sim, img_idx=None, all_img_idx=None): + if img_idx is not None and all_img_idx is not None: + with torch.no_grad(): + false_neg_indicator = ( + img_idx[:, None] == all_img_idx[None, :]) + false_neg_indicator.fill_diagonal_(False) + t2i_sim.masked_fill_(false_neg_indicator, float('-inf')) + i2t_sim.masked_fill_(false_neg_indicator, float('-inf')) + caption_loss = self.contrastive_loss(t2i_sim, dim=1) + image_loss = self.contrastive_loss(i2t_sim, dim=1) + else: + caption_loss = self.contrastive_loss(t2i_sim, dim=1) + image_loss = self.contrastive_loss(i2t_sim, dim=1) + return (caption_loss + image_loss) / 2.0 + + def get_loss(self, img_tensor, text_ids_tensor, text_masks_tensor, + img_id_list): + img_feat = self.forward(img_tensor, input_type='img') + text_feat = self.forward((text_ids_tensor, text_masks_tensor), + input_type='text') + + global_img_feat = torch.cat(all_gather_with_backprop(img_feat), dim=0) + global_text_feat = torch.cat( + all_gather_with_backprop(text_feat), dim=0) + global_img_id_list = torch.cat( + all_gather_with_backprop(img_id_list), dim=0) + + t2i_sim_mat = text_feat @ global_img_feat.t() + i2t_sim_mat = img_feat @ global_text_feat.t() + + logit_scale = self.logit_scale.exp().clamp(max=100.0) + t2i_sim_mat_logits = t2i_sim_mat * logit_scale + i2t_sim_mat_logits = i2t_sim_mat * logit_scale + + loss = self.clip_loss( + t2i_sim_mat_logits, + i2t_sim_mat_logits, + img_idx=img_id_list, + all_img_idx=global_img_id_list) + + return loss + def forward(self, input_data, input_type): if input_type == 'img': img_embedding = self.vision_encoder(input_data) @@ -58,6 +109,8 @@ class CLIPModel(nn.Module): text_mask_tensor) text_embedding = F.normalize(text_embedding, p=2.0, dim=1) return text_embedding + elif input_type == ModeKeys.TRAIN: + return self.get_loss(*input_data) else: raise ValueError('Unknown input type') diff --git a/modelscope/models/multi_modal/clip/clip_vit.py b/modelscope/models/multi_modal/clip/clip_vit.py index 95bb1adc..cfe67426 100644 --- a/modelscope/models/multi_modal/clip/clip_vit.py +++ b/modelscope/models/multi_modal/clip/clip_vit.py @@ -6,6 +6,7 @@ from typing import Tuple, Union import numpy as np import torch import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint from torch import nn @@ -60,7 +61,8 @@ class Transformer(nn.Module): width: int, layers: int, heads: int, - attn_mask: torch.Tensor = None): + attn_mask: torch.Tensor = None, + use_grad_ckp: bool = True): super().__init__() self.width = width self.layers = layers @@ -69,14 +71,21 @@ class Transformer(nn.Module): for _ in range(layers) ]) + self.use_grad_ckp = use_grad_ckp + def forward(self, x: torch.Tensor): - return self.resblocks(x) + if self.use_grad_ckp: + for each_block in self.resblocks: + x = checkpoint.checkpoint(each_block, x) + return x + else: + return self.resblocks(x) class VisionTransformer(nn.Module): def __init__(self, input_resolution: int, patch_size: int, width: int, - layers: int, heads: int, output_dim: int): + layers: int, heads: int, output_dim: int, use_grad_ckp: bool): super().__init__() self.input_resolution = input_resolution self.output_dim = output_dim @@ -93,7 +102,8 @@ class VisionTransformer(nn.Module): (input_resolution // patch_size)**2 + 1, width)) self.ln_pre = LayerNorm(width) - self.transformer = Transformer(width, layers, heads) + self.transformer = Transformer( + width, layers, heads, use_grad_ckp=use_grad_ckp) self.ln_post = LayerNorm(width) self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) diff --git a/modelscope/pipelines/multi_modal/multi_modal_embedding_pipeline.py b/modelscope/pipelines/multi_modal/multi_modal_embedding_pipeline.py index 43046e8c..d15970d2 100644 --- a/modelscope/pipelines/multi_modal/multi_modal_embedding_pipeline.py +++ b/modelscope/pipelines/multi_modal/multi_modal_embedding_pipeline.py @@ -13,7 +13,7 @@ logger = get_logger() Tasks.multi_modal_embedding, module_name=Pipelines.multi_modal_embedding) class MultiModalEmbeddingPipeline(Pipeline): - def __init__(self, model: str, device_id: int = -1): + def __init__(self, model: str, device: str = 'gpu'): """ use `model` and `preprocessor` to create a kws pipeline for prediction Args: diff --git a/modelscope/trainers/__init__.py b/modelscope/trainers/__init__.py index e5dde881..f32a33c6 100644 --- a/modelscope/trainers/__init__.py +++ b/modelscope/trainers/__init__.py @@ -1,5 +1,6 @@ from .base import DummyTrainer from .builder import build_trainer from .cv import ImageInstanceSegmentationTrainer +from .multi_modal import CLIPTrainer from .nlp import SequenceClassificationTrainer from .trainer import EpochBasedTrainer diff --git a/modelscope/trainers/multi_modal/__init__.py b/modelscope/trainers/multi_modal/__init__.py new file mode 100644 index 00000000..7d386349 --- /dev/null +++ b/modelscope/trainers/multi_modal/__init__.py @@ -0,0 +1 @@ +from .clip import CLIPTrainer diff --git a/modelscope/trainers/multi_modal/clip/__init__.py b/modelscope/trainers/multi_modal/clip/__init__.py new file mode 100644 index 00000000..87f1040c --- /dev/null +++ b/modelscope/trainers/multi_modal/clip/__init__.py @@ -0,0 +1 @@ +from .clip_trainer import CLIPTrainer diff --git a/modelscope/trainers/multi_modal/clip/clip_trainer.py b/modelscope/trainers/multi_modal/clip/clip_trainer.py new file mode 100644 index 00000000..cccf4296 --- /dev/null +++ b/modelscope/trainers/multi_modal/clip/clip_trainer.py @@ -0,0 +1,167 @@ +import os +from typing import Dict, Optional + +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from modelscope.metainfo import Trainers +from modelscope.models.base import Model +from modelscope.trainers.base import BaseTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.utils.config import Config +from modelscope.utils.constant import ModeKeys +from modelscope.utils.logger import get_logger +from .clip_trainer_utils import ImageWithCaptionDataset, get_optimizer + +logger = get_logger() + + +@TRAINERS.register_module(module_name=Trainers.clip_multi_modal_embedding) +class CLIPTrainer(BaseTrainer): + + def __init__(self, cfg_file: str, model: str, device_id: int, *args, + **kwargs): + super().__init__(cfg_file) + + self.cfg = Config.from_file(cfg_file) + self.model = Model.from_pretrained(model) + self.device_id = device_id + self.total_epoch = self.cfg.train.epoch + self.train_batch_size = self.cfg.train.batch_size + self.val_batch_size = self.cfg.evaluation.batch_size + self.ckpt_dir = self.cfg.train.ckpt_dir + + self.train_dataset = ImageWithCaptionDataset( + json_file='{}/{}'.format(self.cfg.dataset.root_dir, + self.cfg.dataset.train_set), + img_dir=self.cfg.dataset.root_dir, + phase=ModeKeys.TRAIN) + self.val_dataset = ImageWithCaptionDataset( + json_file='{}/{}'.format(self.cfg.dataset.root_dir, + self.cfg.dataset.val_set), + img_dir=self.cfg.dataset.root_dir, + phase=ModeKeys.EVAL) + + def train(self, *args, **kwargs): + assert dist.is_initialized() + + self.model.clip_model.train() + self.model.clip_model.to(self.device_id) + ddp_model = torch.nn.parallel.DistributedDataParallel( + self.model.clip_model, device_ids=[ + self.device_id, + ]) + + optimizer = get_optimizer(ddp_model) + + for epoch in range(self.total_epoch): + train_sampler = DistributedSampler( + dataset=self.train_dataset, shuffle=True) + train_sampler.set_epoch(epoch) + + train_params = { + 'pin_memory': True, + 'collate_fn': None, + 'batch_size': self.train_batch_size, + 'shuffle': False, + 'drop_last': True, + 'sampler': train_sampler, + 'num_workers': 8 + } + + train_loader = DataLoader(self.train_dataset, **train_params) + + for batch_idx, (img_tensor, text_str_list, + img_id_list) in enumerate(train_loader): + text_info_list = [ + self.model.tokenize_text(tmp) for tmp in text_str_list + ] + text_ids_tensor = torch.cat([tmp[0] for tmp in text_info_list], + dim=0) + text_masks_tensor = torch.cat( + [tmp[1] for tmp in text_info_list], dim=0) + + img_tensor = img_tensor.to(self.device_id, non_blocking=True) + img_id_list = img_id_list.to(self.device_id, non_blocking=True) + text_ids_tensor = text_ids_tensor.to( + self.device_id, non_blocking=True) + text_masks_tensor = text_masks_tensor.to( + self.device_id, non_blocking=True) + + loss = ddp_model((img_tensor, text_ids_tensor, + text_masks_tensor, img_id_list), + ModeKeys.TRAIN) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if batch_idx % 10 == 0: + logger.info( + 'epoch: {}, train batch {}/{}, loss={:.5f}, logit_scale={:.5f}' + .format(epoch, batch_idx, len(train_loader), + loss.item(), + ddp_model.module.logit_scale.exp().item())) + if dist.get_rank() == 0: + os.makedirs(self.ckpt_dir, exist_ok=True) + torch.save(ddp_model.module.state_dict(), + '{}/epoch{}.pth'.format(self.ckpt_dir, epoch)) + + def evaluate(self, + checkpoint_path: Optional[str] = None, + *args, + **kwargs) -> Dict[str, float]: + if checkpoint_path is not None: + checkpoint_params = torch.load(checkpoint_path, 'cpu') + self.model.clip_model.load_state_dict(checkpoint_params) + self.model.clip_model.eval() + self.model.clip_model.to(self.device_id) + + val_params = { + 'collate_fn': None, + 'batch_size': self.val_batch_size, + 'shuffle': False, + 'drop_last': False, + 'num_workers': 8 + } + val_loader = DataLoader(self.val_dataset, **val_params) + + tp_cnt_per_batch = [] + processed_cnt = 0 + with torch.no_grad(): + for batch_idx, (img_tensor, text_str_list, + img_id_list) in enumerate(val_loader): + text_info_list = [ + self.model.tokenize_text(tmp) for tmp in text_str_list + ] + text_ids_tensor = torch.cat([tmp[0] for tmp in text_info_list], + dim=0) + text_masks_tensor = torch.cat( + [tmp[1] for tmp in text_info_list], dim=0) + + img_tensor = img_tensor.to(self.device_id, non_blocking=True) + img_id_list = img_id_list.to(self.device_id, non_blocking=True) + text_ids_tensor = text_ids_tensor.to( + self.device_id, non_blocking=True) + text_masks_tensor = text_masks_tensor.to( + self.device_id, non_blocking=True) + + img_feat = self.model.clip_model(img_tensor, input_type='img') + text_feat = self.model.clip_model( + (text_ids_tensor, text_masks_tensor), input_type='text') + + sim_mat = text_feat @ img_feat.t() + text_cnt, img_cnt = sim_mat.shape + top1_scores, match_ids = torch.max(sim_mat, dim=1) + + match_ids = match_ids.int() + gt_ids = torch.tensor(range(0, text_cnt)).to( + self.device_id, non_blocking=True).int() + error_cnt = torch.nonzero(match_ids - gt_ids) + processed_cnt += text_cnt + + tp_cnt_per_batch.append(text_cnt - 1.0 * error_cnt.numel()) + logger.info('current acc: {:.3f}'.format( + sum(tp_cnt_per_batch) / processed_cnt)) diff --git a/modelscope/trainers/multi_modal/clip/clip_trainer_utils.py b/modelscope/trainers/multi_modal/clip/clip_trainer_utils.py new file mode 100644 index 00000000..1391a4fd --- /dev/null +++ b/modelscope/trainers/multi_modal/clip/clip_trainer_utils.py @@ -0,0 +1,92 @@ +import os +import random + +import json +import torch +import torch.nn.functional as F +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + +from modelscope.utils.constant import ModeKeys + +train_transform = transforms.Compose([ + transforms.RandomResizedCrop( + 224, scale=(0.5, 1.0), interpolation=Image.BICUBIC), + transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], + p=0.8), + transforms.RandomGrayscale(p=0.2), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)) +]) + +val_transform = transforms.Compose([ + transforms.Resize((224, 224), interpolation=Image.BICUBIC), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)) +]) + + +class ImageWithCaptionDataset(Dataset): + + def __init__(self, json_file, img_dir, phase): + self.annotations = json.load(open(json_file)) + self.img_dir = img_dir + if phase == ModeKeys.TRAIN: + self.transform = train_transform + elif phase == ModeKeys.EVAL: + self.transform = val_transform + + self.img_name2img_id = {} + for anno_dict in self.annotations: + img_name = anno_dict['image'] + if img_name not in self.img_name2img_id: + self.img_name2img_id[img_name] = len(self.img_name2img_id) + + def __len__(self): + return len(self.annotations) + + def __getitem__(self, index): + anno_dict = self.annotations[index] + + img_path = os.path.join(self.img_dir, anno_dict['image']) + img_pil = Image.open(img_path).convert('RGB') + img_th = self.transform(img_pil) + img_id = self.img_name2img_id[anno_dict['image']] + + text_str = random.choice(anno_dict['caption']) + + return img_th, text_str, img_id + + +def get_params_groups(ddp_model, weight_decay): + decay = [] + no_decay = [] + for name, param in ddp_model.named_parameters(): + if not param.requires_grad: + continue + if len(param.shape) == 1 or name.endswith('.bias'): + no_decay.append(param) + else: + decay.append(param) + params_groups = [{ + 'params': no_decay, + 'weight_decay': 0. + }, { + 'params': decay, + 'weight_decay': weight_decay + }] + return params_groups + + +def get_optimizer(ddp_model): + from torch.optim import AdamW + lr_init = 1e-5 + betas = [0.9, 0.999] + weight_decay = 0.02 + params_groups = get_params_groups(ddp_model, weight_decay=weight_decay) + return AdamW( + params_groups, lr=lr_init, betas=betas, weight_decay=weight_decay) diff --git a/tests/trainers/test_clip_multi_modal_embedding_trainer.py b/tests/trainers/test_clip_multi_modal_embedding_trainer.py new file mode 100644 index 00000000..c1b51ec6 --- /dev/null +++ b/tests/trainers/test_clip_multi_modal_embedding_trainer.py @@ -0,0 +1,60 @@ +import os +import tempfile +import unittest + +import requests +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.trainers import build_trainer +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +def clip_train_worker(local_rank, ngpus, node_size, node_rank): + global_rank = local_rank + node_rank * ngpus + dist_world_size = node_size * ngpus + + dist.init_process_group( + backend='nccl', world_size=dist_world_size, rank=global_rank) + + model_id = 'damo/multi-modal_clip-vit-large-patch14-chinese_multi-modal-embedding' + local_model_dir = snapshot_download(model_id) + + default_args = dict( + cfg_file='{}/{}'.format(local_model_dir, ModelFile.CONFIGURATION), + model=model_id, + device_id=local_rank) + trainer = build_trainer( + name=Trainers.clip_multi_modal_embedding, default_args=default_args) + + trainer.train() + trainer.evaluate() + + +class CLIPMultiModalEmbeddingTrainerTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_trainer(self): + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '2001' + NODE_SIZE, NODE_RANK = 1, 0 + logger.info('Train clip with {} machines'.format(NODE_SIZE)) + ngpus = torch.cuda.device_count() + logger.info('Machine: {} has {} GPUs'.format(NODE_RANK, ngpus)) + mp.spawn( + clip_train_worker, + nprocs=ngpus, + args=(ngpus, NODE_SIZE, NODE_RANK)) + logger.info('Training done') + + +if __name__ == '__main__': + unittest.main() + ...