Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10572842master
| @@ -389,6 +389,7 @@ class Preprocessors(object): | |||||
| # multi-modal preprocessor | # multi-modal preprocessor | ||||
| ofa_tasks_preprocessor = 'ofa-tasks-preprocessor' | ofa_tasks_preprocessor = 'ofa-tasks-preprocessor' | ||||
| clip_preprocessor = 'clip-preprocessor' | |||||
| mplug_tasks_preprocessor = 'mplug-tasks-preprocessor' | mplug_tasks_preprocessor = 'mplug-tasks-preprocessor' | ||||
| # science preprocessor | # science preprocessor | ||||
| @@ -428,6 +429,8 @@ class Metrics(object): | |||||
| image_inpainting_metric = 'image-inpainting-metric' | image_inpainting_metric = 'image-inpainting-metric' | ||||
| # metric for ocr | # metric for ocr | ||||
| NED = 'ned' | NED = 'ned' | ||||
| # metric for cross-modal retrieval | |||||
| inbatch_recall = 'inbatch_recall' | |||||
| # metric for referring-video-object-segmentation task | # metric for referring-video-object-segmentation task | ||||
| referring_video_object_segmentation_metric = 'referring-video-object-segmentation-metric' | referring_video_object_segmentation_metric = 'referring-video-object-segmentation-metric' | ||||
| @@ -474,6 +477,9 @@ class Hooks(object): | |||||
| # Compression | # Compression | ||||
| SparsityHook = 'SparsityHook' | SparsityHook = 'SparsityHook' | ||||
| # CLIP logit_scale clamp | |||||
| ClipClampLogitScaleHook = 'ClipClampLogitScaleHook' | |||||
| class LR_Schedulers(object): | class LR_Schedulers(object): | ||||
| """learning rate scheduler is defined here | """learning rate scheduler is defined here | ||||
| @@ -24,6 +24,7 @@ class MetricKeys(object): | |||||
| ROUGE_1 = 'rouge-1' | ROUGE_1 = 'rouge-1' | ||||
| ROUGE_L = 'rouge-l' | ROUGE_L = 'rouge-l' | ||||
| NED = 'ned' # ocr metric | NED = 'ned' # ocr metric | ||||
| BatchAcc = 'inbatch_t2i_recall_at_1' | |||||
| task_default_metrics = { | task_default_metrics = { | ||||
| @@ -0,0 +1,55 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import Dict | |||||
| import numpy as np | |||||
| import torch | |||||
| from modelscope.metainfo import Metrics | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.utils.registry import default_group | |||||
| from .base import Metric | |||||
| from .builder import METRICS, MetricKeys | |||||
| @METRICS.register_module( | |||||
| group_key=default_group, module_name=Metrics.inbatch_recall) | |||||
| class InbatchRecallMetric(Metric): | |||||
| """The metric computation class for in-batch retrieval classes. | |||||
| This metric class calculates in-batch image recall@1 for each input batch. | |||||
| """ | |||||
| def __init__(self, *args, **kwargs): | |||||
| super().__init__(*args, **kwargs) | |||||
| self.inbatch_t2i_hitcnts = [] | |||||
| self.batch_sizes = [] | |||||
| def add(self, outputs: Dict, inputs: Dict): | |||||
| image_features = outputs[OutputKeys.IMG_EMBEDDING] | |||||
| text_features = outputs[OutputKeys.TEXT_EMBEDDING] | |||||
| assert type(image_features) == torch.Tensor and type( | |||||
| text_features) == torch.Tensor | |||||
| with torch.no_grad(): | |||||
| logits_per_image = image_features @ text_features.t() | |||||
| logits_per_text = logits_per_image.t() | |||||
| batch_size = logits_per_image.shape[0] | |||||
| ground_truth = torch.arange(batch_size).long() | |||||
| ground_truth = ground_truth.to(image_features.device) | |||||
| inbatch_t2i_hitcnt = (logits_per_text.argmax(-1) == ground_truth | |||||
| ).sum().float().item() | |||||
| self.inbatch_t2i_hitcnts.append(inbatch_t2i_hitcnt) | |||||
| self.batch_sizes.append(batch_size) | |||||
| def evaluate(self): | |||||
| assert len(self.inbatch_t2i_hitcnts) == len( | |||||
| self.batch_sizes) and len(self.batch_sizes) > 0 | |||||
| return { | |||||
| MetricKeys.BatchAcc: | |||||
| sum(self.inbatch_t2i_hitcnts) / sum(self.batch_sizes) | |||||
| } | |||||
| @@ -15,15 +15,13 @@ | |||||
| import os | import os | ||||
| from collections import OrderedDict | from collections import OrderedDict | ||||
| from typing import Any, Dict, Iterable, List, Tuple, Union | |||||
| from typing import Any, Dict, Tuple, Union | |||||
| import json | import json | ||||
| import numpy as np | import numpy as np | ||||
| import torch | import torch | ||||
| import torch.nn as nn | import torch.nn as nn | ||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| from PIL import Image | |||||
| from torchvision.transforms import Compose, Normalize, Resize, ToTensor | |||||
| from modelscope.metainfo import Models | from modelscope.metainfo import Models | ||||
| from modelscope.models import TorchModel | from modelscope.models import TorchModel | ||||
| @@ -506,21 +504,6 @@ def convert_weights(model: nn.Module): | |||||
| model.apply(_convert_weights_to_fp16) | model.apply(_convert_weights_to_fp16) | ||||
| def _convert_to_rgb(image): | |||||
| return image.convert('RGB') | |||||
| def image_transform(image_size=224): | |||||
| transform = Compose([ | |||||
| _convert_to_rgb, | |||||
| Resize((image_size, image_size)), | |||||
| ToTensor(), | |||||
| Normalize((0.48145466, 0.4578275, 0.40821073), | |||||
| (0.26862954, 0.26130258, 0.27577711)), | |||||
| ]) | |||||
| return transform | |||||
| @MODELS.register_module(Tasks.multi_modal_embedding, module_name=Models.clip) | @MODELS.register_module(Tasks.multi_modal_embedding, module_name=Models.clip) | ||||
| class CLIPForMultiModalEmbedding(TorchModel): | class CLIPForMultiModalEmbedding(TorchModel): | ||||
| @@ -540,72 +523,40 @@ class CLIPForMultiModalEmbedding(TorchModel): | |||||
| with open(vision_model_config_file, | with open(vision_model_config_file, | ||||
| 'r') as fv, open(text_model_config_file, 'r') as ft: | 'r') as fv, open(text_model_config_file, 'r') as ft: | ||||
| model_info = json.load(fv) | |||||
| self.model_info = json.load(fv) | |||||
| for k, v in json.load(ft).items(): | for k, v in json.load(ft).items(): | ||||
| model_info[k] = v | |||||
| # image preprocess | |||||
| self.img_preprocess = image_transform(model_info['image_resolution']) | |||||
| self.model_info[k] = v | |||||
| # text tokenizer | |||||
| vocab_file = f'{model_dir}/{ModelFile.VOCAB_FILE}' | vocab_file = f'{model_dir}/{ModelFile.VOCAB_FILE}' | ||||
| self.tokenizer = FullTokenizer(vocab_file=vocab_file) | self.tokenizer = FullTokenizer(vocab_file=vocab_file) | ||||
| # initialize the model | # initialize the model | ||||
| self.clip_model = CLIP(**model_info, tokenizer=self.tokenizer) | |||||
| self.clip_model = CLIP(**self.model_info, tokenizer=self.tokenizer) | |||||
| convert_weights(self.clip_model) | convert_weights(self.clip_model) | ||||
| # restore the pretrained weight | # restore the pretrained weight | ||||
| checkpoint = torch.load( | checkpoint = torch.load( | ||||
| f'{model_dir}/{ModelFile.TORCH_MODEL_BIN_FILE}', 'cpu') | f'{model_dir}/{ModelFile.TORCH_MODEL_BIN_FILE}', 'cpu') | ||||
| sd = checkpoint['state_dict'] | |||||
| sd = checkpoint[ | |||||
| 'state_dict'] if 'state_dict' in checkpoint else checkpoint | |||||
| if next(iter(sd.items()))[0].startswith('module'): | if next(iter(sd.items()))[0].startswith('module'): | ||||
| sd = {k[len('module.'):]: v for k, v in sd.items()} | sd = {k[len('module.'):]: v for k, v in sd.items()} | ||||
| # support the finetuned model | |||||
| if next(iter(sd.items()))[0].startswith('clip_model'): | |||||
| sd = {k[len('clip_model.'):]: v for k, v in sd.items()} | |||||
| self.clip_model.load_state_dict(sd) | self.clip_model.load_state_dict(sd) | ||||
| self.clip_model.eval() | self.clip_model.eval() | ||||
| # place the model | # place the model | ||||
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |||||
| if self.device == 'cuda': | |||||
| self.device = 'cuda:{}'.format(int(os.environ.get( | |||||
| 'LOCAL_RANK', 0))) if torch.cuda.is_available() else 'cpu' | |||||
| if torch.cuda.is_available(): | |||||
| self.clip_model.to(self.device) | self.clip_model.to(self.device) | ||||
| logger.info('Use GPU for inference') | |||||
| logger.info('Use GPU {} for finetuning & inference'.format( | |||||
| int(os.environ.get('LOCAL_RANK', 0)))) | |||||
| else: | else: | ||||
| self.clip_model.float() | self.clip_model.float() | ||||
| logger.info('Use CPU for inference') | |||||
| def tokenize(self, | |||||
| texts: Union[str, List[str]], | |||||
| context_length: int = 52) -> torch.LongTensor: | |||||
| """ | |||||
| Returns the tokenized representation of given input string(s) | |||||
| Parameters | |||||
| ---------- | |||||
| texts : Union[str, List[str]] | |||||
| An input string or a list of input strings to tokenize | |||||
| context_length : int | |||||
| The context length to use; all baseline models use 24 as the context length | |||||
| Returns | |||||
| ------- | |||||
| A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] | |||||
| """ | |||||
| if isinstance(texts, str): | |||||
| texts = [texts] | |||||
| all_tokens = [] | |||||
| for text in texts: | |||||
| all_tokens.append( | |||||
| [self.tokenizer.vocab['[CLS]']] | |||||
| + self.tokenizer.convert_tokens_to_ids( | |||||
| self.tokenizer.tokenize(text))[:context_length - 2] | |||||
| + [self.tokenizer.vocab['[SEP]']]) | |||||
| result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) | |||||
| for i, tokens in enumerate(all_tokens): | |||||
| assert len(tokens) <= context_length | |||||
| result[i, :len(tokens)] = torch.tensor(tokens) | |||||
| return result | |||||
| logger.info('Use CPU for finetuning & inference') | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | ||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| @@ -613,75 +564,36 @@ class CLIPForMultiModalEmbedding(TorchModel): | |||||
| OutputKeys.IMG_EMBEDDING: None, | OutputKeys.IMG_EMBEDDING: None, | ||||
| OutputKeys.TEXT_EMBEDDING: None | OutputKeys.TEXT_EMBEDDING: None | ||||
| } | } | ||||
| if 'img' in input and input['img'] is not None: | |||||
| image_input = input['img'] | |||||
| # single image input | |||||
| if isinstance(image_input, Image.Image): | |||||
| image_tensor = self.img_preprocess(image_input).unsqueeze(0) | |||||
| # multi images input | |||||
| elif isinstance(image_input, list): | |||||
| if all([isinstance(elem, Image.Image) | |||||
| for elem in image_input]): | |||||
| image_tensor = torch.stack( | |||||
| [self.img_preprocess(elem) for elem in image_input], | |||||
| dim=0) | |||||
| else: | |||||
| unsupported_elem_type = [ | |||||
| type(elem) for elem in image_input | |||||
| if not isinstance(elem, Image.Image) | |||||
| ][0] | |||||
| raise TypeError( | |||||
| f'img should be PIL.Image or List[PIL.Image], \ | |||||
| but got a List containing one {unsupported_elem_type}' | |||||
| ) | |||||
| # others | |||||
| else: | |||||
| raise TypeError( | |||||
| f'img should be PIL.Image or List[PIL.Image], but got {type(image_input)}' | |||||
| ) | |||||
| image_tensor = image_tensor.to(self.device) | |||||
| with torch.no_grad(): | |||||
| mode = input.get('mode', ModeKeys.INFERENCE) | |||||
| # encode the image | |||||
| if 'img' in input and isinstance(input['img'], torch.Tensor): | |||||
| image_tensor = input['img'].to(self.device) | |||||
| if image_tensor.dim() == 5 and image_tensor.shape[1] == 1: | |||||
| image_tensor = image_tensor.squeeze(1) | |||||
| with torch.autograd.set_grad_enabled(mode == ModeKeys.TRAIN): | |||||
| image_features = self.clip_model.encode_image(image_tensor) | image_features = self.clip_model.encode_image(image_tensor) | ||||
| image_features /= image_features.norm( | image_features /= image_features.norm( | ||||
| dim=-1, keepdim=True) # l2-normalize | dim=-1, keepdim=True) # l2-normalize | ||||
| output[OutputKeys.IMG_EMBEDDING] = image_features | output[OutputKeys.IMG_EMBEDDING] = image_features | ||||
| if 'text' in input and input['text'] is not None: | |||||
| text_input = input['text'] | |||||
| # single text input | |||||
| if isinstance(text_input, str): | |||||
| text_tensor = self.tokenize(text_input) | |||||
| # multi texts input | |||||
| elif isinstance(text_input, list): | |||||
| if all([isinstance(elem, str) for elem in text_input]): | |||||
| text_tensor = self.tokenize(text_input) | |||||
| else: | |||||
| unsupported_elem_type = [ | |||||
| type(elem) for elem in text_input | |||||
| if not isinstance(elem, str) | |||||
| ][0] | |||||
| raise TypeError( | |||||
| f'text should be str or List[str], but got a List containing one {unsupported_elem_type}' | |||||
| ) | |||||
| # others | |||||
| else: | |||||
| raise TypeError( | |||||
| f'text should be str or List[str], but got {type(text_input)}' | |||||
| ) | |||||
| text_tensor = text_tensor.to(self.device) | |||||
| with torch.no_grad(): | |||||
| if 'text' in input and isinstance(input['text'], torch.Tensor): | |||||
| text_tensor = input['text'].to(self.device) | |||||
| if text_tensor.dim() == 3 and text_tensor.shape[1] == 1: | |||||
| text_tensor = text_tensor.squeeze(1) | |||||
| with torch.autograd.set_grad_enabled(mode == ModeKeys.TRAIN): | |||||
| text_features = self.clip_model.encode_text(text_tensor) | text_features = self.clip_model.encode_text(text_tensor) | ||||
| text_features /= text_features.norm( | text_features /= text_features.norm( | ||||
| dim=-1, keepdim=True) # l2-normalize | dim=-1, keepdim=True) # l2-normalize | ||||
| output[OutputKeys.TEXT_EMBEDDING] = text_features | output[OutputKeys.TEXT_EMBEDDING] = text_features | ||||
| if mode == ModeKeys.TRAIN: | |||||
| output['logit_scale'] = (self.clip_model.logit_scale | |||||
| * 1.0).exp().mean() | |||||
| return output | return output | ||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||||
| @@ -1,10 +1,12 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from typing import Any, Dict | |||||
| from typing import Any, Dict, Optional, Union | |||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.models.multi_modal.clip.model import CLIPForMultiModalEmbedding | |||||
| from modelscope.pipelines.base import Input, Model, Pipeline | from modelscope.pipelines.base import Input, Model, Pipeline | ||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.preprocessors.multi_modal import CLIPPreprocessor, Preprocessor | |||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| @@ -17,7 +19,10 @@ logger = get_logger() | |||||
| Tasks.multi_modal_embedding, module_name=Pipelines.multi_modal_embedding) | Tasks.multi_modal_embedding, module_name=Pipelines.multi_modal_embedding) | ||||
| class MultiModalEmbeddingPipeline(Pipeline): | class MultiModalEmbeddingPipeline(Pipeline): | ||||
| def __init__(self, model: str, device: str = 'gpu'): | |||||
| def __init__(self, | |||||
| model: Union[Model, str], | |||||
| preprocessor: Optional[Preprocessor] = None, | |||||
| **kwargs): | |||||
| """ | """ | ||||
| use `model` and `preprocessor` to create a kws pipeline for prediction | use `model` and `preprocessor` to create a kws pipeline for prediction | ||||
| Args: | Args: | ||||
| @@ -29,14 +34,17 @@ class MultiModalEmbeddingPipeline(Pipeline): | |||||
| pipe_model = model | pipe_model = model | ||||
| else: | else: | ||||
| raise NotImplementedError('model must be a single str') | raise NotImplementedError('model must be a single str') | ||||
| pipe_model.eval() | |||||
| if preprocessor is None: | |||||
| if isinstance(pipe_model, CLIPForMultiModalEmbedding): | |||||
| preprocessor = CLIPPreprocessor(pipe_model.model_dir) | |||||
| else: | |||||
| raise NotImplementedError | |||||
| super().__init__(model=pipe_model) | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| return input | |||||
| super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | ||||
| return self.model(input) | |||||
| return self.model(self.preprocess(input)) | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||||
| return inputs | return inputs | ||||
| @@ -3,8 +3,11 @@ import os.path as osp | |||||
| from io import BytesIO | from io import BytesIO | ||||
| from typing import Any, Dict, List, Tuple, Union | from typing import Any, Dict, List, Tuple, Union | ||||
| import json | |||||
| import torch | import torch | ||||
| from PIL import Image | from PIL import Image | ||||
| from timm.data import create_transform | |||||
| from torchvision.transforms import Compose, Normalize, Resize, ToTensor | |||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| from modelscope.metainfo import Preprocessors | from modelscope.metainfo import Preprocessors | ||||
| @@ -107,6 +110,180 @@ class OfaPreprocessor(Preprocessor): | |||||
| eos_idx=self.tokenizer.eos_token_id) | eos_idx=self.tokenizer.eos_token_id) | ||||
| def _convert_to_rgb(image): | |||||
| return image.convert('RGB') | |||||
| @PREPROCESSORS.register_module( | |||||
| Fields.multi_modal, module_name=Preprocessors.clip_preprocessor) | |||||
| class CLIPPreprocessor(Preprocessor): | |||||
| def __init__(self, | |||||
| model_dir: str, | |||||
| mode=ModeKeys.INFERENCE, | |||||
| *args, | |||||
| **kwargs): | |||||
| """preprocess the data | |||||
| Args: | |||||
| model_dir (str): model path | |||||
| mode: preprocessor mode (model mode) | |||||
| """ | |||||
| super().__init__(*args, **kwargs) | |||||
| model_dir = model_dir if osp.exists(model_dir) else snapshot_download( | |||||
| model_dir) | |||||
| self.mode = mode | |||||
| # text tokenizer | |||||
| from modelscope.models.multi_modal.clip.bert_tokenizer import FullTokenizer | |||||
| if 'tokenizer' in kwargs and isinstance(kwargs['tokenizer'], | |||||
| FullTokenizer): | |||||
| self.tokenizer = kwargs['tokenizer'] | |||||
| else: | |||||
| vocab_file = f'{model_dir}/{ModelFile.VOCAB_FILE}' | |||||
| self.tokenizer = FullTokenizer(vocab_file=vocab_file) | |||||
| # image preprocessor | |||||
| if 'resolution' in kwargs and isinstance(kwargs['resolution'], int): | |||||
| self.image_resolution = kwargs['resolution'] | |||||
| else: | |||||
| self.image_resolution = json.load( | |||||
| open('{}/vision_model_config.json'.format( | |||||
| model_dir)))['image_resolution'] | |||||
| self.img_preprocess = self._build_image_transform() | |||||
| # key mapping | |||||
| # specify the input keys, compatible with training and inference whose key names may be different | |||||
| self.input_keys = {'img': 'img', 'text': 'text'} | |||||
| def _build_image_transform(self): | |||||
| if self.mode == ModeKeys.TRAIN: | |||||
| transform = create_transform( | |||||
| input_size=self.image_resolution, | |||||
| scale=(0.9, 1.0), | |||||
| is_training=True, | |||||
| color_jitter=None, | |||||
| auto_augment='original', | |||||
| interpolation='bicubic', | |||||
| mean=(0.48145466, 0.4578275, 0.40821073), | |||||
| std=(0.26862954, 0.26130258, 0.27577711), | |||||
| ) | |||||
| transform = Compose(transform.transforms[:-3] + [_convert_to_rgb] | |||||
| + transform.transforms[-3:]) | |||||
| else: | |||||
| transform = Compose([ | |||||
| Resize((self.image_resolution, self.image_resolution), | |||||
| interpolation=Image.BICUBIC), | |||||
| _convert_to_rgb, | |||||
| ToTensor(), | |||||
| Normalize((0.48145466, 0.4578275, 0.40821073), | |||||
| (0.26862954, 0.26130258, 0.27577711)), | |||||
| ]) | |||||
| return transform | |||||
| def tokenize(self, | |||||
| texts: Union[str, List[str]], | |||||
| context_length: int = 52) -> torch.LongTensor: | |||||
| """ | |||||
| Returns the tokenized representation of given input string(s) | |||||
| Parameters | |||||
| ---------- | |||||
| texts : Union[str, List[str]] | |||||
| An input string or a list of input strings to tokenize | |||||
| context_length : int | |||||
| The context length to use; all baseline models use 24 as the context length | |||||
| Returns | |||||
| ------- | |||||
| A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] | |||||
| """ | |||||
| if isinstance(texts, str): | |||||
| texts = [texts] | |||||
| all_tokens = [] | |||||
| for text in texts: | |||||
| all_tokens.append( | |||||
| [self.tokenizer.vocab['[CLS]']] | |||||
| + self.tokenizer.convert_tokens_to_ids( | |||||
| self.tokenizer.tokenize(text))[:context_length - 2] | |||||
| + [self.tokenizer.vocab['[SEP]']]) | |||||
| result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) | |||||
| for i, tokens in enumerate(all_tokens): | |||||
| assert len(tokens) <= context_length | |||||
| result[i, :len(tokens)] = torch.tensor(tokens) | |||||
| return result | |||||
| def set_input_img_key(self, new_key: str): | |||||
| self.input_keys['img'] = new_key | |||||
| def set_input_text_key(self, new_key: str): | |||||
| self.input_keys['text'] = new_key | |||||
| def __call__(self, input: Union[str, tuple, Dict[str, Any]], *args, | |||||
| **kwargs) -> Dict[str, Any]: | |||||
| output = {} | |||||
| # preprocess the image input | |||||
| input_img_key = self.input_keys['img'] | |||||
| if input_img_key in input and input[input_img_key] is not None: | |||||
| image_input = input[input_img_key] | |||||
| # single image input | |||||
| if isinstance(image_input, Image.Image): | |||||
| image_tensor = self.img_preprocess(image_input).unsqueeze(0) | |||||
| # multi images input | |||||
| elif isinstance(image_input, list): | |||||
| if all([isinstance(elem, Image.Image) | |||||
| for elem in image_input]): | |||||
| image_tensor = torch.stack( | |||||
| [self.img_preprocess(elem) | |||||
| for elem in image_input], # noqa | |||||
| dim=0) # noqa | |||||
| else: | |||||
| unsupported_elem_type = [ | |||||
| type(elem) for elem in image_input | |||||
| if not isinstance(elem, Image.Image) | |||||
| ][0] | |||||
| raise TypeError( | |||||
| f'img should be PIL.Image or List[PIL.Image], \ | |||||
| but got a List containing one {unsupported_elem_type}' | |||||
| ) | |||||
| # others | |||||
| else: | |||||
| raise TypeError( | |||||
| f'img should be PIL.Image or List[PIL.Image], but got {type(image_input)}' | |||||
| ) | |||||
| output['img'] = image_tensor | |||||
| # preprocess the text input | |||||
| input_text_key = self.input_keys['text'] | |||||
| if input_text_key in input and input[input_text_key] is not None: | |||||
| text_input = input[input_text_key] | |||||
| # single text input | |||||
| if isinstance(text_input, str): | |||||
| text_tensor = self.tokenize(text_input) | |||||
| # multi texts input | |||||
| elif isinstance(text_input, list): | |||||
| if all([isinstance(elem, str) for elem in text_input]): | |||||
| text_tensor = self.tokenize(text_input) | |||||
| else: | |||||
| unsupported_elem_type = [ | |||||
| type(elem) for elem in text_input | |||||
| if not isinstance(elem, str) | |||||
| ][0] | |||||
| raise TypeError( | |||||
| f'text should be str or List[str], but got a List containing one {unsupported_elem_type}' | |||||
| ) | |||||
| # others | |||||
| else: | |||||
| raise TypeError( | |||||
| f'text should be str or List[str], but got {type(text_input)}' | |||||
| ) | |||||
| output['text'] = text_tensor | |||||
| return output | |||||
| @PREPROCESSORS.register_module( | @PREPROCESSORS.register_module( | ||||
| Fields.multi_modal, module_name=Preprocessors.mplug_tasks_preprocessor) | Fields.multi_modal, module_name=Preprocessors.mplug_tasks_preprocessor) | ||||
| class MPlugPreprocessor(Preprocessor): | class MPlugPreprocessor(Preprocessor): | ||||
| @@ -0,0 +1,18 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import torch | |||||
| from modelscope.metainfo import Hooks | |||||
| from modelscope.trainers.multi_modal.clip.clip_trainer import CLIPTrainer | |||||
| from .builder import HOOKS | |||||
| from .hook import Hook | |||||
| @HOOKS.register_module(module_name=Hooks.ClipClampLogitScaleHook) | |||||
| class ClipClampLogitScaleHook(Hook): | |||||
| """ClipClampLogitScaleHook hook which performs clamp on CLIP logit scale parameter after update""" | |||||
| def after_train_iter(self, trainer: CLIPTrainer): | |||||
| """Called after every training iter to evaluate the results.""" | |||||
| unwrapped_model = getattr(trainer.model, 'module', trainer.model) | |||||
| logit_scale = unwrapped_model.clip_model.logit_scale | |||||
| logit_scale.data = torch.clamp(logit_scale.data, 0, 4.6052) | |||||
| @@ -1,169 +1,206 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import math | |||||
| import os | import os | ||||
| from typing import Dict, Optional | |||||
| from typing import Callable, Dict, Optional, Tuple, Union | |||||
| import torch | import torch | ||||
| import torch.distributed as dist | |||||
| from torch.utils.data import DataLoader | |||||
| from torch.utils.data.distributed import DistributedSampler | |||||
| from torch import distributed as dist | |||||
| from torch import nn | |||||
| from torch.utils.data import Dataset | |||||
| from modelscope.metainfo import Trainers | from modelscope.metainfo import Trainers | ||||
| from modelscope.models.base import Model | |||||
| from modelscope.trainers.base import BaseTrainer | |||||
| from modelscope.models.base import Model, TorchModel | |||||
| from modelscope.models.multi_modal.clip.model import convert_models_to_fp32 | |||||
| from modelscope.msdatasets.ms_dataset import MsDataset | |||||
| from modelscope.preprocessors.base import Preprocessor | |||||
| from modelscope.preprocessors.multi_modal import CLIPPreprocessor | |||||
| from modelscope.trainers import EpochBasedTrainer | |||||
| from modelscope.trainers.builder import TRAINERS | from modelscope.trainers.builder import TRAINERS | ||||
| from modelscope.trainers.optimizer.builder import build_optimizer | |||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import ModeKeys | |||||
| from modelscope.utils.logger import get_logger | |||||
| from .clip_trainer_utils import ImageWithCaptionDataset, get_optimizer | |||||
| from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys, | |||||
| ModeKeys) | |||||
| from .clip_trainer_utils import get_loss, get_optimizer_params, get_schedule | |||||
| logger = get_logger() | |||||
| def exclude(n): | |||||
| return 'bn' in n or 'ln' in n or 'bias' in n or 'logit_scale' in n | |||||
| def include(n): | |||||
| return not exclude(n) | |||||
| @TRAINERS.register_module(module_name=Trainers.clip_multi_modal_embedding) | @TRAINERS.register_module(module_name=Trainers.clip_multi_modal_embedding) | ||||
| class CLIPTrainer(BaseTrainer): | |||||
| def __init__(self, cfg_file: str, model: str, device_id: int, *args, | |||||
| **kwargs): | |||||
| super().__init__(cfg_file) | |||||
| self.cfg = Config.from_file(cfg_file) | |||||
| self.model = Model.from_pretrained(model) | |||||
| self.device_id = device_id | |||||
| self.total_epoch = self.cfg.train.epoch | |||||
| self.train_batch_size = self.cfg.train.batch_size | |||||
| self.val_batch_size = self.cfg.evaluation.batch_size | |||||
| self.ckpt_dir = self.cfg.train.ckpt_dir | |||||
| self.train_dataset = ImageWithCaptionDataset( | |||||
| json_file='{}/{}'.format(self.cfg.dataset.root_dir, | |||||
| self.cfg.dataset.train_set), | |||||
| img_dir=self.cfg.dataset.root_dir, | |||||
| phase=ModeKeys.TRAIN) | |||||
| self.val_dataset = ImageWithCaptionDataset( | |||||
| json_file='{}/{}'.format(self.cfg.dataset.root_dir, | |||||
| self.cfg.dataset.val_set), | |||||
| img_dir=self.cfg.dataset.root_dir, | |||||
| phase=ModeKeys.EVAL) | |||||
| def train(self, *args, **kwargs): | |||||
| assert dist.is_initialized() | |||||
| self.model.clip_model.train() | |||||
| self.model.clip_model.to(self.device_id) | |||||
| ddp_model = torch.nn.parallel.DistributedDataParallel( | |||||
| self.model.clip_model, device_ids=[ | |||||
| self.device_id, | |||||
| ]) | |||||
| optimizer = get_optimizer(ddp_model) | |||||
| for epoch in range(self.total_epoch): | |||||
| train_sampler = DistributedSampler( | |||||
| dataset=self.train_dataset, shuffle=True) | |||||
| train_sampler.set_epoch(epoch) | |||||
| train_params = { | |||||
| 'pin_memory': True, | |||||
| 'collate_fn': None, | |||||
| 'batch_size': self.train_batch_size, | |||||
| 'shuffle': False, | |||||
| 'drop_last': True, | |||||
| 'sampler': train_sampler, | |||||
| 'num_workers': 8 | |||||
| class CLIPTrainer(EpochBasedTrainer): | |||||
| def __init__( | |||||
| self, | |||||
| model: Optional[Union[TorchModel, nn.Module, str]] = None, | |||||
| cfg_file: Optional[str] = None, | |||||
| arg_parse_fn: Optional[Callable] = None, | |||||
| data_collator: Optional[Union[Callable, Dict[str, | |||||
| Callable]]] = None, | |||||
| train_dataset: Optional[Union[MsDataset, Dataset]] = None, | |||||
| eval_dataset: Optional[Union[MsDataset, Dataset]] = None, | |||||
| preprocessor: Optional[Union[Preprocessor, | |||||
| Dict[str, Preprocessor]]] = None, | |||||
| optimizers: Tuple[torch.optim.Optimizer, | |||||
| torch.optim.lr_scheduler._LRScheduler] = (None, | |||||
| None), | |||||
| model_revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||||
| seed: int = 42, | |||||
| **kwargs): | |||||
| model = Model.from_pretrained(model, revision=model_revision) | |||||
| # for training & eval, we convert the model from FP16 back to FP32 | |||||
| # to compatible with modelscope amp training | |||||
| convert_models_to_fp32(model) | |||||
| cfg = Config.from_file(cfg_file) | |||||
| if 'work_dir' not in kwargs or len(kwargs['work_dir']) == 0: | |||||
| work_dir = cfg.train.work_dir | |||||
| else: | |||||
| work_dir = kwargs['work_dir'] | |||||
| # fetch the model name of CLIP model (base, large or large-336) | |||||
| model_name = cfg.pretrained_model.model_name | |||||
| # world size | |||||
| world_size = int(os.environ.get('WORLD_SIZE', 1)) | |||||
| # train step, optimizer and lr_scheduler | |||||
| epoch_steps = math.ceil( | |||||
| len(train_dataset) / # noqa | |||||
| (cfg.train.dataloader.batch_size_per_gpu * world_size)) # noqa | |||||
| cfg.train.lr_scheduler.num_train_steps = epoch_steps * cfg.train.max_epochs | |||||
| if optimizers[0] is None: | |||||
| named_parameters = list(model.named_parameters()) | |||||
| gain_or_bias_params = [ | |||||
| p for n, p in named_parameters | |||||
| if exclude(n) and p.requires_grad | |||||
| ] | |||||
| rest_params = [ | |||||
| p for n, p in named_parameters | |||||
| if include(n) and p.requires_grad | |||||
| ] | |||||
| optimizer_hparams = get_optimizer_params( | |||||
| model_name, cfg) # lr, wd, beta1, beta2, eps | |||||
| optimizer_args = { | |||||
| 'params': [ | |||||
| { | |||||
| 'params': gain_or_bias_params, | |||||
| 'weight_decay': 0. | |||||
| }, | |||||
| { | |||||
| 'params': rest_params, | |||||
| 'weight_decay': optimizer_hparams['weight_decay'] | |||||
| }, | |||||
| ], | |||||
| 'lr': | |||||
| optimizer_hparams['lr'], | |||||
| 'betas': | |||||
| (optimizer_hparams['beta1'], optimizer_hparams['beta2']), | |||||
| 'eps': | |||||
| optimizer_hparams['eps'], | |||||
| } | |||||
| optimizer = build_optimizer( | |||||
| model, cfg=cfg.train.optimizer, default_args=optimizer_args) | |||||
| else: | |||||
| optimizer = optimizers[0] | |||||
| if optimizers[1] is None: | |||||
| lr_scheduler = get_schedule(optimizer, cfg.train.lr_scheduler) | |||||
| else: | |||||
| lr_scheduler = optimizers[1] | |||||
| optimizers = (optimizer, lr_scheduler) | |||||
| # loss module | |||||
| loss_img = nn.CrossEntropyLoss() | |||||
| loss_txt = nn.CrossEntropyLoss() | |||||
| self.loss_img = loss_img.cuda(int(os.environ.get('LOCAL_RANK', 0))) | |||||
| self.loss_txt = loss_txt.cuda(int(os.environ.get('LOCAL_RANK', 0))) | |||||
| self.loss_cfg = cfg.train.loss_cfg | |||||
| # launcher and use_fp16 | |||||
| if 'launcher' not in kwargs and cfg.train.get('launcher', None): | |||||
| kwargs['launcher'] = cfg.train.launcher | |||||
| if 'use_fp16' not in kwargs and cfg.train.get('use_fp16', False): | |||||
| kwargs['use_fp16'] = cfg.train.use_fp16 | |||||
| # preprocessor | |||||
| if preprocessor is None: | |||||
| preprocessor = { | |||||
| ConfigKeys.train: | |||||
| CLIPPreprocessor( | |||||
| model_dir=work_dir, | |||||
| mode=ModeKeys.TRAIN, | |||||
| tokenizer=model.tokenizer, | |||||
| resolution=model.model_info['image_resolution']), | |||||
| ConfigKeys.val: | |||||
| CLIPPreprocessor( | |||||
| model_dir=work_dir, | |||||
| mode=ModeKeys.EVAL, | |||||
| tokenizer=model.tokenizer, | |||||
| resolution=model.model_info['image_resolution']), | |||||
| } | } | ||||
| train_loader = DataLoader(self.train_dataset, **train_params) | |||||
| for batch_idx, (img_tensor, text_str_list, | |||||
| img_id_list) in enumerate(train_loader): | |||||
| text_info_list = [ | |||||
| self.model.tokenize_text(tmp) for tmp in text_str_list | |||||
| ] | |||||
| text_ids_tensor = torch.cat([tmp[0] for tmp in text_info_list], | |||||
| dim=0) | |||||
| text_masks_tensor = torch.cat( | |||||
| [tmp[1] for tmp in text_info_list], dim=0) | |||||
| img_tensor = img_tensor.to(self.device_id, non_blocking=True) | |||||
| img_id_list = img_id_list.to(self.device_id, non_blocking=True) | |||||
| text_ids_tensor = text_ids_tensor.to( | |||||
| self.device_id, non_blocking=True) | |||||
| text_masks_tensor = text_masks_tensor.to( | |||||
| self.device_id, non_blocking=True) | |||||
| loss = ddp_model((img_tensor, text_ids_tensor, | |||||
| text_masks_tensor, img_id_list), | |||||
| ModeKeys.TRAIN) | |||||
| optimizer.zero_grad() | |||||
| loss.backward() | |||||
| optimizer.step() | |||||
| if batch_idx % 10 == 0: | |||||
| logger.info( | |||||
| 'epoch: {}, train batch {}/{}, loss={:.5f}, logit_scale={:.5f}' | |||||
| .format(epoch, batch_idx, len(train_loader), | |||||
| loss.item(), | |||||
| ddp_model.module.logit_scale.exp().item())) | |||||
| if dist.get_rank() == 0: | |||||
| os.makedirs(self.ckpt_dir, exist_ok=True) | |||||
| torch.save(ddp_model.module.state_dict(), | |||||
| '{}/epoch{}.pth'.format(self.ckpt_dir, epoch)) | |||||
| def evaluate(self, | |||||
| checkpoint_path: Optional[str] = None, | |||||
| *args, | |||||
| **kwargs) -> Dict[str, float]: | |||||
| if checkpoint_path is not None: | |||||
| checkpoint_params = torch.load(checkpoint_path, 'cpu') | |||||
| self.model.clip_model.load_state_dict(checkpoint_params) | |||||
| self.model.clip_model.eval() | |||||
| self.model.clip_model.to(self.device_id) | |||||
| val_params = { | |||||
| 'collate_fn': None, | |||||
| 'batch_size': self.val_batch_size, | |||||
| 'shuffle': False, | |||||
| 'drop_last': False, | |||||
| 'num_workers': 8 | |||||
| } | |||||
| val_loader = DataLoader(self.val_dataset, **val_params) | |||||
| tp_cnt_per_batch = [] | |||||
| processed_cnt = 0 | |||||
| with torch.no_grad(): | |||||
| for batch_idx, (img_tensor, text_str_list, | |||||
| img_id_list) in enumerate(val_loader): | |||||
| text_info_list = [ | |||||
| self.model.tokenize_text(tmp) for tmp in text_str_list | |||||
| ] | |||||
| text_ids_tensor = torch.cat([tmp[0] for tmp in text_info_list], | |||||
| dim=0) | |||||
| text_masks_tensor = torch.cat( | |||||
| [tmp[1] for tmp in text_info_list], dim=0) | |||||
| img_tensor = img_tensor.to(self.device_id, non_blocking=True) | |||||
| img_id_list = img_id_list.to(self.device_id, non_blocking=True) | |||||
| text_ids_tensor = text_ids_tensor.to( | |||||
| self.device_id, non_blocking=True) | |||||
| text_masks_tensor = text_masks_tensor.to( | |||||
| self.device_id, non_blocking=True) | |||||
| img_feat = self.model.clip_model(img_tensor, input_type='img') | |||||
| text_feat = self.model.clip_model( | |||||
| (text_ids_tensor, text_masks_tensor), input_type='text') | |||||
| sim_mat = text_feat @ img_feat.t() | |||||
| text_cnt, img_cnt = sim_mat.shape | |||||
| top1_scores, match_ids = torch.max(sim_mat, dim=1) | |||||
| match_ids = match_ids.int() | |||||
| gt_ids = torch.tensor(range(0, text_cnt)).to( | |||||
| self.device_id, non_blocking=True).int() | |||||
| error_cnt = torch.nonzero(match_ids - gt_ids) | |||||
| processed_cnt += text_cnt | |||||
| tp_cnt_per_batch.append(text_cnt - 1.0 * error_cnt.numel()) | |||||
| logger.info('current acc: {:.3f}'.format( | |||||
| sum(tp_cnt_per_batch) / processed_cnt)) | |||||
| # dataset related | |||||
| self.dataset_cfg = cfg.dataset | |||||
| if hasattr(self.dataset_cfg, 'column_map'): | |||||
| # cases where dataset key names are not "img" and "text" | |||||
| img_key_name = getattr(self.dataset_cfg.column_map, 'img', 'img') | |||||
| preprocessor[ConfigKeys.train].set_input_img_key(img_key_name) | |||||
| preprocessor[ConfigKeys.val].set_input_img_key(img_key_name) | |||||
| text_key_name = getattr(self.dataset_cfg.column_map, 'text', | |||||
| 'text') | |||||
| preprocessor[ConfigKeys.train].set_input_text_key(text_key_name) | |||||
| preprocessor[ConfigKeys.val].set_input_text_key(text_key_name) | |||||
| self.global_batch_size = cfg.train.dataloader.batch_size_per_gpu * world_size | |||||
| super().__init__( | |||||
| model=model, | |||||
| cfg_file=cfg_file, | |||||
| arg_parse_fn=arg_parse_fn, | |||||
| data_collator=data_collator, | |||||
| train_dataset=train_dataset, | |||||
| eval_dataset=eval_dataset, | |||||
| preprocessor=preprocessor, | |||||
| optimizers=optimizers, | |||||
| seed=seed, | |||||
| **kwargs, | |||||
| ) | |||||
| def train_step(self, model, inputs): | |||||
| model.train() | |||||
| inputs['mode'] = ModeKeys.TRAIN | |||||
| model_outputs = model.forward( | |||||
| inputs | |||||
| ) # {OutputKeys.IMG_EMBEDDING: Tensor(batch_size, dim), OutputKeys.TEXT_EMBEDDING: Tensor(batch_size, dim)} | |||||
| loss = get_loss(model_outputs, self.loss_img, self.loss_txt, | |||||
| self.loss_cfg) | |||||
| train_outputs = {'loss': loss} | |||||
| # add model output info to log | |||||
| if 'log_vars' not in train_outputs: | |||||
| default_keys_pattern = ['loss'] | |||||
| match_keys = set([]) | |||||
| for key_p in default_keys_pattern: | |||||
| match_keys.update( | |||||
| [key for key in train_outputs.keys() if key_p in key]) | |||||
| log_vars = {} | |||||
| for key in match_keys: | |||||
| value = train_outputs.get(key, None) | |||||
| if value is not None: | |||||
| if dist.is_available() and dist.is_initialized(): | |||||
| value = value.data.clone() | |||||
| dist.all_reduce(value.div_(dist.get_world_size())) | |||||
| log_vars.update({key: value.item()}) | |||||
| unwrapped_model = getattr(model, 'module', model) | |||||
| log_vars[ | |||||
| 'logit_scale'] = unwrapped_model.clip_model.logit_scale.data.clone( | |||||
| ).item() # noqa | |||||
| log_vars['global_batch_size'] = int(self.global_batch_size) | |||||
| self.log_buffer.update(log_vars) | |||||
| else: | |||||
| self.log_buffer.update(train_outputs['log_vars']) | |||||
| self.train_outputs = train_outputs | |||||
| @@ -1,94 +1,125 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| # Copyright 2022 The OFA-Sys Team. | |||||
| # All rights reserved. | |||||
| # This source code is licensed under the Apache 2.0 license | |||||
| # found in the LICENSE file in the root directory. | |||||
| import math | |||||
| import os | import os | ||||
| import random | |||||
| from functools import partial | |||||
| from inspect import unwrap | |||||
| import json | |||||
| import torch | import torch | ||||
| import torch.nn.functional as F | |||||
| from PIL import Image | |||||
| from torch.utils.data import Dataset | |||||
| from torchvision import transforms | |||||
| from modelscope.utils.constant import ModeKeys | |||||
| train_transform = transforms.Compose([ | |||||
| transforms.RandomResizedCrop( | |||||
| 224, scale=(0.5, 1.0), interpolation=Image.BICUBIC), | |||||
| transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], | |||||
| p=0.8), | |||||
| transforms.RandomGrayscale(p=0.2), | |||||
| transforms.RandomHorizontalFlip(), | |||||
| transforms.ToTensor(), | |||||
| transforms.Normalize((0.48145466, 0.4578275, 0.40821073), | |||||
| (0.26862954, 0.26130258, 0.27577711)) | |||||
| ]) | |||||
| val_transform = transforms.Compose([ | |||||
| transforms.Resize((224, 224), interpolation=Image.BICUBIC), | |||||
| transforms.ToTensor(), | |||||
| transforms.Normalize((0.48145466, 0.4578275, 0.40821073), | |||||
| (0.26862954, 0.26130258, 0.27577711)) | |||||
| ]) | |||||
| class ImageWithCaptionDataset(Dataset): | |||||
| def __init__(self, json_file, img_dir, phase): | |||||
| self.annotations = json.load(open(json_file)) | |||||
| self.img_dir = img_dir | |||||
| if phase == ModeKeys.TRAIN: | |||||
| self.transform = train_transform | |||||
| elif phase == ModeKeys.EVAL: | |||||
| self.transform = val_transform | |||||
| self.img_name2img_id = {} | |||||
| for anno_dict in self.annotations: | |||||
| img_name = anno_dict['image'] | |||||
| if img_name not in self.img_name2img_id: | |||||
| self.img_name2img_id[img_name] = len(self.img_name2img_id) | |||||
| def __len__(self): | |||||
| return len(self.annotations) | |||||
| def __getitem__(self, index): | |||||
| anno_dict = self.annotations[index] | |||||
| img_path = os.path.join(self.img_dir, anno_dict['image']) | |||||
| img_pil = Image.open(img_path).convert('RGB') | |||||
| img_th = self.transform(img_pil) | |||||
| img_id = self.img_name2img_id[anno_dict['image']] | |||||
| text_str = random.choice(anno_dict['caption']) | |||||
| return img_th, text_str, img_id | |||||
| def get_params_groups(ddp_model, weight_decay): | |||||
| decay = [] | |||||
| no_decay = [] | |||||
| for name, param in ddp_model.named_parameters(): | |||||
| if not param.requires_grad: | |||||
| continue | |||||
| if len(param.shape) == 1 or name.endswith('.bias'): | |||||
| no_decay.append(param) | |||||
| else: | |||||
| decay.append(param) | |||||
| params_groups = [{ | |||||
| 'params': no_decay, | |||||
| 'weight_decay': 0. | |||||
| }, { | |||||
| 'params': decay, | |||||
| 'weight_decay': weight_decay | |||||
| }] | |||||
| return params_groups | |||||
| def get_optimizer(ddp_model): | |||||
| from torch.optim import AdamW | |||||
| lr_init = 1e-5 | |||||
| betas = [0.9, 0.999] | |||||
| weight_decay = 0.02 | |||||
| params_groups = get_params_groups(ddp_model, weight_decay=weight_decay) | |||||
| return AdamW( | |||||
| params_groups, lr=lr_init, betas=betas, weight_decay=weight_decay) | |||||
| import torch.distributed as dist | |||||
| from torch.optim.lr_scheduler import LambdaLR | |||||
| from modelscope.outputs import OutputKeys | |||||
| def get_optimizer_params(model_name, cfg): | |||||
| # get default params | |||||
| # Params from paper (https://arxiv.org/pdf/2103.00020.pdf) | |||||
| # base model | |||||
| if model_name in ['damo/multi-modal_clip-vit-base-patch16_zh']: | |||||
| params = { | |||||
| 'lr': 5.0e-4, | |||||
| 'beta1': 0.9, | |||||
| 'beta2': 0.98, | |||||
| 'eps': 1.0e-6, | |||||
| 'weight_decay': 0.0 | |||||
| } | |||||
| # large models | |||||
| elif model_name in [ | |||||
| 'damo/multi-modal_clip-vit-large-patch14_zh', | |||||
| 'damo/multi-modal_clip-vit-large-patch14_336_zh' | |||||
| ]: | |||||
| params = { | |||||
| 'lr': 4.0e-4, | |||||
| 'beta1': 0.9, | |||||
| 'beta2': 0.98, | |||||
| 'eps': 1.0e-6, | |||||
| 'weight_decay': 0.0 | |||||
| } | |||||
| else: | |||||
| params = { | |||||
| 'lr': 5.0e-4, | |||||
| 'beta1': 0.9, | |||||
| 'beta2': 0.999, | |||||
| 'eps': 1.0e-8, | |||||
| 'weight_decay': 0.0 | |||||
| } | |||||
| # override with config params | |||||
| for key in ['lr', 'beta1', 'beta2', 'eps', 'weight_decay']: | |||||
| if hasattr(cfg.train, 'optimizer_hparams'): | |||||
| params[key] = getattr(cfg.train.optimizer_hparams, key, | |||||
| params[key]) | |||||
| return params | |||||
| def get_loss(model_outputs, loss_img, loss_txt, loss_cfg): | |||||
| image_features = model_outputs[OutputKeys.IMG_EMBEDDING] | |||||
| text_features = model_outputs[OutputKeys.TEXT_EMBEDDING] | |||||
| logit_scale = model_outputs['logit_scale'] | |||||
| logit_scale = logit_scale.mean() | |||||
| if loss_cfg.aggregate and int(os.environ.get('WORLD_SIZE', 1)) > 1: | |||||
| world_size = dist.get_world_size() | |||||
| rank = dist.get_rank() | |||||
| # We gather tensors from all gpus to get more negatives to contrast with. | |||||
| gathered_image_features = [ | |||||
| torch.zeros_like(image_features) for _ in range(world_size) | |||||
| ] | |||||
| gathered_text_features = [ | |||||
| torch.zeros_like(text_features) for _ in range(world_size) | |||||
| ] | |||||
| dist.all_gather(gathered_image_features, image_features) | |||||
| dist.all_gather(gathered_text_features, text_features) | |||||
| all_image_features = torch.cat([image_features] | |||||
| + gathered_image_features[:rank] | |||||
| + gathered_image_features[rank + 1:]) | |||||
| all_text_features = torch.cat([text_features] | |||||
| + gathered_text_features[:rank] | |||||
| + gathered_text_features[rank + 1:]) | |||||
| # this is needed to send gradients back everywhere. | |||||
| logits_per_image = logit_scale * all_image_features @ all_text_features.t( | |||||
| ) | |||||
| logits_per_text = logits_per_image.t() | |||||
| else: | |||||
| logits_per_image = logit_scale * image_features @ text_features.t() | |||||
| logits_per_text = logit_scale * text_features @ image_features.t() | |||||
| ground_truth = torch.arange(len(logits_per_image)).long() | |||||
| ground_truth = ground_truth.cuda( | |||||
| int(os.environ.get('LOCAL_RANK', 0)), non_blocking=True) | |||||
| total_loss = (loss_img(logits_per_image, ground_truth) | |||||
| + loss_txt(logits_per_text, ground_truth)) / 2 | |||||
| return total_loss | |||||
| def lr_lambda(num_warmup_steps, num_training_steps, num_cycles, current_step): | |||||
| if current_step < num_warmup_steps: | |||||
| return float(current_step) / float(max(1, num_warmup_steps)) | |||||
| progress = float(current_step - num_warmup_steps) / float( | |||||
| max(1, num_training_steps - num_warmup_steps)) | |||||
| return max( | |||||
| 0.0, | |||||
| 0.5 * # noqa | |||||
| (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) # noqa | |||||
| def get_schedule(optimizer, | |||||
| scheduler, | |||||
| num_cycles: float = 0.5, | |||||
| last_epoch: int = -1): | |||||
| num_warmup_steps = int(scheduler.warmup_proportion | |||||
| * scheduler.num_train_steps) | |||||
| num_training_steps = scheduler.num_train_steps | |||||
| return LambdaLR( | |||||
| optimizer, | |||||
| partial(lr_lambda, num_warmup_steps, num_training_steps, num_cycles), | |||||
| last_epoch) | |||||
| @@ -24,7 +24,7 @@ class MultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| def test_run(self): | def test_run(self): | ||||
| pipeline_multi_modal_embedding = pipeline( | pipeline_multi_modal_embedding = pipeline( | ||||
| Tasks.multi_modal_embedding, model=self.model_id) | Tasks.multi_modal_embedding, model=self.model_id) | ||||
| text_embedding = pipeline_multi_modal_embedding( | |||||
| text_embedding = pipeline_multi_modal_embedding.forward( | |||||
| self.test_input)[OutputKeys.TEXT_EMBEDDING] | self.test_input)[OutputKeys.TEXT_EMBEDDING] | ||||
| print('l1-norm: {}'.format( | print('l1-norm: {}'.format( | ||||
| torch.norm(text_embedding, p=1, dim=-1).item())) | torch.norm(text_embedding, p=1, dim=-1).item())) | ||||
| @@ -36,7 +36,7 @@ class MultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| model = Model.from_pretrained(self.model_id) | model = Model.from_pretrained(self.model_id) | ||||
| pipeline_multi_modal_embedding = pipeline( | pipeline_multi_modal_embedding = pipeline( | ||||
| task=Tasks.multi_modal_embedding, model=model) | task=Tasks.multi_modal_embedding, model=model) | ||||
| text_embedding = pipeline_multi_modal_embedding( | |||||
| text_embedding = pipeline_multi_modal_embedding.forward( | |||||
| self.test_input)[OutputKeys.TEXT_EMBEDDING] | self.test_input)[OutputKeys.TEXT_EMBEDDING] | ||||
| print('l1-norm: {}'.format( | print('l1-norm: {}'.format( | ||||
| torch.norm(text_embedding, p=1, dim=-1).item())) | torch.norm(text_embedding, p=1, dim=-1).item())) | ||||
| @@ -47,7 +47,7 @@ class MultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| def test_run_with_default_model(self): | def test_run_with_default_model(self): | ||||
| pipeline_multi_modal_embedding = pipeline( | pipeline_multi_modal_embedding = pipeline( | ||||
| task=Tasks.multi_modal_embedding) | task=Tasks.multi_modal_embedding) | ||||
| text_embedding = pipeline_multi_modal_embedding( | |||||
| text_embedding = pipeline_multi_modal_embedding.forward( | |||||
| self.test_input)[OutputKeys.TEXT_EMBEDDING] | self.test_input)[OutputKeys.TEXT_EMBEDDING] | ||||
| print('l1-norm: {}'.format( | print('l1-norm: {}'.format( | ||||
| torch.norm(text_embedding, p=1, dim=-1).item())) | torch.norm(text_embedding, p=1, dim=-1).item())) | ||||
| @@ -0,0 +1,83 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import shutil | |||||
| import unittest | |||||
| import json | |||||
| from modelscope.metainfo import Metrics, Trainers | |||||
| from modelscope.msdatasets import MsDataset | |||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.constant import ModelFile | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class TestClipTrainer(unittest.TestCase): | |||||
| def setUp(self) -> None: | |||||
| self.finetune_cfg = \ | |||||
| {'framework': 'pytorch', | |||||
| 'task': 'multi-modal-embedding', | |||||
| 'pipeline': {'type': 'multi-modal-embedding'}, | |||||
| 'pretrained_model': {'model_name': 'damo/multi-modal_clip-vit-base-patch16_zh'}, | |||||
| 'dataset': {'column_map': {'img': 'image', 'text': 'query'}}, | |||||
| 'train': {'work_dir': './workspace/ckpts/clip', | |||||
| # 'launcher': 'pytorch', | |||||
| 'max_epochs': 1, | |||||
| 'use_fp16': True, | |||||
| 'dataloader': {'batch_size_per_gpu': 8, | |||||
| 'workers_per_gpu': 0, | |||||
| 'shuffle': True, | |||||
| 'drop_last': True}, | |||||
| 'lr_scheduler': {'name': 'cosine', | |||||
| 'warmup_proportion': 0.01}, | |||||
| 'lr_scheduler_hook': {'type': 'LrSchedulerHook', 'by_epoch': False}, | |||||
| 'optimizer': {'type': 'AdamW'}, | |||||
| 'optimizer_hparams': {'lr': 5e-05, 'weight_decay': 0.01}, | |||||
| 'optimizer_hook': {'type': 'TorchAMPOptimizerHook', | |||||
| 'cumulative_iters': 1, | |||||
| 'loss_keys': 'loss'}, | |||||
| 'loss_cfg': {'aggregate': True}, | |||||
| 'hooks': [{'type': 'BestCkptSaverHook', | |||||
| 'metric_key': 'inbatch_t2i_recall_at_1', | |||||
| 'interval': 100}, | |||||
| {'type': 'TextLoggerHook', 'interval': 1}, | |||||
| {'type': 'IterTimerHook'}, | |||||
| {'type': 'EvaluationHook', 'by_epoch': True, 'interval': 1}, | |||||
| {'type': 'ClipClampLogitScaleHook'}]}, | |||||
| 'evaluation': {'dataloader': {'batch_size_per_gpu': 8, | |||||
| 'workers_per_gpu': 0, | |||||
| 'shuffle': True, | |||||
| 'drop_last': True}, | |||||
| 'metrics': [{'type': 'inbatch_recall'}]}, | |||||
| 'preprocessor': []} | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_trainer_std(self): | |||||
| WORKSPACE = './workspace/ckpts/clip' | |||||
| os.makedirs(WORKSPACE, exist_ok=True) | |||||
| config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) | |||||
| with open(config_file, 'w') as writer: | |||||
| json.dump(self.finetune_cfg, writer) | |||||
| pretrained_model = 'damo/multi-modal_clip-vit-base-patch16_zh' | |||||
| args = dict( | |||||
| model=pretrained_model, | |||||
| work_dir=WORKSPACE, | |||||
| train_dataset=MsDataset.load( | |||||
| 'muge', namespace='modelscope', split='train[:200]'), | |||||
| eval_dataset=MsDataset.load( | |||||
| 'muge', namespace='modelscope', split='validation[:100]'), | |||||
| metrics=[Metrics.inbatch_recall], | |||||
| cfg_file=config_file) | |||||
| trainer = build_trainer( | |||||
| name=Trainers.clip_multi_modal_embedding, default_args=args) | |||||
| trainer.train() | |||||
| self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, | |||||
| os.listdir(os.path.join(WORKSPACE, 'output'))) | |||||
| shutil.rmtree(WORKSPACE) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||