From cb12a7c6f8d34859a8d8d85c120fda1e781d5afd Mon Sep 17 00:00:00 2001 From: "yichang.zyc" Date: Wed, 3 Aug 2022 21:25:16 +0800 Subject: [PATCH] [to #42322933] fea: support Chinese, eg: visual-grounding https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9627026 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9627026 --- modelscope/models/multi_modal/ofa/__init__.py | 3 +- .../multi_modal/ofa/configuration_ofa.py | 4 + .../models/multi_modal/ofa/modeling_ofa.py | 41 ++- .../multi_modal/ofa/tokenization_ofa.py | 322 ++++++++++++++++++ .../multi_modal/ofa/tokenization_ofa_fast.py | 157 ++++++++- .../models/multi_modal/ofa_for_all_tasks.py | 18 +- modelscope/preprocessors/ofa/base.py | 12 +- tests/pipelines/test_ofa_tasks.py | 35 ++ 8 files changed, 563 insertions(+), 29 deletions(-) diff --git a/modelscope/models/multi_modal/ofa/__init__.py b/modelscope/models/multi_modal/ofa/__init__.py index 433e8266..16de7fff 100644 --- a/modelscope/models/multi_modal/ofa/__init__.py +++ b/modelscope/models/multi_modal/ofa/__init__.py @@ -1,2 +1,3 @@ from .modeling_ofa import OFADecoder, OFAEncoder, OFAModel, OFAPreTrainedModel -from .tokenization_ofa import OFATokenizer +from .tokenization_ofa import OFATokenizer, OFATokenizerZH +from .tokenization_ofa_fast import OFATokenizerFast, OFATokenizerZHFast diff --git a/modelscope/models/multi_modal/ofa/configuration_ofa.py b/modelscope/models/multi_modal/ofa/configuration_ofa.py index 4d28dcc5..4899f416 100644 --- a/modelscope/models/multi_modal/ofa/configuration_ofa.py +++ b/modelscope/models/multi_modal/ofa/configuration_ofa.py @@ -134,6 +134,8 @@ class OFAConfig(PretrainedConfig): code_layernorm_embedding=True, code_image_size=128, entangle_position_embedding=False, + interpolate_position=False, + orig_patch_image_size=224, **kwargs): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings @@ -173,6 +175,8 @@ class OFAConfig(PretrainedConfig): self.code_layernorm_embedding = code_layernorm_embedding self.code_image_size = code_image_size self.entangle_position_embedding = entangle_position_embedding + self.interpolate_position = interpolate_position + self.orig_patch_image_size = orig_patch_image_size super().__init__( pad_token_id=pad_token_id, diff --git a/modelscope/models/multi_modal/ofa/modeling_ofa.py b/modelscope/models/multi_modal/ofa/modeling_ofa.py index b0350d1d..01cc02f9 100755 --- a/modelscope/models/multi_modal/ofa/modeling_ofa.py +++ b/modelscope/models/multi_modal/ofa/modeling_ofa.py @@ -311,7 +311,6 @@ class OFAAttention(nn.Module): self.head_dim * num_heads == self.embed_dim ), f'embed_dim must be divisible by num_heads ' \ f'(got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads}).' - # self.scaling = self.head_dim ** -0.5 # 1. difference scale_factor = 2 self.scaling = float(self.head_dim * scale_factor)**-0.5 @@ -913,7 +912,6 @@ class OFAEncoder(OFAPreTrainedModel): else: raise NotImplementedError - # self.image_proj = nn.Linear(1024, embed_dim) self.image_proj = Linear(1024, embed_dim) if config.resnet_model_path: @@ -1075,7 +1073,25 @@ class OFAEncoder(OFAPreTrainedModel): image_num_patches = sample_patch_num image_padding_mask = image_padding_mask.gather(1, patch_orders) image_position_ids = image_position_ids.gather(1, patch_orders) - image_pos_embed = self.embed_image_positions(image_position_ids) + orig_num_patches = (self.config.orig_patch_image_size // 16)**2 + orig_hw = self.config.orig_patch_image_size // 16 + if self.config.interpolate_position and image_num_patches > orig_num_patches: + old_image_position_ids = torch.arange(orig_hw).unsqueeze(0).expand(orig_hw, orig_hw) + \ + torch.arange(orig_hw).unsqueeze(1) * \ + self.config.image_bucket_size + 1 # noqa + old_image_position_ids = old_image_position_ids.to(device) + old_image_pos_embed = self.embed_image_positions( + old_image_position_ids) + old_image_pos_embed = old_image_pos_embed.reshape( + 1, orig_hw, orig_hw, -1).permute(0, 3, 1, 2) + image_pos_embed = F.interpolate( + old_image_pos_embed, size=(h, w), mode='bilinear') + image_pos_embed = image_pos_embed.permute(0, 2, 3, 1).reshape( + 1, image_num_patches, -1) + image_pos_embed = image_pos_embed.expand( + patch_images.size(0), -1, -1) + else: + image_pos_embed = self.embed_image_positions(image_position_ids) return image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed @@ -1250,7 +1266,6 @@ class OFAEncoder(OFAPreTrainedModel): position_embedding (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): positional embeddings of the input image and tokens. """ - image_embed = None image_embed_2 = None image_pos_embed = None @@ -1258,14 +1273,7 @@ class OFAEncoder(OFAPreTrainedModel): if patch_images is not None: image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed = \ self.get_patch_images_info(patch_images, sample_patch_num, input_ids.device) - # print("patch_masks.shape") - # print(patch_masks.shape) - # print(patch_masks) - # print("image_padding_mask.shape") - # print(image_padding_mask.shape) - # print(image_padding_mask) image_padding_mask[~patch_masks] = True - # print(image_padding_mask) if patch_images_2 is not None: image_embed_2, image_num_patches_2, image_padding_mask_2, image_position_ids_2, image_pos_embed_2 = \ self.get_patch_images_info(patch_images_2, sample_patch_num, input_ids.device) @@ -1313,10 +1321,6 @@ class OFAEncoder(OFAPreTrainedModel): encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None - # if output_hidden_states: - # # encoder_states.append(x) - # encoder_states += (x,) - # encoder layers for idx, layer in enumerate(self.layers): if output_hidden_states: @@ -1645,7 +1649,6 @@ class OFADecoder(OFAPreTrainedModel): def reorder_incremental_state_scripting( self, - # incremental_state: Dict[str, Dict[str, Optional[Tensor]]], past_key_values: Optional[torch.Tensor], new_order: Tensor, ): @@ -1799,15 +1802,12 @@ class OFADecoder(OFAPreTrainedModel): self_attn_bias = self_abs_pos_bias.clone() if code_masks is None or not code_masks.any(): - # print("code_masks is None or not code_masks.any()") self_attn_bias += self.get_rel_pos_bias( all_prev_output_tokens, idx).unsqueeze(0) elif code_masks is not None and code_masks.all(): - # print("code_masks is not None and code_masks.all()") self_attn_bias += self.get_image_rel_pos_bias( all_prev_output_tokens, idx).unsqueeze(0) else: - # print("else") self_attn_bias[~code_masks] += self.get_rel_pos_bias( all_prev_output_tokens, idx).unsqueeze(0) self_attn_bias[code_masks] += self.get_image_rel_pos_bias( @@ -1921,7 +1921,7 @@ class OFAModel(OFAPreTrainedModel): output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC, ) - # 新增函数以适配fairseq的generator + # an adaptor for fairseq generator def max_decoder_positions(self): """Maximum length supported by the decoder.""" return self.decoder.max_positions() @@ -2062,7 +2062,6 @@ class OFAModel(OFAPreTrainedModel): return Seq2SeqLMOutput( logits=decoder_outputs.last_hidden_state, - # last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, diff --git a/modelscope/models/multi_modal/ofa/tokenization_ofa.py b/modelscope/models/multi_modal/ofa/tokenization_ofa.py index e40436b6..158905eb 100644 --- a/modelscope/models/multi_modal/ofa/tokenization_ofa.py +++ b/modelscope/models/multi_modal/ofa/tokenization_ofa.py @@ -12,7 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tokenization classes for OFA.""" +import collections +import os +from typing import List, Optional, Tuple + +from transformers import PreTrainedTokenizer from transformers.models.bart.tokenization_bart import BartTokenizer +from transformers.models.bert.tokenization_bert import (BasicTokenizer, + WordpieceTokenizer) from transformers.utils import logging logger = logging.get_logger(__name__) @@ -26,12 +33,37 @@ PRETRAINED_VOCAB_FILES_MAP = { 'merges_file': { 'ofa-base': 'https://huggingface.co/ofa-base/resolve/main/merges.txt', }, + # OFA models are implemented to be compatible with both huggingface + # and modelscope frameworks. For all OFA models available on huggingface, + # please refer to https://huggingface.co/models?filter=ofa } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 'ofa-base': 1024, } +VOCAB_FILES_NAMES_ZH = {'vocab_file': 'vocab.txt'} + +PRETRAINED_VOCAB_FILES_MAP_ZH = { + 'vocab_file': { + 'bert-base-chinese': + 'https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt', + } + # OFA models are implemented to be compatible with both huggingface + # and modelscope frameworks. For all OFA models available on huggingface, + # please refer to https://huggingface.co/models?filter=ofa +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES_ZH = { + 'ofa-base': 1024, +} + +PRETRAINED_INIT_CONFIGURATION_ZH = { + 'bert-base-chinese': { + 'do_lower_case': True + }, +} + class OFATokenizer(BartTokenizer): """ @@ -46,3 +78,293 @@ class OFATokenizer(BartTokenizer): vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, 'r', encoding='utf-8') as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip('\n') + vocab[token] = index + return vocab + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class OFATokenizerZH(PreTrainedTokenizer): + r""" + Construct a OFA tokenizer. Based on WordPiece. + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + do_basic_tokenize (`bool`, *optional*, defaults to `True`): + Whether or not to do basic tokenization before WordPiece. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + """ + + vocab_files_names = VOCAB_FILES_NAMES_ZH + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP_ZH + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION_ZH + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES_ZH + + def __init__(self, + vocab_file, + do_lower_case=True, + do_basic_tokenize=True, + never_split=None, + bos_token='', + eos_token='', + sep_token='', + cls_token='', + unk_token='', + pad_token='', + mask_token='', + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs): + super().__init__( + do_lower_case=do_lower_case, + do_basic_tokenize=do_basic_tokenize, + never_split=never_split, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + **kwargs, + ) + + if not os.path.isfile(vocab_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained " + 'model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`' + ) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict([ + (ids, tok) for tok, ids in self.vocab.items() + ]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer( + do_lower_case=do_lower_case, + never_split=never_split, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + ) + self.wordpiece_tokenizer = WordpieceTokenizer( + vocab=self.vocab, unk_token=self.unk_token) + + @property + def do_lower_case(self): + return self.basic_tokenizer.do_lower_case + + @property + def vocab_size(self): + return len(self.vocab) + + def get_vocab(self): + return dict(self.vocab, **self.added_tokens_encoder) + + def _tokenize(self, text): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens): + + # If the token is part of the never_split set + if token in self.basic_tokenizer.never_split: + split_tokens.append(token) + else: + split_tokens += self.wordpiece_tokenizer.tokenize(token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + out_string = ' '.join(tokens).replace(' ##', '').strip() + return out_string + + def build_inputs_with_special_tokens( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, + token_ids_1=token_ids_1, + already_has_special_tokens=True) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ( + [0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def create_token_type_ids_from_sequences( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence + pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + + sep) * [1] + + def save_vocabulary(self, + save_directory: str, + filename_prefix: Optional[str] = None) -> Tuple[str]: + index = 0 + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, + (filename_prefix + '-' if filename_prefix else '') + + VOCAB_FILES_NAMES['vocab_file']) + else: + vocab_file = (filename_prefix + + '-' if filename_prefix else '') + save_directory + with open(vocab_file, 'w', encoding='utf-8') as writer: + for token, token_index in sorted( + self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f'Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive.' + ' Please check that the vocabulary is not corrupted!') + index = token_index + writer.write(token + '\n') + index += 1 + return (vocab_file, ) diff --git a/modelscope/models/multi_modal/ofa/tokenization_ofa_fast.py b/modelscope/models/multi_modal/ofa/tokenization_ofa_fast.py index 235d1b34..03d2d71e 100644 --- a/modelscope/models/multi_modal/ofa/tokenization_ofa_fast.py +++ b/modelscope/models/multi_modal/ofa/tokenization_ofa_fast.py @@ -12,10 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tokenization classes for OFA.""" +from typing import List, Optional, Tuple + +import json +from tokenizers import normalizers +from transformers import PreTrainedTokenizerFast from transformers.models.bart.tokenization_bart_fast import BartTokenizerFast from transformers.utils import logging -from .tokenization_ofa import OFATokenizer +from .tokenization_ofa import OFATokenizer, OFATokenizerZH logger = logging.get_logger(__name__) @@ -36,12 +41,37 @@ PRETRAINED_VOCAB_FILES_MAP = { 'ofa-base': 'https://huggingface.co/ofa-base/resolve/main/tokenizer.json', }, + # OFA models are implemented to be compatible with both huggingface + # and modelscope frameworks. For all OFA models available on huggingface, + # please refer to https://huggingface.co/models?filter=ofa } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 'ofa-base': 1024, } +VOCAB_FILES_NAMES_ZH = {'vocab_file': 'vocab.txt'} + +PRETRAINED_VOCAB_FILES_MAP_ZH = { + 'vocab_file': { + 'bert-base-chinese': + 'https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt', + } + # OFA models are implemeted to be compatible with both huggingface + # and modelscope frameworks. For all OFA models available on huggingface, + # please refer to https://huggingface.co/models?filter=ofa +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES_ZH = { + 'ofa-base': 1024, +} + +PRETRAINED_INIT_CONFIGURATION_ZH = { + 'bert-base-chinese': { + 'do_lower_case': True + }, +} + class OFATokenizerFast(BartTokenizerFast): r""" @@ -57,3 +87,128 @@ class OFATokenizerFast(BartTokenizerFast): pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES slow_tokenizer_class = OFATokenizer + + +class OFATokenizerZHFast(PreTrainedTokenizerFast): + r""" + Construct a "fast" OFA tokenizer (backed by HuggingFace's *tokenizers* library). + + [`~OFATokenizerFast`] is identical to [`BartTokenizerFast`] and runs end-to-end tokenization: punctuation splitting + and wordpiece. + + Refer to superclass [`BartTokenizerFast`] for usage examples and documentation concerning parameters. + """ + vocab_files_names = VOCAB_FILES_NAMES_ZH + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP_ZH + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION_ZH + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES_ZH + slow_tokenizer_class = OFATokenizerZH + + def __init__(self, + vocab_file=None, + tokenizer_file=None, + do_lower_case=True, + bos_token='', + eos_token='', + sep_token='', + cls_token='', + unk_token='', + pad_token='', + mask_token='', + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs): + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + do_lower_case=do_lower_case, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + **kwargs, + ) + + normalizer_state = json.loads( + self.backend_tokenizer.normalizer.__getstate__()) + if (normalizer_state.get('lowercase', do_lower_case) != do_lower_case + or normalizer_state.get('strip_accents', strip_accents) + != strip_accents or normalizer_state.get( + 'handle_chinese_chars', + tokenize_chinese_chars) != tokenize_chinese_chars): + normalizer_class = getattr(normalizers, + normalizer_state.pop('type')) + normalizer_state['lowercase'] = do_lower_case + normalizer_state['strip_accents'] = strip_accents + normalizer_state['handle_chinese_chars'] = tokenize_chinese_chars + self.backend_tokenizer.normalizer = normalizer_class( + **normalizer_state) + + self.do_lower_case = do_lower_case + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + + if token_ids_1: + output += token_ids_1 + [self.sep_token_id] + + return output + + def create_token_type_ids_from_sequences( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence + pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + + sep) * [1] + + def save_vocabulary(self, + save_directory: str, + filename_prefix: Optional[str] = None) -> Tuple[str]: + files = self._tokenizer.model.save( + save_directory, name=filename_prefix) + return tuple(files) diff --git a/modelscope/models/multi_modal/ofa_for_all_tasks.py b/modelscope/models/multi_modal/ofa_for_all_tasks.py index 0ec87d66..363d552d 100644 --- a/modelscope/models/multi_modal/ofa_for_all_tasks.py +++ b/modelscope/models/multi_modal/ofa_for_all_tasks.py @@ -16,7 +16,7 @@ from modelscope.preprocessors.ofa.utils.collate import collate_tokens from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile from modelscope.utils.trie import Trie -from .ofa import OFAModel, OFATokenizer +from .ofa import OFAModel, OFATokenizer, OFATokenizerZH from .ofa.generate import sequence_generator as sg from .ofa.generate.utils import move_to_device from .ofa.utils.constant import OFA_TASK_KEY_MAPPING, Tasks @@ -41,11 +41,21 @@ class OfaForAllTasks(TorchModel): self.cfg = Config.from_file( osp.join(model_dir, ModelFile.CONFIGURATION)) self.model = model.module if hasattr(model, 'module') else model - self.tokenizer = OFATokenizer.from_pretrained(model_dir) + self.language = self.cfg.model.get('language', 'en') + if self.language == 'en': + self.tokenizer = OFATokenizer.from_pretrained(model_dir) + elif self.language in ['zh', 'cn']: + self.tokenizer = OFATokenizerZH.from_pretrained(model_dir) + else: + raise NotImplementedError + # there is some diff between here and our ofa code, + # there will be no need to use param: use_bpe self.tokenizer.add_tokens([''.format(i) for i in range(8192)]) self.tokenizer.add_tokens([''.format(i) for i in range(1000)]) self.cfg.update({'num_bins': 1000, 'num_codes': 8192}) self.batch_size = self.cfg.model.get('batch_size', 1) + self.patch_image_size = self.cfg.model.get('patch_image_size', 480) + self.max_image_size = self.cfg.model.get('max_image_size', 512) self.val_batch_size = self.cfg.model.get('valid_batch_size', self.batch_size) self.gen_type = self.cfg.model.get('gen_type', 'generation') @@ -129,8 +139,8 @@ class OfaForAllTasks(TorchModel): - len(self.tokenizer.get_vocab().items()) + self.cfg.num_bins) region_tensor = torch.stack(region_coord_l, dim=0) - region_tensor = region_tensor / ( - self.cfg.num_bins - 1) * self.cfg.model.get('max_image_size', 512) + region_tensor = region_tensor / (self.cfg.num_bins + - 1) * self.max_image_size region_tensor[:, ::2] /= input['w_resize_ratios'] region_tensor[:, 1::2] /= input['h_resize_ratios'] return { diff --git a/modelscope/preprocessors/ofa/base.py b/modelscope/preprocessors/ofa/base.py index 8f53dbf7..fb9d06cd 100644 --- a/modelscope/preprocessors/ofa/base.py +++ b/modelscope/preprocessors/ofa/base.py @@ -6,7 +6,7 @@ import json import numpy as np import torch -from modelscope.models.multi_modal.ofa import OFATokenizer +from modelscope.models.multi_modal.ofa import OFATokenizer, OFATokenizerZH from modelscope.utils.trie import Trie from .utils.random_help import set_torch_seed @@ -21,7 +21,15 @@ class OfaBasePreprocessor: model_dir (str): model path """ self.cfg = cfg - tokenizer = OFATokenizer.from_pretrained(model_dir) + self.language = self.cfg.model.get('language', 'en') + if self.language == 'en': + tokenizer = OFATokenizer.from_pretrained(model_dir) + elif self.language in ['zh', 'cn']: + tokenizer = OFATokenizerZH.from_pretrained(model_dir) + else: + raise NotImplementedError + # there is some diff between here and our ofa code, + # there will be no need to use param: use_bpe tokenizer.add_tokens([''.format(i) for i in range(8192)]) tokenizer.add_tokens([''.format(i) for i in range(1000)]) self.tokenizer = tokenizer diff --git a/tests/pipelines/test_ofa_tasks.py b/tests/pipelines/test_ofa_tasks.py index 1dc7d303..5cba86b1 100644 --- a/tests/pipelines/test_ofa_tasks.py +++ b/tests/pipelines/test_ofa_tasks.py @@ -1,17 +1,33 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import os import unittest +from os import path as osp +import cv2 +import numpy as np from PIL import Image from modelscope.models import Model from modelscope.outputs import OutputKeys from modelscope.pipelines import pipeline +from modelscope.preprocessors.image import load_image from modelscope.utils.constant import Tasks from modelscope.utils.test_utils import test_level class OfaTasksTest(unittest.TestCase): + def setUp(self) -> None: + self.output_dir = 'unittest_output' + os.makedirs(self.output_dir, exist_ok=True) + + def save_img(self, image_in, box, image_out): + image = load_image(image_in) + img = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) + cv2.rectangle(img, (int(box[0]), int(box[1])), + (int(box[2]), int(box[3])), (0, 255, 0), 3) + cv2.imwrite(osp.join(self.output_dir, image_out), img) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_image_captioning_with_model(self): model = Model.from_pretrained('damo/ofa_image-caption_coco_large_en') @@ -132,6 +148,9 @@ class OfaTasksTest(unittest.TestCase): input = {'image': image, 'text': text} result = ofa_pipe(input) print(result) + image_name = image.split('/')[-2] + self.save_img(image, result[OutputKeys.BOXES], + osp.join('large_en_model_' + image_name + '.png')) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_visual_grounding_with_name(self): @@ -143,6 +162,22 @@ class OfaTasksTest(unittest.TestCase): input = {'image': image, 'text': text} result = ofa_pipe(input) print(result) + image_name = image.split('/')[-2] + self.save_img(image, result[OutputKeys.BOXES], + osp.join('large_en_name_' + image_name + '.png')) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_visual_grounding_zh_with_name(self): + model = 'damo/ofa_visual-grounding_refcoco_large_zh' + ofa_pipe = pipeline(Tasks.visual_grounding, model=model) + image = 'data/test/images/visual_grounding.png' + text = '一个圆头的蓝色宝可梦' + input = {'image': image, 'text': text} + result = ofa_pipe(input) + print(result) + image_name = image.split('/')[-1] + self.save_img(image, result[OutputKeys.BOXES], + osp.join('large_zh_name_' + image_name)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_visual_question_answering_with_model(self):