diff --git a/modelscope/models/multi_modal/ofa/__init__.py b/modelscope/models/multi_modal/ofa/__init__.py
index 433e8266..16de7fff 100644
--- a/modelscope/models/multi_modal/ofa/__init__.py
+++ b/modelscope/models/multi_modal/ofa/__init__.py
@@ -1,2 +1,3 @@
from .modeling_ofa import OFADecoder, OFAEncoder, OFAModel, OFAPreTrainedModel
-from .tokenization_ofa import OFATokenizer
+from .tokenization_ofa import OFATokenizer, OFATokenizerZH
+from .tokenization_ofa_fast import OFATokenizerFast, OFATokenizerZHFast
diff --git a/modelscope/models/multi_modal/ofa/configuration_ofa.py b/modelscope/models/multi_modal/ofa/configuration_ofa.py
index 4d28dcc5..4899f416 100644
--- a/modelscope/models/multi_modal/ofa/configuration_ofa.py
+++ b/modelscope/models/multi_modal/ofa/configuration_ofa.py
@@ -134,6 +134,8 @@ class OFAConfig(PretrainedConfig):
code_layernorm_embedding=True,
code_image_size=128,
entangle_position_embedding=False,
+ interpolate_position=False,
+ orig_patch_image_size=224,
**kwargs):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
@@ -173,6 +175,8 @@ class OFAConfig(PretrainedConfig):
self.code_layernorm_embedding = code_layernorm_embedding
self.code_image_size = code_image_size
self.entangle_position_embedding = entangle_position_embedding
+ self.interpolate_position = interpolate_position
+ self.orig_patch_image_size = orig_patch_image_size
super().__init__(
pad_token_id=pad_token_id,
diff --git a/modelscope/models/multi_modal/ofa/modeling_ofa.py b/modelscope/models/multi_modal/ofa/modeling_ofa.py
index b0350d1d..01cc02f9 100755
--- a/modelscope/models/multi_modal/ofa/modeling_ofa.py
+++ b/modelscope/models/multi_modal/ofa/modeling_ofa.py
@@ -311,7 +311,6 @@ class OFAAttention(nn.Module):
self.head_dim * num_heads == self.embed_dim
), f'embed_dim must be divisible by num_heads ' \
f'(got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads}).'
- # self.scaling = self.head_dim ** -0.5
# 1. difference
scale_factor = 2
self.scaling = float(self.head_dim * scale_factor)**-0.5
@@ -913,7 +912,6 @@ class OFAEncoder(OFAPreTrainedModel):
else:
raise NotImplementedError
- # self.image_proj = nn.Linear(1024, embed_dim)
self.image_proj = Linear(1024, embed_dim)
if config.resnet_model_path:
@@ -1075,7 +1073,25 @@ class OFAEncoder(OFAPreTrainedModel):
image_num_patches = sample_patch_num
image_padding_mask = image_padding_mask.gather(1, patch_orders)
image_position_ids = image_position_ids.gather(1, patch_orders)
- image_pos_embed = self.embed_image_positions(image_position_ids)
+ orig_num_patches = (self.config.orig_patch_image_size // 16)**2
+ orig_hw = self.config.orig_patch_image_size // 16
+ if self.config.interpolate_position and image_num_patches > orig_num_patches:
+ old_image_position_ids = torch.arange(orig_hw).unsqueeze(0).expand(orig_hw, orig_hw) + \
+ torch.arange(orig_hw).unsqueeze(1) * \
+ self.config.image_bucket_size + 1 # noqa
+ old_image_position_ids = old_image_position_ids.to(device)
+ old_image_pos_embed = self.embed_image_positions(
+ old_image_position_ids)
+ old_image_pos_embed = old_image_pos_embed.reshape(
+ 1, orig_hw, orig_hw, -1).permute(0, 3, 1, 2)
+ image_pos_embed = F.interpolate(
+ old_image_pos_embed, size=(h, w), mode='bilinear')
+ image_pos_embed = image_pos_embed.permute(0, 2, 3, 1).reshape(
+ 1, image_num_patches, -1)
+ image_pos_embed = image_pos_embed.expand(
+ patch_images.size(0), -1, -1)
+ else:
+ image_pos_embed = self.embed_image_positions(image_position_ids)
return image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed
@@ -1250,7 +1266,6 @@ class OFAEncoder(OFAPreTrainedModel):
position_embedding (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`):
positional embeddings of the input image and tokens.
"""
-
image_embed = None
image_embed_2 = None
image_pos_embed = None
@@ -1258,14 +1273,7 @@ class OFAEncoder(OFAPreTrainedModel):
if patch_images is not None:
image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed = \
self.get_patch_images_info(patch_images, sample_patch_num, input_ids.device)
- # print("patch_masks.shape")
- # print(patch_masks.shape)
- # print(patch_masks)
- # print("image_padding_mask.shape")
- # print(image_padding_mask.shape)
- # print(image_padding_mask)
image_padding_mask[~patch_masks] = True
- # print(image_padding_mask)
if patch_images_2 is not None:
image_embed_2, image_num_patches_2, image_padding_mask_2, image_position_ids_2, image_pos_embed_2 = \
self.get_patch_images_info(patch_images_2, sample_patch_num, input_ids.device)
@@ -1313,10 +1321,6 @@ class OFAEncoder(OFAPreTrainedModel):
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
- # if output_hidden_states:
- # # encoder_states.append(x)
- # encoder_states += (x,)
-
# encoder layers
for idx, layer in enumerate(self.layers):
if output_hidden_states:
@@ -1645,7 +1649,6 @@ class OFADecoder(OFAPreTrainedModel):
def reorder_incremental_state_scripting(
self,
- # incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
past_key_values: Optional[torch.Tensor],
new_order: Tensor,
):
@@ -1799,15 +1802,12 @@ class OFADecoder(OFAPreTrainedModel):
self_attn_bias = self_abs_pos_bias.clone()
if code_masks is None or not code_masks.any():
- # print("code_masks is None or not code_masks.any()")
self_attn_bias += self.get_rel_pos_bias(
all_prev_output_tokens, idx).unsqueeze(0)
elif code_masks is not None and code_masks.all():
- # print("code_masks is not None and code_masks.all()")
self_attn_bias += self.get_image_rel_pos_bias(
all_prev_output_tokens, idx).unsqueeze(0)
else:
- # print("else")
self_attn_bias[~code_masks] += self.get_rel_pos_bias(
all_prev_output_tokens, idx).unsqueeze(0)
self_attn_bias[code_masks] += self.get_image_rel_pos_bias(
@@ -1921,7 +1921,7 @@ class OFAModel(OFAPreTrainedModel):
output_type=Seq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
)
- # 新增函数以适配fairseq的generator
+ # an adaptor for fairseq generator
def max_decoder_positions(self):
"""Maximum length supported by the decoder."""
return self.decoder.max_positions()
@@ -2062,7 +2062,6 @@ class OFAModel(OFAPreTrainedModel):
return Seq2SeqLMOutput(
logits=decoder_outputs.last_hidden_state,
- # last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
diff --git a/modelscope/models/multi_modal/ofa/tokenization_ofa.py b/modelscope/models/multi_modal/ofa/tokenization_ofa.py
index e40436b6..158905eb 100644
--- a/modelscope/models/multi_modal/ofa/tokenization_ofa.py
+++ b/modelscope/models/multi_modal/ofa/tokenization_ofa.py
@@ -12,7 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OFA."""
+import collections
+import os
+from typing import List, Optional, Tuple
+
+from transformers import PreTrainedTokenizer
from transformers.models.bart.tokenization_bart import BartTokenizer
+from transformers.models.bert.tokenization_bert import (BasicTokenizer,
+ WordpieceTokenizer)
from transformers.utils import logging
logger = logging.get_logger(__name__)
@@ -26,12 +33,37 @@ PRETRAINED_VOCAB_FILES_MAP = {
'merges_file': {
'ofa-base': 'https://huggingface.co/ofa-base/resolve/main/merges.txt',
},
+ # OFA models are implemented to be compatible with both huggingface
+ # and modelscope frameworks. For all OFA models available on huggingface,
+ # please refer to https://huggingface.co/models?filter=ofa
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'ofa-base': 1024,
}
+VOCAB_FILES_NAMES_ZH = {'vocab_file': 'vocab.txt'}
+
+PRETRAINED_VOCAB_FILES_MAP_ZH = {
+ 'vocab_file': {
+ 'bert-base-chinese':
+ 'https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt',
+ }
+ # OFA models are implemented to be compatible with both huggingface
+ # and modelscope frameworks. For all OFA models available on huggingface,
+ # please refer to https://huggingface.co/models?filter=ofa
+}
+
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES_ZH = {
+ 'ofa-base': 1024,
+}
+
+PRETRAINED_INIT_CONFIGURATION_ZH = {
+ 'bert-base-chinese': {
+ 'do_lower_case': True
+ },
+}
+
class OFATokenizer(BartTokenizer):
"""
@@ -46,3 +78,293 @@ class OFATokenizer(BartTokenizer):
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+
+
+def load_vocab(vocab_file):
+ """Loads a vocabulary file into a dictionary."""
+ vocab = collections.OrderedDict()
+ with open(vocab_file, 'r', encoding='utf-8') as reader:
+ tokens = reader.readlines()
+ for index, token in enumerate(tokens):
+ token = token.rstrip('\n')
+ vocab[token] = index
+ return vocab
+
+
+def whitespace_tokenize(text):
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
+ text = text.strip()
+ if not text:
+ return []
+ tokens = text.split()
+ return tokens
+
+
+class OFATokenizerZH(PreTrainedTokenizer):
+ r"""
+ Construct a OFA tokenizer. Based on WordPiece.
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ File containing the vocabulary.
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ do_basic_tokenize (`bool`, *optional*, defaults to `True`):
+ Whether or not to do basic tokenization before WordPiece.
+ never_split (`Iterable`, *optional*):
+ Collection of tokens which will never be split during tokenization. Only has an effect when
+ `do_basic_tokenize=True`
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
+ sequence. The token used is the `cls_token`.
+
+
+
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ sep_token (`str`, *optional*, defaults to `""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ cls_token (`str`, *optional*, defaults to `""`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
+ Whether or not to tokenize Chinese characters.
+
+ This should likely be deactivated for Japanese (see this
+ [issue](https://github.com/huggingface/transformers/issues/328)).
+ strip_accents (`bool`, *optional*):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for `lowercase` (as in the original BERT).
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES_ZH
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP_ZH
+ pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION_ZH
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES_ZH
+
+ def __init__(self,
+ vocab_file,
+ do_lower_case=True,
+ do_basic_tokenize=True,
+ never_split=None,
+ bos_token='',
+ eos_token='',
+ sep_token='',
+ cls_token='',
+ unk_token='',
+ pad_token='',
+ mask_token='',
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ **kwargs):
+ super().__init__(
+ do_lower_case=do_lower_case,
+ do_basic_tokenize=do_basic_tokenize,
+ never_split=never_split,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ **kwargs,
+ )
+
+ if not os.path.isfile(vocab_file):
+ raise ValueError(
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
+ 'model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`'
+ )
+ self.vocab = load_vocab(vocab_file)
+ self.ids_to_tokens = collections.OrderedDict([
+ (ids, tok) for tok, ids in self.vocab.items()
+ ])
+ self.do_basic_tokenize = do_basic_tokenize
+ if do_basic_tokenize:
+ self.basic_tokenizer = BasicTokenizer(
+ do_lower_case=do_lower_case,
+ never_split=never_split,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ )
+ self.wordpiece_tokenizer = WordpieceTokenizer(
+ vocab=self.vocab, unk_token=self.unk_token)
+
+ @property
+ def do_lower_case(self):
+ return self.basic_tokenizer.do_lower_case
+
+ @property
+ def vocab_size(self):
+ return len(self.vocab)
+
+ def get_vocab(self):
+ return dict(self.vocab, **self.added_tokens_encoder)
+
+ def _tokenize(self, text):
+ split_tokens = []
+ if self.do_basic_tokenize:
+ for token in self.basic_tokenizer.tokenize(
+ text, never_split=self.all_special_tokens):
+
+ # If the token is part of the never_split set
+ if token in self.basic_tokenizer.never_split:
+ split_tokens.append(token)
+ else:
+ split_tokens += self.wordpiece_tokenizer.tokenize(token)
+ else:
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
+ return split_tokens
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.ids_to_tokens.get(index, self.unk_token)
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ out_string = ' '.join(tokens).replace(' ##', '').strip()
+ return out_string
+
+ def build_inputs_with_special_tokens(
+ self,
+ token_ids_0: List[int],
+ token_ids_1: Optional[List[int]] = None) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A BERT sequence has the following format:
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + token_ids_1 + sep
+
+ def get_special_tokens_mask(
+ self,
+ token_ids_0: List[int],
+ token_ids_1: Optional[List[int]] = None,
+ already_has_special_tokens: bool = False) -> List[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0,
+ token_ids_1=token_ids_1,
+ already_has_special_tokens=True)
+
+ if token_ids_1 is not None:
+ return [1] + ([0] * len(token_ids_0)) + [1] + (
+ [0] * len(token_ids_1)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1]
+
+ def create_token_type_ids_from_sequences(
+ self,
+ token_ids_0: List[int],
+ token_ids_1: Optional[List[int]] = None) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
+ pair mask has the following format:
+
+ ```
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+ | first sequence | second sequence |
+ ```
+
+ If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1
+ + sep) * [1]
+
+ def save_vocabulary(self,
+ save_directory: str,
+ filename_prefix: Optional[str] = None) -> Tuple[str]:
+ index = 0
+ if os.path.isdir(save_directory):
+ vocab_file = os.path.join(
+ save_directory,
+ (filename_prefix + '-' if filename_prefix else '')
+ + VOCAB_FILES_NAMES['vocab_file'])
+ else:
+ vocab_file = (filename_prefix
+ + '-' if filename_prefix else '') + save_directory
+ with open(vocab_file, 'w', encoding='utf-8') as writer:
+ for token, token_index in sorted(
+ self.vocab.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ f'Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive.'
+ ' Please check that the vocabulary is not corrupted!')
+ index = token_index
+ writer.write(token + '\n')
+ index += 1
+ return (vocab_file, )
diff --git a/modelscope/models/multi_modal/ofa/tokenization_ofa_fast.py b/modelscope/models/multi_modal/ofa/tokenization_ofa_fast.py
index 235d1b34..03d2d71e 100644
--- a/modelscope/models/multi_modal/ofa/tokenization_ofa_fast.py
+++ b/modelscope/models/multi_modal/ofa/tokenization_ofa_fast.py
@@ -12,10 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OFA."""
+from typing import List, Optional, Tuple
+
+import json
+from tokenizers import normalizers
+from transformers import PreTrainedTokenizerFast
from transformers.models.bart.tokenization_bart_fast import BartTokenizerFast
from transformers.utils import logging
-from .tokenization_ofa import OFATokenizer
+from .tokenization_ofa import OFATokenizer, OFATokenizerZH
logger = logging.get_logger(__name__)
@@ -36,12 +41,37 @@ PRETRAINED_VOCAB_FILES_MAP = {
'ofa-base':
'https://huggingface.co/ofa-base/resolve/main/tokenizer.json',
},
+ # OFA models are implemented to be compatible with both huggingface
+ # and modelscope frameworks. For all OFA models available on huggingface,
+ # please refer to https://huggingface.co/models?filter=ofa
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'ofa-base': 1024,
}
+VOCAB_FILES_NAMES_ZH = {'vocab_file': 'vocab.txt'}
+
+PRETRAINED_VOCAB_FILES_MAP_ZH = {
+ 'vocab_file': {
+ 'bert-base-chinese':
+ 'https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt',
+ }
+ # OFA models are implemeted to be compatible with both huggingface
+ # and modelscope frameworks. For all OFA models available on huggingface,
+ # please refer to https://huggingface.co/models?filter=ofa
+}
+
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES_ZH = {
+ 'ofa-base': 1024,
+}
+
+PRETRAINED_INIT_CONFIGURATION_ZH = {
+ 'bert-base-chinese': {
+ 'do_lower_case': True
+ },
+}
+
class OFATokenizerFast(BartTokenizerFast):
r"""
@@ -57,3 +87,128 @@ class OFATokenizerFast(BartTokenizerFast):
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
slow_tokenizer_class = OFATokenizer
+
+
+class OFATokenizerZHFast(PreTrainedTokenizerFast):
+ r"""
+ Construct a "fast" OFA tokenizer (backed by HuggingFace's *tokenizers* library).
+
+ [`~OFATokenizerFast`] is identical to [`BartTokenizerFast`] and runs end-to-end tokenization: punctuation splitting
+ and wordpiece.
+
+ Refer to superclass [`BartTokenizerFast`] for usage examples and documentation concerning parameters.
+ """
+ vocab_files_names = VOCAB_FILES_NAMES_ZH
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP_ZH
+ pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION_ZH
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES_ZH
+ slow_tokenizer_class = OFATokenizerZH
+
+ def __init__(self,
+ vocab_file=None,
+ tokenizer_file=None,
+ do_lower_case=True,
+ bos_token='',
+ eos_token='',
+ sep_token='',
+ cls_token='',
+ unk_token='',
+ pad_token='',
+ mask_token='',
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ **kwargs):
+ super().__init__(
+ vocab_file,
+ tokenizer_file=tokenizer_file,
+ do_lower_case=do_lower_case,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ **kwargs,
+ )
+
+ normalizer_state = json.loads(
+ self.backend_tokenizer.normalizer.__getstate__())
+ if (normalizer_state.get('lowercase', do_lower_case) != do_lower_case
+ or normalizer_state.get('strip_accents', strip_accents)
+ != strip_accents or normalizer_state.get(
+ 'handle_chinese_chars',
+ tokenize_chinese_chars) != tokenize_chinese_chars):
+ normalizer_class = getattr(normalizers,
+ normalizer_state.pop('type'))
+ normalizer_state['lowercase'] = do_lower_case
+ normalizer_state['strip_accents'] = strip_accents
+ normalizer_state['handle_chinese_chars'] = tokenize_chinese_chars
+ self.backend_tokenizer.normalizer = normalizer_class(
+ **normalizer_state)
+
+ self.do_lower_case = do_lower_case
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A BERT sequence has the following format:
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+
+ if token_ids_1:
+ output += token_ids_1 + [self.sep_token_id]
+
+ return output
+
+ def create_token_type_ids_from_sequences(
+ self,
+ token_ids_0: List[int],
+ token_ids_1: Optional[List[int]] = None) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
+ pair mask has the following format:
+
+ ```
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+ | first sequence | second sequence |
+ ```
+
+ If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1
+ + sep) * [1]
+
+ def save_vocabulary(self,
+ save_directory: str,
+ filename_prefix: Optional[str] = None) -> Tuple[str]:
+ files = self._tokenizer.model.save(
+ save_directory, name=filename_prefix)
+ return tuple(files)
diff --git a/modelscope/models/multi_modal/ofa_for_all_tasks.py b/modelscope/models/multi_modal/ofa_for_all_tasks.py
index 0ec87d66..363d552d 100644
--- a/modelscope/models/multi_modal/ofa_for_all_tasks.py
+++ b/modelscope/models/multi_modal/ofa_for_all_tasks.py
@@ -16,7 +16,7 @@ from modelscope.preprocessors.ofa.utils.collate import collate_tokens
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile
from modelscope.utils.trie import Trie
-from .ofa import OFAModel, OFATokenizer
+from .ofa import OFAModel, OFATokenizer, OFATokenizerZH
from .ofa.generate import sequence_generator as sg
from .ofa.generate.utils import move_to_device
from .ofa.utils.constant import OFA_TASK_KEY_MAPPING, Tasks
@@ -41,11 +41,21 @@ class OfaForAllTasks(TorchModel):
self.cfg = Config.from_file(
osp.join(model_dir, ModelFile.CONFIGURATION))
self.model = model.module if hasattr(model, 'module') else model
- self.tokenizer = OFATokenizer.from_pretrained(model_dir)
+ self.language = self.cfg.model.get('language', 'en')
+ if self.language == 'en':
+ self.tokenizer = OFATokenizer.from_pretrained(model_dir)
+ elif self.language in ['zh', 'cn']:
+ self.tokenizer = OFATokenizerZH.from_pretrained(model_dir)
+ else:
+ raise NotImplementedError
+ # there is some diff between here and our ofa code,
+ # there will be no need to use param: use_bpe
self.tokenizer.add_tokens([''.format(i) for i in range(8192)])
self.tokenizer.add_tokens([''.format(i) for i in range(1000)])
self.cfg.update({'num_bins': 1000, 'num_codes': 8192})
self.batch_size = self.cfg.model.get('batch_size', 1)
+ self.patch_image_size = self.cfg.model.get('patch_image_size', 480)
+ self.max_image_size = self.cfg.model.get('max_image_size', 512)
self.val_batch_size = self.cfg.model.get('valid_batch_size',
self.batch_size)
self.gen_type = self.cfg.model.get('gen_type', 'generation')
@@ -129,8 +139,8 @@ class OfaForAllTasks(TorchModel):
- len(self.tokenizer.get_vocab().items())
+ self.cfg.num_bins)
region_tensor = torch.stack(region_coord_l, dim=0)
- region_tensor = region_tensor / (
- self.cfg.num_bins - 1) * self.cfg.model.get('max_image_size', 512)
+ region_tensor = region_tensor / (self.cfg.num_bins
+ - 1) * self.max_image_size
region_tensor[:, ::2] /= input['w_resize_ratios']
region_tensor[:, 1::2] /= input['h_resize_ratios']
return {
diff --git a/modelscope/preprocessors/ofa/base.py b/modelscope/preprocessors/ofa/base.py
index 8f53dbf7..fb9d06cd 100644
--- a/modelscope/preprocessors/ofa/base.py
+++ b/modelscope/preprocessors/ofa/base.py
@@ -6,7 +6,7 @@ import json
import numpy as np
import torch
-from modelscope.models.multi_modal.ofa import OFATokenizer
+from modelscope.models.multi_modal.ofa import OFATokenizer, OFATokenizerZH
from modelscope.utils.trie import Trie
from .utils.random_help import set_torch_seed
@@ -21,7 +21,15 @@ class OfaBasePreprocessor:
model_dir (str): model path
"""
self.cfg = cfg
- tokenizer = OFATokenizer.from_pretrained(model_dir)
+ self.language = self.cfg.model.get('language', 'en')
+ if self.language == 'en':
+ tokenizer = OFATokenizer.from_pretrained(model_dir)
+ elif self.language in ['zh', 'cn']:
+ tokenizer = OFATokenizerZH.from_pretrained(model_dir)
+ else:
+ raise NotImplementedError
+ # there is some diff between here and our ofa code,
+ # there will be no need to use param: use_bpe
tokenizer.add_tokens([''.format(i) for i in range(8192)])
tokenizer.add_tokens([''.format(i) for i in range(1000)])
self.tokenizer = tokenizer
diff --git a/tests/pipelines/test_ofa_tasks.py b/tests/pipelines/test_ofa_tasks.py
index 1dc7d303..5cba86b1 100644
--- a/tests/pipelines/test_ofa_tasks.py
+++ b/tests/pipelines/test_ofa_tasks.py
@@ -1,17 +1,33 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
import unittest
+from os import path as osp
+import cv2
+import numpy as np
from PIL import Image
from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
+from modelscope.preprocessors.image import load_image
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
class OfaTasksTest(unittest.TestCase):
+ def setUp(self) -> None:
+ self.output_dir = 'unittest_output'
+ os.makedirs(self.output_dir, exist_ok=True)
+
+ def save_img(self, image_in, box, image_out):
+ image = load_image(image_in)
+ img = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
+ cv2.rectangle(img, (int(box[0]), int(box[1])),
+ (int(box[2]), int(box[3])), (0, 255, 0), 3)
+ cv2.imwrite(osp.join(self.output_dir, image_out), img)
+
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_image_captioning_with_model(self):
model = Model.from_pretrained('damo/ofa_image-caption_coco_large_en')
@@ -132,6 +148,9 @@ class OfaTasksTest(unittest.TestCase):
input = {'image': image, 'text': text}
result = ofa_pipe(input)
print(result)
+ image_name = image.split('/')[-2]
+ self.save_img(image, result[OutputKeys.BOXES],
+ osp.join('large_en_model_' + image_name + '.png'))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_visual_grounding_with_name(self):
@@ -143,6 +162,22 @@ class OfaTasksTest(unittest.TestCase):
input = {'image': image, 'text': text}
result = ofa_pipe(input)
print(result)
+ image_name = image.split('/')[-2]
+ self.save_img(image, result[OutputKeys.BOXES],
+ osp.join('large_en_name_' + image_name + '.png'))
+
+ @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
+ def test_run_with_visual_grounding_zh_with_name(self):
+ model = 'damo/ofa_visual-grounding_refcoco_large_zh'
+ ofa_pipe = pipeline(Tasks.visual_grounding, model=model)
+ image = 'data/test/images/visual_grounding.png'
+ text = '一个圆头的蓝色宝可梦'
+ input = {'image': image, 'text': text}
+ result = ofa_pipe(input)
+ print(result)
+ image_name = image.split('/')[-1]
+ self.save_img(image, result[OutputKeys.BOXES],
+ osp.join('large_zh_name_' + image_name))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_visual_question_answering_with_model(self):