Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9497065master
| @@ -0,0 +1,51 @@ | |||||
| [ | |||||
| { | |||||
| "image": "train/COCO_train2014_000000496606.jpg", | |||||
| "caption": [ | |||||
| "一只黄色的小狗趴在长椅上" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "image": "val/COCO_val2014_000000043734.jpg", | |||||
| "caption": [ | |||||
| "两只黑色的狗从水里翻着水花游过来" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "image": "train/COCO_train2014_000000404748.jpg", | |||||
| "caption": [ | |||||
| "两只长颈鹿站在岩石旁的草地上" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "image": "val/COCO_val2014_000000574392.jpg", | |||||
| "caption": [ | |||||
| "一个木制的公园长椅在森林里。" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "image": "train/COCO_train2014_000000563734.jpg", | |||||
| "caption": [ | |||||
| "许多公交车排成队在广场上停着。" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "image": "train/COCO_train2014_000000197406.jpg", | |||||
| "caption": [ | |||||
| "一个男人和一只长颈鹿站在沙滩上" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "image": "val/COCO_val2014_000000473869.jpg", | |||||
| "caption": [ | |||||
| "一个微笑的男人在厨房里做饭。" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "image": "train/COCO_train2014_000000021183.jpg", | |||||
| "caption": [ | |||||
| "一个年龄比较大,坐在街道旁座椅上的男人手里握着一个装着写有标语的板子的手推车", | |||||
| "一个年老的男人坐在街道上的长椅上,手搭在面前放着告示牌的小推车上" | |||||
| ] | |||||
| } | |||||
| ] | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:eff26436ee5ca4146a5c7218c8a1814a324574e92114736792dcc768ac1e566f | |||||
| size 134292 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:07fb36fb94301aa067c1c7f9ca4c8c04d6d7282b4a5494e392c54928d242a56b | |||||
| size 149178 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:33473d2a21e669196271e28eca437696625e4a5e11eb6efc5b57e7961f15cf0d | |||||
| size 68914 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:1c1f6dc406b0e08b43668e73f9700e63420eb4e384a53c539062e89315b64ad6 | |||||
| size 84248 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:ed36ab05878caee478d6532777c862af11a4c62182ba989dfb3bf32e41277c65 | |||||
| size 239503 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:95815c59443288b019e496d0c81cf8e734b347e8a31d996a9f1463eb506f3717 | |||||
| size 177175 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:7482e789876cbdd18e1e5f0487d2a10f40be1cf4ce696d8e203da80418ec580b | |||||
| size 195821 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:df3cee336d965ca249b5e4acd9618d0e2d0e267267222408b6565bb331a5fb23 | |||||
| size 198775 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:c03f9d1eb963d6b385e22f3a26d202bea7637d3effd347af53435f5ad9434d72 | |||||
| size 179422 | |||||
| @@ -0,0 +1,52 @@ | |||||
| [ | |||||
| { | |||||
| "image": "train/COCO_train2014_000000573854.jpg", | |||||
| "caption": [ | |||||
| "机场跑道的喷气式飞机正准备起飞。" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "image": "val/COCO_val2014_000000412975.jpg", | |||||
| "caption": [ | |||||
| "一个女孩走下台阶。" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "image": "val/COCO_val2014_000000341725.jpg", | |||||
| "caption": [ | |||||
| "窗台上蓝色的花瓶里有一束粉色的郁金香。" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "image": "val/COCO_val2014_000000163020.jpg", | |||||
| "caption": [ | |||||
| "一只海鸥在水面上飞翔。" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "image": "train/COCO_train2014_000000177625.jpg", | |||||
| "caption": [ | |||||
| "一男一女在聚会上玩电子游戏,男人的脚边趴着一只狗" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "image": "train/COCO_train2014_000000275612.jpg", | |||||
| "caption": [ | |||||
| "厕所中的一个马桶", | |||||
| "浴室里,高档的马桶与各式洗浴用品一应俱全。" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "image": "train/COCO_train2014_000000493952.jpg", | |||||
| "caption": [ | |||||
| "一辆黑色轿车停在一栋大楼前。" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "image": "val/COCO_val2014_000000044723.jpg", | |||||
| "caption": [ | |||||
| "阴天下一张伦敦塔的照片。", | |||||
| "一座大楼的顶端悬挂着钟表。" | |||||
| ] | |||||
| } | |||||
| ] | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:c26dc7c54a1202744d50bc2186ea2a49865879a3a3a174099c4e9ecc1199a16a | |||||
| size 93126 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:24cd7ad56cf00d57a7b2d182957a8ad6b44d5eb55dfe3bc69ad5a292151d482e | |||||
| size 122140 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:ee9c78c8c141d1bb3cd064f2a003f0786a19c0b2cc54e0cfa2ee2459daf7bebe | |||||
| size 63796 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:e2c08174e59610f65797f50e9eea968eec0ba092c5aca69574e70a6e98862da7 | |||||
| size 92038 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:a324c2f213442f8ab2fcc5f16f59d2d31ec08993b27b13a623b3a32dd4c408ac | |||||
| size 182587 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:d2f1193dc4c0cd50e0a233810fde1875f7211e936c35d9a3754bf71c2c8da84e | |||||
| size 109371 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:74f3626546a174ca28da8ce35eeea6d62d230da5ff74fd73d37211557c35d83e | |||||
| size 377231 | |||||
| @@ -111,6 +111,9 @@ class Trainers(object): | |||||
| default = 'Trainer' | default = 'Trainer' | ||||
| # multi-modal tasks | |||||
| clip_multi_modal_embedding = 'clip-multi-modal-embedding' | |||||
| class Preprocessors(object): | class Preprocessors(object): | ||||
| """ Names for different preprocessor. | """ Names for different preprocessor. | ||||
| @@ -4,9 +4,12 @@ from transformers import BertConfig, BertForMaskedLM | |||||
| class TextTransformer(nn.Module): | class TextTransformer(nn.Module): | ||||
| def __init__(self, config_dict, feat_dim=768): | |||||
| def __init__(self, config_dict, feat_dim=768, use_grad_ckp=True): | |||||
| super(TextTransformer, self).__init__() | super(TextTransformer, self).__init__() | ||||
| bert_config = BertConfig.from_dict(config_dict) | bert_config = BertConfig.from_dict(config_dict) | ||||
| if use_grad_ckp: | |||||
| bert_config.gradient_checkpointing = True | |||||
| self.bert = BertForMaskedLM(bert_config).bert | self.bert = BertForMaskedLM(bert_config).bert | ||||
| self.projector = nn.Linear( | self.projector = nn.Linear( | ||||
| @@ -8,6 +8,8 @@ import torch.nn as nn | |||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| from PIL import Image | from PIL import Image | ||||
| from tokenizers import BertWordPieceTokenizer | from tokenizers import BertWordPieceTokenizer | ||||
| from torch.distributed.nn.functional import \ | |||||
| all_gather as all_gather_with_backprop | |||||
| from torchvision.transforms import Compose, Normalize, Resize, ToTensor | from torchvision.transforms import Compose, Normalize, Resize, ToTensor | ||||
| from modelscope.metainfo import Models | from modelscope.metainfo import Models | ||||
| @@ -15,7 +17,7 @@ from modelscope.models.base import Model | |||||
| from modelscope.models.builder import MODELS | from modelscope.models.builder import MODELS | ||||
| from modelscope.models.multi_modal.clip.clip_bert import TextTransformer | from modelscope.models.multi_modal.clip.clip_bert import TextTransformer | ||||
| from modelscope.models.multi_modal.clip.clip_vit import VisionTransformer | from modelscope.models.multi_modal.clip.clip_vit import VisionTransformer | ||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.constant import ModeKeys, Tasks | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| logger = get_logger() | logger = get_logger() | ||||
| @@ -40,13 +42,62 @@ class CLIPModel(nn.Module): | |||||
| width=vision_config['width'], | width=vision_config['width'], | ||||
| layers=vision_config['layers'], | layers=vision_config['layers'], | ||||
| heads=vision_config['heads'], | heads=vision_config['heads'], | ||||
| output_dim=vision_config['feat_dim']) | |||||
| output_dim=vision_config['feat_dim'], | |||||
| use_grad_ckp=True) | |||||
| # text encoder | # text encoder | ||||
| text_config = model_config['text_config'] | text_config = model_config['text_config'] | ||||
| self.text_encoder = TextTransformer( | self.text_encoder = TextTransformer( | ||||
| text_config['bert_config'], feat_dim=text_config['feat_dim']) | text_config['bert_config'], feat_dim=text_config['feat_dim']) | ||||
| self.logit_scale = nn.Parameter(torch.ones([]) * 4.6) | |||||
| def contrastive_loss(self, logits, dim): | |||||
| neg_ce = torch.diag(F.log_softmax(logits, dim=dim)) | |||||
| return -neg_ce.mean() | |||||
| def clip_loss(self, t2i_sim, i2t_sim, img_idx=None, all_img_idx=None): | |||||
| if img_idx is not None and all_img_idx is not None: | |||||
| with torch.no_grad(): | |||||
| false_neg_indicator = ( | |||||
| img_idx[:, None] == all_img_idx[None, :]) | |||||
| false_neg_indicator.fill_diagonal_(False) | |||||
| t2i_sim.masked_fill_(false_neg_indicator, float('-inf')) | |||||
| i2t_sim.masked_fill_(false_neg_indicator, float('-inf')) | |||||
| caption_loss = self.contrastive_loss(t2i_sim, dim=1) | |||||
| image_loss = self.contrastive_loss(i2t_sim, dim=1) | |||||
| else: | |||||
| caption_loss = self.contrastive_loss(t2i_sim, dim=1) | |||||
| image_loss = self.contrastive_loss(i2t_sim, dim=1) | |||||
| return (caption_loss + image_loss) / 2.0 | |||||
| def get_loss(self, img_tensor, text_ids_tensor, text_masks_tensor, | |||||
| img_id_list): | |||||
| img_feat = self.forward(img_tensor, input_type='img') | |||||
| text_feat = self.forward((text_ids_tensor, text_masks_tensor), | |||||
| input_type='text') | |||||
| global_img_feat = torch.cat(all_gather_with_backprop(img_feat), dim=0) | |||||
| global_text_feat = torch.cat( | |||||
| all_gather_with_backprop(text_feat), dim=0) | |||||
| global_img_id_list = torch.cat( | |||||
| all_gather_with_backprop(img_id_list), dim=0) | |||||
| t2i_sim_mat = text_feat @ global_img_feat.t() | |||||
| i2t_sim_mat = img_feat @ global_text_feat.t() | |||||
| logit_scale = self.logit_scale.exp().clamp(max=100.0) | |||||
| t2i_sim_mat_logits = t2i_sim_mat * logit_scale | |||||
| i2t_sim_mat_logits = i2t_sim_mat * logit_scale | |||||
| loss = self.clip_loss( | |||||
| t2i_sim_mat_logits, | |||||
| i2t_sim_mat_logits, | |||||
| img_idx=img_id_list, | |||||
| all_img_idx=global_img_id_list) | |||||
| return loss | |||||
| def forward(self, input_data, input_type): | def forward(self, input_data, input_type): | ||||
| if input_type == 'img': | if input_type == 'img': | ||||
| img_embedding = self.vision_encoder(input_data) | img_embedding = self.vision_encoder(input_data) | ||||
| @@ -58,6 +109,8 @@ class CLIPModel(nn.Module): | |||||
| text_mask_tensor) | text_mask_tensor) | ||||
| text_embedding = F.normalize(text_embedding, p=2.0, dim=1) | text_embedding = F.normalize(text_embedding, p=2.0, dim=1) | ||||
| return text_embedding | return text_embedding | ||||
| elif input_type == ModeKeys.TRAIN: | |||||
| return self.get_loss(*input_data) | |||||
| else: | else: | ||||
| raise ValueError('Unknown input type') | raise ValueError('Unknown input type') | ||||
| @@ -6,6 +6,7 @@ from typing import Tuple, Union | |||||
| import numpy as np | import numpy as np | ||||
| import torch | import torch | ||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| import torch.utils.checkpoint as checkpoint | |||||
| from torch import nn | from torch import nn | ||||
| @@ -60,7 +61,8 @@ class Transformer(nn.Module): | |||||
| width: int, | width: int, | ||||
| layers: int, | layers: int, | ||||
| heads: int, | heads: int, | ||||
| attn_mask: torch.Tensor = None): | |||||
| attn_mask: torch.Tensor = None, | |||||
| use_grad_ckp: bool = True): | |||||
| super().__init__() | super().__init__() | ||||
| self.width = width | self.width = width | ||||
| self.layers = layers | self.layers = layers | ||||
| @@ -69,14 +71,21 @@ class Transformer(nn.Module): | |||||
| for _ in range(layers) | for _ in range(layers) | ||||
| ]) | ]) | ||||
| self.use_grad_ckp = use_grad_ckp | |||||
| def forward(self, x: torch.Tensor): | def forward(self, x: torch.Tensor): | ||||
| return self.resblocks(x) | |||||
| if self.use_grad_ckp: | |||||
| for each_block in self.resblocks: | |||||
| x = checkpoint.checkpoint(each_block, x) | |||||
| return x | |||||
| else: | |||||
| return self.resblocks(x) | |||||
| class VisionTransformer(nn.Module): | class VisionTransformer(nn.Module): | ||||
| def __init__(self, input_resolution: int, patch_size: int, width: int, | def __init__(self, input_resolution: int, patch_size: int, width: int, | ||||
| layers: int, heads: int, output_dim: int): | |||||
| layers: int, heads: int, output_dim: int, use_grad_ckp: bool): | |||||
| super().__init__() | super().__init__() | ||||
| self.input_resolution = input_resolution | self.input_resolution = input_resolution | ||||
| self.output_dim = output_dim | self.output_dim = output_dim | ||||
| @@ -93,7 +102,8 @@ class VisionTransformer(nn.Module): | |||||
| (input_resolution // patch_size)**2 + 1, width)) | (input_resolution // patch_size)**2 + 1, width)) | ||||
| self.ln_pre = LayerNorm(width) | self.ln_pre = LayerNorm(width) | ||||
| self.transformer = Transformer(width, layers, heads) | |||||
| self.transformer = Transformer( | |||||
| width, layers, heads, use_grad_ckp=use_grad_ckp) | |||||
| self.ln_post = LayerNorm(width) | self.ln_post = LayerNorm(width) | ||||
| self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) | ||||
| @@ -13,7 +13,7 @@ logger = get_logger() | |||||
| Tasks.multi_modal_embedding, module_name=Pipelines.multi_modal_embedding) | Tasks.multi_modal_embedding, module_name=Pipelines.multi_modal_embedding) | ||||
| class MultiModalEmbeddingPipeline(Pipeline): | class MultiModalEmbeddingPipeline(Pipeline): | ||||
| def __init__(self, model: str, device_id: int = -1): | |||||
| def __init__(self, model: str, device: str = 'gpu'): | |||||
| """ | """ | ||||
| use `model` and `preprocessor` to create a kws pipeline for prediction | use `model` and `preprocessor` to create a kws pipeline for prediction | ||||
| Args: | Args: | ||||
| @@ -1,5 +1,6 @@ | |||||
| from .base import DummyTrainer | from .base import DummyTrainer | ||||
| from .builder import build_trainer | from .builder import build_trainer | ||||
| from .cv import ImageInstanceSegmentationTrainer | from .cv import ImageInstanceSegmentationTrainer | ||||
| from .multi_modal import CLIPTrainer | |||||
| from .nlp import SequenceClassificationTrainer | from .nlp import SequenceClassificationTrainer | ||||
| from .trainer import EpochBasedTrainer | from .trainer import EpochBasedTrainer | ||||
| @@ -0,0 +1 @@ | |||||
| from .clip import CLIPTrainer | |||||
| @@ -0,0 +1 @@ | |||||
| from .clip_trainer import CLIPTrainer | |||||
| @@ -0,0 +1,167 @@ | |||||
| import os | |||||
| from typing import Dict, Optional | |||||
| import torch | |||||
| import torch.distributed as dist | |||||
| from torch.utils.data import DataLoader | |||||
| from torch.utils.data.distributed import DistributedSampler | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.models.base import Model | |||||
| from modelscope.trainers.base import BaseTrainer | |||||
| from modelscope.trainers.builder import TRAINERS | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import ModeKeys | |||||
| from modelscope.utils.logger import get_logger | |||||
| from .clip_trainer_utils import ImageWithCaptionDataset, get_optimizer | |||||
| logger = get_logger() | |||||
| @TRAINERS.register_module(module_name=Trainers.clip_multi_modal_embedding) | |||||
| class CLIPTrainer(BaseTrainer): | |||||
| def __init__(self, cfg_file: str, model: str, device_id: int, *args, | |||||
| **kwargs): | |||||
| super().__init__(cfg_file) | |||||
| self.cfg = Config.from_file(cfg_file) | |||||
| self.model = Model.from_pretrained(model) | |||||
| self.device_id = device_id | |||||
| self.total_epoch = self.cfg.train.epoch | |||||
| self.train_batch_size = self.cfg.train.batch_size | |||||
| self.val_batch_size = self.cfg.evaluation.batch_size | |||||
| self.ckpt_dir = self.cfg.train.ckpt_dir | |||||
| self.train_dataset = ImageWithCaptionDataset( | |||||
| json_file='{}/{}'.format(self.cfg.dataset.root_dir, | |||||
| self.cfg.dataset.train_set), | |||||
| img_dir=self.cfg.dataset.root_dir, | |||||
| phase=ModeKeys.TRAIN) | |||||
| self.val_dataset = ImageWithCaptionDataset( | |||||
| json_file='{}/{}'.format(self.cfg.dataset.root_dir, | |||||
| self.cfg.dataset.val_set), | |||||
| img_dir=self.cfg.dataset.root_dir, | |||||
| phase=ModeKeys.EVAL) | |||||
| def train(self, *args, **kwargs): | |||||
| assert dist.is_initialized() | |||||
| self.model.clip_model.train() | |||||
| self.model.clip_model.to(self.device_id) | |||||
| ddp_model = torch.nn.parallel.DistributedDataParallel( | |||||
| self.model.clip_model, device_ids=[ | |||||
| self.device_id, | |||||
| ]) | |||||
| optimizer = get_optimizer(ddp_model) | |||||
| for epoch in range(self.total_epoch): | |||||
| train_sampler = DistributedSampler( | |||||
| dataset=self.train_dataset, shuffle=True) | |||||
| train_sampler.set_epoch(epoch) | |||||
| train_params = { | |||||
| 'pin_memory': True, | |||||
| 'collate_fn': None, | |||||
| 'batch_size': self.train_batch_size, | |||||
| 'shuffle': False, | |||||
| 'drop_last': True, | |||||
| 'sampler': train_sampler, | |||||
| 'num_workers': 8 | |||||
| } | |||||
| train_loader = DataLoader(self.train_dataset, **train_params) | |||||
| for batch_idx, (img_tensor, text_str_list, | |||||
| img_id_list) in enumerate(train_loader): | |||||
| text_info_list = [ | |||||
| self.model.tokenize_text(tmp) for tmp in text_str_list | |||||
| ] | |||||
| text_ids_tensor = torch.cat([tmp[0] for tmp in text_info_list], | |||||
| dim=0) | |||||
| text_masks_tensor = torch.cat( | |||||
| [tmp[1] for tmp in text_info_list], dim=0) | |||||
| img_tensor = img_tensor.to(self.device_id, non_blocking=True) | |||||
| img_id_list = img_id_list.to(self.device_id, non_blocking=True) | |||||
| text_ids_tensor = text_ids_tensor.to( | |||||
| self.device_id, non_blocking=True) | |||||
| text_masks_tensor = text_masks_tensor.to( | |||||
| self.device_id, non_blocking=True) | |||||
| loss = ddp_model((img_tensor, text_ids_tensor, | |||||
| text_masks_tensor, img_id_list), | |||||
| ModeKeys.TRAIN) | |||||
| optimizer.zero_grad() | |||||
| loss.backward() | |||||
| optimizer.step() | |||||
| if batch_idx % 10 == 0: | |||||
| logger.info( | |||||
| 'epoch: {}, train batch {}/{}, loss={:.5f}, logit_scale={:.5f}' | |||||
| .format(epoch, batch_idx, len(train_loader), | |||||
| loss.item(), | |||||
| ddp_model.module.logit_scale.exp().item())) | |||||
| if dist.get_rank() == 0: | |||||
| os.makedirs(self.ckpt_dir, exist_ok=True) | |||||
| torch.save(ddp_model.module.state_dict(), | |||||
| '{}/epoch{}.pth'.format(self.ckpt_dir, epoch)) | |||||
| def evaluate(self, | |||||
| checkpoint_path: Optional[str] = None, | |||||
| *args, | |||||
| **kwargs) -> Dict[str, float]: | |||||
| if checkpoint_path is not None: | |||||
| checkpoint_params = torch.load(checkpoint_path, 'cpu') | |||||
| self.model.clip_model.load_state_dict(checkpoint_params) | |||||
| self.model.clip_model.eval() | |||||
| self.model.clip_model.to(self.device_id) | |||||
| val_params = { | |||||
| 'collate_fn': None, | |||||
| 'batch_size': self.val_batch_size, | |||||
| 'shuffle': False, | |||||
| 'drop_last': False, | |||||
| 'num_workers': 8 | |||||
| } | |||||
| val_loader = DataLoader(self.val_dataset, **val_params) | |||||
| tp_cnt_per_batch = [] | |||||
| processed_cnt = 0 | |||||
| with torch.no_grad(): | |||||
| for batch_idx, (img_tensor, text_str_list, | |||||
| img_id_list) in enumerate(val_loader): | |||||
| text_info_list = [ | |||||
| self.model.tokenize_text(tmp) for tmp in text_str_list | |||||
| ] | |||||
| text_ids_tensor = torch.cat([tmp[0] for tmp in text_info_list], | |||||
| dim=0) | |||||
| text_masks_tensor = torch.cat( | |||||
| [tmp[1] for tmp in text_info_list], dim=0) | |||||
| img_tensor = img_tensor.to(self.device_id, non_blocking=True) | |||||
| img_id_list = img_id_list.to(self.device_id, non_blocking=True) | |||||
| text_ids_tensor = text_ids_tensor.to( | |||||
| self.device_id, non_blocking=True) | |||||
| text_masks_tensor = text_masks_tensor.to( | |||||
| self.device_id, non_blocking=True) | |||||
| img_feat = self.model.clip_model(img_tensor, input_type='img') | |||||
| text_feat = self.model.clip_model( | |||||
| (text_ids_tensor, text_masks_tensor), input_type='text') | |||||
| sim_mat = text_feat @ img_feat.t() | |||||
| text_cnt, img_cnt = sim_mat.shape | |||||
| top1_scores, match_ids = torch.max(sim_mat, dim=1) | |||||
| match_ids = match_ids.int() | |||||
| gt_ids = torch.tensor(range(0, text_cnt)).to( | |||||
| self.device_id, non_blocking=True).int() | |||||
| error_cnt = torch.nonzero(match_ids - gt_ids) | |||||
| processed_cnt += text_cnt | |||||
| tp_cnt_per_batch.append(text_cnt - 1.0 * error_cnt.numel()) | |||||
| logger.info('current acc: {:.3f}'.format( | |||||
| sum(tp_cnt_per_batch) / processed_cnt)) | |||||
| @@ -0,0 +1,92 @@ | |||||
| import os | |||||
| import random | |||||
| import json | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| from PIL import Image | |||||
| from torch.utils.data import Dataset | |||||
| from torchvision import transforms | |||||
| from modelscope.utils.constant import ModeKeys | |||||
| train_transform = transforms.Compose([ | |||||
| transforms.RandomResizedCrop( | |||||
| 224, scale=(0.5, 1.0), interpolation=Image.BICUBIC), | |||||
| transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], | |||||
| p=0.8), | |||||
| transforms.RandomGrayscale(p=0.2), | |||||
| transforms.RandomHorizontalFlip(), | |||||
| transforms.ToTensor(), | |||||
| transforms.Normalize((0.48145466, 0.4578275, 0.40821073), | |||||
| (0.26862954, 0.26130258, 0.27577711)) | |||||
| ]) | |||||
| val_transform = transforms.Compose([ | |||||
| transforms.Resize((224, 224), interpolation=Image.BICUBIC), | |||||
| transforms.ToTensor(), | |||||
| transforms.Normalize((0.48145466, 0.4578275, 0.40821073), | |||||
| (0.26862954, 0.26130258, 0.27577711)) | |||||
| ]) | |||||
| class ImageWithCaptionDataset(Dataset): | |||||
| def __init__(self, json_file, img_dir, phase): | |||||
| self.annotations = json.load(open(json_file)) | |||||
| self.img_dir = img_dir | |||||
| if phase == ModeKeys.TRAIN: | |||||
| self.transform = train_transform | |||||
| elif phase == ModeKeys.EVAL: | |||||
| self.transform = val_transform | |||||
| self.img_name2img_id = {} | |||||
| for anno_dict in self.annotations: | |||||
| img_name = anno_dict['image'] | |||||
| if img_name not in self.img_name2img_id: | |||||
| self.img_name2img_id[img_name] = len(self.img_name2img_id) | |||||
| def __len__(self): | |||||
| return len(self.annotations) | |||||
| def __getitem__(self, index): | |||||
| anno_dict = self.annotations[index] | |||||
| img_path = os.path.join(self.img_dir, anno_dict['image']) | |||||
| img_pil = Image.open(img_path).convert('RGB') | |||||
| img_th = self.transform(img_pil) | |||||
| img_id = self.img_name2img_id[anno_dict['image']] | |||||
| text_str = random.choice(anno_dict['caption']) | |||||
| return img_th, text_str, img_id | |||||
| def get_params_groups(ddp_model, weight_decay): | |||||
| decay = [] | |||||
| no_decay = [] | |||||
| for name, param in ddp_model.named_parameters(): | |||||
| if not param.requires_grad: | |||||
| continue | |||||
| if len(param.shape) == 1 or name.endswith('.bias'): | |||||
| no_decay.append(param) | |||||
| else: | |||||
| decay.append(param) | |||||
| params_groups = [{ | |||||
| 'params': no_decay, | |||||
| 'weight_decay': 0. | |||||
| }, { | |||||
| 'params': decay, | |||||
| 'weight_decay': weight_decay | |||||
| }] | |||||
| return params_groups | |||||
| def get_optimizer(ddp_model): | |||||
| from torch.optim import AdamW | |||||
| lr_init = 1e-5 | |||||
| betas = [0.9, 0.999] | |||||
| weight_decay = 0.02 | |||||
| params_groups = get_params_groups(ddp_model, weight_decay=weight_decay) | |||||
| return AdamW( | |||||
| params_groups, lr=lr_init, betas=betas, weight_decay=weight_decay) | |||||
| @@ -0,0 +1,60 @@ | |||||
| import os | |||||
| import tempfile | |||||
| import unittest | |||||
| import requests | |||||
| import torch | |||||
| import torch.distributed as dist | |||||
| import torch.multiprocessing as mp | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.constant import ModelFile | |||||
| from modelscope.utils.logger import get_logger | |||||
| from modelscope.utils.test_utils import test_level | |||||
| logger = get_logger() | |||||
| def clip_train_worker(local_rank, ngpus, node_size, node_rank): | |||||
| global_rank = local_rank + node_rank * ngpus | |||||
| dist_world_size = node_size * ngpus | |||||
| dist.init_process_group( | |||||
| backend='nccl', world_size=dist_world_size, rank=global_rank) | |||||
| model_id = 'damo/multi-modal_clip-vit-large-patch14-chinese_multi-modal-embedding' | |||||
| local_model_dir = snapshot_download(model_id) | |||||
| default_args = dict( | |||||
| cfg_file='{}/{}'.format(local_model_dir, ModelFile.CONFIGURATION), | |||||
| model=model_id, | |||||
| device_id=local_rank) | |||||
| trainer = build_trainer( | |||||
| name=Trainers.clip_multi_modal_embedding, default_args=default_args) | |||||
| trainer.train() | |||||
| trainer.evaluate() | |||||
| class CLIPMultiModalEmbeddingTrainerTest(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_trainer(self): | |||||
| os.environ['MASTER_ADDR'] = '127.0.0.1' | |||||
| os.environ['MASTER_PORT'] = '2001' | |||||
| NODE_SIZE, NODE_RANK = 1, 0 | |||||
| logger.info('Train clip with {} machines'.format(NODE_SIZE)) | |||||
| ngpus = torch.cuda.device_count() | |||||
| logger.info('Machine: {} has {} GPUs'.format(NODE_RANK, ngpus)) | |||||
| mp.spawn( | |||||
| clip_train_worker, | |||||
| nprocs=ngpus, | |||||
| args=(ngpus, NODE_SIZE, NODE_RANK)) | |||||
| logger.info('Training done') | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| ... | |||||