Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9627026master
| @@ -1,2 +1,3 @@ | |||
| from .modeling_ofa import OFADecoder, OFAEncoder, OFAModel, OFAPreTrainedModel | |||
| from .tokenization_ofa import OFATokenizer | |||
| from .tokenization_ofa import OFATokenizer, OFATokenizerZH | |||
| from .tokenization_ofa_fast import OFATokenizerFast, OFATokenizerZHFast | |||
| @@ -134,6 +134,8 @@ class OFAConfig(PretrainedConfig): | |||
| code_layernorm_embedding=True, | |||
| code_image_size=128, | |||
| entangle_position_embedding=False, | |||
| interpolate_position=False, | |||
| orig_patch_image_size=224, | |||
| **kwargs): | |||
| self.vocab_size = vocab_size | |||
| self.max_position_embeddings = max_position_embeddings | |||
| @@ -173,6 +175,8 @@ class OFAConfig(PretrainedConfig): | |||
| self.code_layernorm_embedding = code_layernorm_embedding | |||
| self.code_image_size = code_image_size | |||
| self.entangle_position_embedding = entangle_position_embedding | |||
| self.interpolate_position = interpolate_position | |||
| self.orig_patch_image_size = orig_patch_image_size | |||
| super().__init__( | |||
| pad_token_id=pad_token_id, | |||
| @@ -311,7 +311,6 @@ class OFAAttention(nn.Module): | |||
| self.head_dim * num_heads == self.embed_dim | |||
| ), f'embed_dim must be divisible by num_heads ' \ | |||
| f'(got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads}).' | |||
| # self.scaling = self.head_dim ** -0.5 | |||
| # 1. difference | |||
| scale_factor = 2 | |||
| self.scaling = float(self.head_dim * scale_factor)**-0.5 | |||
| @@ -913,7 +912,6 @@ class OFAEncoder(OFAPreTrainedModel): | |||
| else: | |||
| raise NotImplementedError | |||
| # self.image_proj = nn.Linear(1024, embed_dim) | |||
| self.image_proj = Linear(1024, embed_dim) | |||
| if config.resnet_model_path: | |||
| @@ -1075,7 +1073,25 @@ class OFAEncoder(OFAPreTrainedModel): | |||
| image_num_patches = sample_patch_num | |||
| image_padding_mask = image_padding_mask.gather(1, patch_orders) | |||
| image_position_ids = image_position_ids.gather(1, patch_orders) | |||
| image_pos_embed = self.embed_image_positions(image_position_ids) | |||
| orig_num_patches = (self.config.orig_patch_image_size // 16)**2 | |||
| orig_hw = self.config.orig_patch_image_size // 16 | |||
| if self.config.interpolate_position and image_num_patches > orig_num_patches: | |||
| old_image_position_ids = torch.arange(orig_hw).unsqueeze(0).expand(orig_hw, orig_hw) + \ | |||
| torch.arange(orig_hw).unsqueeze(1) * \ | |||
| self.config.image_bucket_size + 1 # noqa | |||
| old_image_position_ids = old_image_position_ids.to(device) | |||
| old_image_pos_embed = self.embed_image_positions( | |||
| old_image_position_ids) | |||
| old_image_pos_embed = old_image_pos_embed.reshape( | |||
| 1, orig_hw, orig_hw, -1).permute(0, 3, 1, 2) | |||
| image_pos_embed = F.interpolate( | |||
| old_image_pos_embed, size=(h, w), mode='bilinear') | |||
| image_pos_embed = image_pos_embed.permute(0, 2, 3, 1).reshape( | |||
| 1, image_num_patches, -1) | |||
| image_pos_embed = image_pos_embed.expand( | |||
| patch_images.size(0), -1, -1) | |||
| else: | |||
| image_pos_embed = self.embed_image_positions(image_position_ids) | |||
| return image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed | |||
| @@ -1250,7 +1266,6 @@ class OFAEncoder(OFAPreTrainedModel): | |||
| position_embedding (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): | |||
| positional embeddings of the input image and tokens. | |||
| """ | |||
| image_embed = None | |||
| image_embed_2 = None | |||
| image_pos_embed = None | |||
| @@ -1258,14 +1273,7 @@ class OFAEncoder(OFAPreTrainedModel): | |||
| if patch_images is not None: | |||
| image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed = \ | |||
| self.get_patch_images_info(patch_images, sample_patch_num, input_ids.device) | |||
| # print("patch_masks.shape") | |||
| # print(patch_masks.shape) | |||
| # print(patch_masks) | |||
| # print("image_padding_mask.shape") | |||
| # print(image_padding_mask.shape) | |||
| # print(image_padding_mask) | |||
| image_padding_mask[~patch_masks] = True | |||
| # print(image_padding_mask) | |||
| if patch_images_2 is not None: | |||
| image_embed_2, image_num_patches_2, image_padding_mask_2, image_position_ids_2, image_pos_embed_2 = \ | |||
| self.get_patch_images_info(patch_images_2, sample_patch_num, input_ids.device) | |||
| @@ -1313,10 +1321,6 @@ class OFAEncoder(OFAPreTrainedModel): | |||
| encoder_states = () if output_hidden_states else None | |||
| all_attentions = () if output_attentions else None | |||
| # if output_hidden_states: | |||
| # # encoder_states.append(x) | |||
| # encoder_states += (x,) | |||
| # encoder layers | |||
| for idx, layer in enumerate(self.layers): | |||
| if output_hidden_states: | |||
| @@ -1645,7 +1649,6 @@ class OFADecoder(OFAPreTrainedModel): | |||
| def reorder_incremental_state_scripting( | |||
| self, | |||
| # incremental_state: Dict[str, Dict[str, Optional[Tensor]]], | |||
| past_key_values: Optional[torch.Tensor], | |||
| new_order: Tensor, | |||
| ): | |||
| @@ -1799,15 +1802,12 @@ class OFADecoder(OFAPreTrainedModel): | |||
| self_attn_bias = self_abs_pos_bias.clone() | |||
| if code_masks is None or not code_masks.any(): | |||
| # print("code_masks is None or not code_masks.any()") | |||
| self_attn_bias += self.get_rel_pos_bias( | |||
| all_prev_output_tokens, idx).unsqueeze(0) | |||
| elif code_masks is not None and code_masks.all(): | |||
| # print("code_masks is not None and code_masks.all()") | |||
| self_attn_bias += self.get_image_rel_pos_bias( | |||
| all_prev_output_tokens, idx).unsqueeze(0) | |||
| else: | |||
| # print("else") | |||
| self_attn_bias[~code_masks] += self.get_rel_pos_bias( | |||
| all_prev_output_tokens, idx).unsqueeze(0) | |||
| self_attn_bias[code_masks] += self.get_image_rel_pos_bias( | |||
| @@ -1921,7 +1921,7 @@ class OFAModel(OFAPreTrainedModel): | |||
| output_type=Seq2SeqModelOutput, | |||
| config_class=_CONFIG_FOR_DOC, | |||
| ) | |||
| # 新增函数以适配fairseq的generator | |||
| # an adaptor for fairseq generator | |||
| def max_decoder_positions(self): | |||
| """Maximum length supported by the decoder.""" | |||
| return self.decoder.max_positions() | |||
| @@ -2062,7 +2062,6 @@ class OFAModel(OFAPreTrainedModel): | |||
| return Seq2SeqLMOutput( | |||
| logits=decoder_outputs.last_hidden_state, | |||
| # last_hidden_state=decoder_outputs.last_hidden_state, | |||
| past_key_values=decoder_outputs.past_key_values, | |||
| decoder_hidden_states=decoder_outputs.hidden_states, | |||
| decoder_attentions=decoder_outputs.attentions, | |||
| @@ -12,7 +12,14 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """Tokenization classes for OFA.""" | |||
| import collections | |||
| import os | |||
| from typing import List, Optional, Tuple | |||
| from transformers import PreTrainedTokenizer | |||
| from transformers.models.bart.tokenization_bart import BartTokenizer | |||
| from transformers.models.bert.tokenization_bert import (BasicTokenizer, | |||
| WordpieceTokenizer) | |||
| from transformers.utils import logging | |||
| logger = logging.get_logger(__name__) | |||
| @@ -26,12 +33,37 @@ PRETRAINED_VOCAB_FILES_MAP = { | |||
| 'merges_file': { | |||
| 'ofa-base': 'https://huggingface.co/ofa-base/resolve/main/merges.txt', | |||
| }, | |||
| # OFA models are implemented to be compatible with both huggingface | |||
| # and modelscope frameworks. For all OFA models available on huggingface, | |||
| # please refer to https://huggingface.co/models?filter=ofa | |||
| } | |||
| PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { | |||
| 'ofa-base': 1024, | |||
| } | |||
| VOCAB_FILES_NAMES_ZH = {'vocab_file': 'vocab.txt'} | |||
| PRETRAINED_VOCAB_FILES_MAP_ZH = { | |||
| 'vocab_file': { | |||
| 'bert-base-chinese': | |||
| 'https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt', | |||
| } | |||
| # OFA models are implemented to be compatible with both huggingface | |||
| # and modelscope frameworks. For all OFA models available on huggingface, | |||
| # please refer to https://huggingface.co/models?filter=ofa | |||
| } | |||
| PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES_ZH = { | |||
| 'ofa-base': 1024, | |||
| } | |||
| PRETRAINED_INIT_CONFIGURATION_ZH = { | |||
| 'bert-base-chinese': { | |||
| 'do_lower_case': True | |||
| }, | |||
| } | |||
| class OFATokenizer(BartTokenizer): | |||
| """ | |||
| @@ -46,3 +78,293 @@ class OFATokenizer(BartTokenizer): | |||
| vocab_files_names = VOCAB_FILES_NAMES | |||
| pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP | |||
| max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES | |||
| def load_vocab(vocab_file): | |||
| """Loads a vocabulary file into a dictionary.""" | |||
| vocab = collections.OrderedDict() | |||
| with open(vocab_file, 'r', encoding='utf-8') as reader: | |||
| tokens = reader.readlines() | |||
| for index, token in enumerate(tokens): | |||
| token = token.rstrip('\n') | |||
| vocab[token] = index | |||
| return vocab | |||
| def whitespace_tokenize(text): | |||
| """Runs basic whitespace cleaning and splitting on a piece of text.""" | |||
| text = text.strip() | |||
| if not text: | |||
| return [] | |||
| tokens = text.split() | |||
| return tokens | |||
| class OFATokenizerZH(PreTrainedTokenizer): | |||
| r""" | |||
| Construct a OFA tokenizer. Based on WordPiece. | |||
| This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to | |||
| this superclass for more information regarding those methods. | |||
| Args: | |||
| vocab_file (`str`): | |||
| File containing the vocabulary. | |||
| do_lower_case (`bool`, *optional*, defaults to `True`): | |||
| Whether or not to lowercase the input when tokenizing. | |||
| do_basic_tokenize (`bool`, *optional*, defaults to `True`): | |||
| Whether or not to do basic tokenization before WordPiece. | |||
| never_split (`Iterable`, *optional*): | |||
| Collection of tokens which will never be split during tokenization. Only has an effect when | |||
| `do_basic_tokenize=True` | |||
| bos_token (`str`, *optional*, defaults to `"<s>"`): | |||
| The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. | |||
| <Tip> | |||
| When building a sequence using special tokens, this is not the token that is used for the beginning of | |||
| sequence. The token used is the `cls_token`. | |||
| </Tip> | |||
| eos_token (`str`, *optional*, defaults to `"</s>"`): | |||
| The end of sequence token. | |||
| <Tip> | |||
| When building a sequence using special tokens, this is not the token that is used for the end of sequence. | |||
| The token used is the `sep_token`. | |||
| </Tip> | |||
| sep_token (`str`, *optional*, defaults to `"</s>"`): | |||
| The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for | |||
| sequence classification or for a text and a question for question answering. It is also used as the last | |||
| token of a sequence built with special tokens. | |||
| cls_token (`str`, *optional*, defaults to `"<s>"`): | |||
| The classifier token which is used when doing sequence classification (classification of the whole sequence | |||
| instead of per-token classification). It is the first token of the sequence when built with special tokens. | |||
| unk_token (`str`, *optional*, defaults to `"<unk>"`): | |||
| The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this | |||
| token instead. | |||
| pad_token (`str`, *optional*, defaults to `"<pad>"`): | |||
| The token used for padding, for example when batching sequences of different lengths. | |||
| mask_token (`str`, *optional*, defaults to `"<mask>"`): | |||
| The token used for masking values. This is the token used when training this model with masked language | |||
| modeling. This is the token which the model will try to predict. | |||
| tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): | |||
| Whether or not to tokenize Chinese characters. | |||
| This should likely be deactivated for Japanese (see this | |||
| [issue](https://github.com/huggingface/transformers/issues/328)). | |||
| strip_accents (`bool`, *optional*): | |||
| Whether or not to strip all accents. If this option is not specified, then it will be determined by the | |||
| value for `lowercase` (as in the original BERT). | |||
| """ | |||
| vocab_files_names = VOCAB_FILES_NAMES_ZH | |||
| pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP_ZH | |||
| pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION_ZH | |||
| max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES_ZH | |||
| def __init__(self, | |||
| vocab_file, | |||
| do_lower_case=True, | |||
| do_basic_tokenize=True, | |||
| never_split=None, | |||
| bos_token='<s>', | |||
| eos_token='</s>', | |||
| sep_token='</s>', | |||
| cls_token='<s>', | |||
| unk_token='<unk>', | |||
| pad_token='<pad>', | |||
| mask_token='<mask>', | |||
| tokenize_chinese_chars=True, | |||
| strip_accents=None, | |||
| **kwargs): | |||
| super().__init__( | |||
| do_lower_case=do_lower_case, | |||
| do_basic_tokenize=do_basic_tokenize, | |||
| never_split=never_split, | |||
| bos_token=bos_token, | |||
| eos_token=eos_token, | |||
| unk_token=unk_token, | |||
| sep_token=sep_token, | |||
| cls_token=cls_token, | |||
| pad_token=pad_token, | |||
| mask_token=mask_token, | |||
| tokenize_chinese_chars=tokenize_chinese_chars, | |||
| strip_accents=strip_accents, | |||
| **kwargs, | |||
| ) | |||
| if not os.path.isfile(vocab_file): | |||
| raise ValueError( | |||
| f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained " | |||
| 'model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`' | |||
| ) | |||
| self.vocab = load_vocab(vocab_file) | |||
| self.ids_to_tokens = collections.OrderedDict([ | |||
| (ids, tok) for tok, ids in self.vocab.items() | |||
| ]) | |||
| self.do_basic_tokenize = do_basic_tokenize | |||
| if do_basic_tokenize: | |||
| self.basic_tokenizer = BasicTokenizer( | |||
| do_lower_case=do_lower_case, | |||
| never_split=never_split, | |||
| tokenize_chinese_chars=tokenize_chinese_chars, | |||
| strip_accents=strip_accents, | |||
| ) | |||
| self.wordpiece_tokenizer = WordpieceTokenizer( | |||
| vocab=self.vocab, unk_token=self.unk_token) | |||
| @property | |||
| def do_lower_case(self): | |||
| return self.basic_tokenizer.do_lower_case | |||
| @property | |||
| def vocab_size(self): | |||
| return len(self.vocab) | |||
| def get_vocab(self): | |||
| return dict(self.vocab, **self.added_tokens_encoder) | |||
| def _tokenize(self, text): | |||
| split_tokens = [] | |||
| if self.do_basic_tokenize: | |||
| for token in self.basic_tokenizer.tokenize( | |||
| text, never_split=self.all_special_tokens): | |||
| # If the token is part of the never_split set | |||
| if token in self.basic_tokenizer.never_split: | |||
| split_tokens.append(token) | |||
| else: | |||
| split_tokens += self.wordpiece_tokenizer.tokenize(token) | |||
| else: | |||
| split_tokens = self.wordpiece_tokenizer.tokenize(text) | |||
| return split_tokens | |||
| def _convert_token_to_id(self, token): | |||
| """Converts a token (str) in an id using the vocab.""" | |||
| return self.vocab.get(token, self.vocab.get(self.unk_token)) | |||
| def _convert_id_to_token(self, index): | |||
| """Converts an index (integer) in a token (str) using the vocab.""" | |||
| return self.ids_to_tokens.get(index, self.unk_token) | |||
| def convert_tokens_to_string(self, tokens): | |||
| """Converts a sequence of tokens (string) in a single string.""" | |||
| out_string = ' '.join(tokens).replace(' ##', '').strip() | |||
| return out_string | |||
| def build_inputs_with_special_tokens( | |||
| self, | |||
| token_ids_0: List[int], | |||
| token_ids_1: Optional[List[int]] = None) -> List[int]: | |||
| """ | |||
| Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and | |||
| adding special tokens. A BERT sequence has the following format: | |||
| - single sequence: `[CLS] X [SEP]` | |||
| - pair of sequences: `[CLS] A [SEP] B [SEP]` | |||
| Args: | |||
| token_ids_0 (`List[int]`): | |||
| List of IDs to which the special tokens will be added. | |||
| token_ids_1 (`List[int]`, *optional*): | |||
| Optional second list of IDs for sequence pairs. | |||
| Returns: | |||
| `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. | |||
| """ | |||
| if token_ids_1 is None: | |||
| return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] | |||
| cls = [self.cls_token_id] | |||
| sep = [self.sep_token_id] | |||
| return cls + token_ids_0 + sep + token_ids_1 + sep | |||
| def get_special_tokens_mask( | |||
| self, | |||
| token_ids_0: List[int], | |||
| token_ids_1: Optional[List[int]] = None, | |||
| already_has_special_tokens: bool = False) -> List[int]: | |||
| """ | |||
| Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding | |||
| special tokens using the tokenizer `prepare_for_model` method. | |||
| Args: | |||
| token_ids_0 (`List[int]`): | |||
| List of IDs. | |||
| token_ids_1 (`List[int]`, *optional*): | |||
| Optional second list of IDs for sequence pairs. | |||
| already_has_special_tokens (`bool`, *optional*, defaults to `False`): | |||
| Whether or not the token list is already formatted with special tokens for the model. | |||
| Returns: | |||
| `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. | |||
| """ | |||
| if already_has_special_tokens: | |||
| return super().get_special_tokens_mask( | |||
| token_ids_0=token_ids_0, | |||
| token_ids_1=token_ids_1, | |||
| already_has_special_tokens=True) | |||
| if token_ids_1 is not None: | |||
| return [1] + ([0] * len(token_ids_0)) + [1] + ( | |||
| [0] * len(token_ids_1)) + [1] | |||
| return [1] + ([0] * len(token_ids_0)) + [1] | |||
| def create_token_type_ids_from_sequences( | |||
| self, | |||
| token_ids_0: List[int], | |||
| token_ids_1: Optional[List[int]] = None) -> List[int]: | |||
| """ | |||
| Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence | |||
| pair mask has the following format: | |||
| ``` | |||
| 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 | |||
| | first sequence | second sequence | | |||
| ``` | |||
| If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). | |||
| Args: | |||
| token_ids_0 (`List[int]`): | |||
| List of IDs. | |||
| token_ids_1 (`List[int]`, *optional*): | |||
| Optional second list of IDs for sequence pairs. | |||
| Returns: | |||
| `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). | |||
| """ | |||
| sep = [self.sep_token_id] | |||
| cls = [self.cls_token_id] | |||
| if token_ids_1 is None: | |||
| return len(cls + token_ids_0 + sep) * [0] | |||
| return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 | |||
| + sep) * [1] | |||
| def save_vocabulary(self, | |||
| save_directory: str, | |||
| filename_prefix: Optional[str] = None) -> Tuple[str]: | |||
| index = 0 | |||
| if os.path.isdir(save_directory): | |||
| vocab_file = os.path.join( | |||
| save_directory, | |||
| (filename_prefix + '-' if filename_prefix else '') | |||
| + VOCAB_FILES_NAMES['vocab_file']) | |||
| else: | |||
| vocab_file = (filename_prefix | |||
| + '-' if filename_prefix else '') + save_directory | |||
| with open(vocab_file, 'w', encoding='utf-8') as writer: | |||
| for token, token_index in sorted( | |||
| self.vocab.items(), key=lambda kv: kv[1]): | |||
| if index != token_index: | |||
| logger.warning( | |||
| f'Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive.' | |||
| ' Please check that the vocabulary is not corrupted!') | |||
| index = token_index | |||
| writer.write(token + '\n') | |||
| index += 1 | |||
| return (vocab_file, ) | |||
| @@ -12,10 +12,15 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """Tokenization classes for OFA.""" | |||
| from typing import List, Optional, Tuple | |||
| import json | |||
| from tokenizers import normalizers | |||
| from transformers import PreTrainedTokenizerFast | |||
| from transformers.models.bart.tokenization_bart_fast import BartTokenizerFast | |||
| from transformers.utils import logging | |||
| from .tokenization_ofa import OFATokenizer | |||
| from .tokenization_ofa import OFATokenizer, OFATokenizerZH | |||
| logger = logging.get_logger(__name__) | |||
| @@ -36,12 +41,37 @@ PRETRAINED_VOCAB_FILES_MAP = { | |||
| 'ofa-base': | |||
| 'https://huggingface.co/ofa-base/resolve/main/tokenizer.json', | |||
| }, | |||
| # OFA models are implemented to be compatible with both huggingface | |||
| # and modelscope frameworks. For all OFA models available on huggingface, | |||
| # please refer to https://huggingface.co/models?filter=ofa | |||
| } | |||
| PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { | |||
| 'ofa-base': 1024, | |||
| } | |||
| VOCAB_FILES_NAMES_ZH = {'vocab_file': 'vocab.txt'} | |||
| PRETRAINED_VOCAB_FILES_MAP_ZH = { | |||
| 'vocab_file': { | |||
| 'bert-base-chinese': | |||
| 'https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt', | |||
| } | |||
| # OFA models are implemeted to be compatible with both huggingface | |||
| # and modelscope frameworks. For all OFA models available on huggingface, | |||
| # please refer to https://huggingface.co/models?filter=ofa | |||
| } | |||
| PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES_ZH = { | |||
| 'ofa-base': 1024, | |||
| } | |||
| PRETRAINED_INIT_CONFIGURATION_ZH = { | |||
| 'bert-base-chinese': { | |||
| 'do_lower_case': True | |||
| }, | |||
| } | |||
| class OFATokenizerFast(BartTokenizerFast): | |||
| r""" | |||
| @@ -57,3 +87,128 @@ class OFATokenizerFast(BartTokenizerFast): | |||
| pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP | |||
| max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES | |||
| slow_tokenizer_class = OFATokenizer | |||
| class OFATokenizerZHFast(PreTrainedTokenizerFast): | |||
| r""" | |||
| Construct a "fast" OFA tokenizer (backed by HuggingFace's *tokenizers* library). | |||
| [`~OFATokenizerFast`] is identical to [`BartTokenizerFast`] and runs end-to-end tokenization: punctuation splitting | |||
| and wordpiece. | |||
| Refer to superclass [`BartTokenizerFast`] for usage examples and documentation concerning parameters. | |||
| """ | |||
| vocab_files_names = VOCAB_FILES_NAMES_ZH | |||
| pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP_ZH | |||
| pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION_ZH | |||
| max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES_ZH | |||
| slow_tokenizer_class = OFATokenizerZH | |||
| def __init__(self, | |||
| vocab_file=None, | |||
| tokenizer_file=None, | |||
| do_lower_case=True, | |||
| bos_token='<s>', | |||
| eos_token='</s>', | |||
| sep_token='</s>', | |||
| cls_token='<s>', | |||
| unk_token='<unk>', | |||
| pad_token='<pad>', | |||
| mask_token='<mask>', | |||
| tokenize_chinese_chars=True, | |||
| strip_accents=None, | |||
| **kwargs): | |||
| super().__init__( | |||
| vocab_file, | |||
| tokenizer_file=tokenizer_file, | |||
| do_lower_case=do_lower_case, | |||
| bos_token=bos_token, | |||
| eos_token=eos_token, | |||
| unk_token=unk_token, | |||
| sep_token=sep_token, | |||
| cls_token=cls_token, | |||
| pad_token=pad_token, | |||
| mask_token=mask_token, | |||
| tokenize_chinese_chars=tokenize_chinese_chars, | |||
| strip_accents=strip_accents, | |||
| **kwargs, | |||
| ) | |||
| normalizer_state = json.loads( | |||
| self.backend_tokenizer.normalizer.__getstate__()) | |||
| if (normalizer_state.get('lowercase', do_lower_case) != do_lower_case | |||
| or normalizer_state.get('strip_accents', strip_accents) | |||
| != strip_accents or normalizer_state.get( | |||
| 'handle_chinese_chars', | |||
| tokenize_chinese_chars) != tokenize_chinese_chars): | |||
| normalizer_class = getattr(normalizers, | |||
| normalizer_state.pop('type')) | |||
| normalizer_state['lowercase'] = do_lower_case | |||
| normalizer_state['strip_accents'] = strip_accents | |||
| normalizer_state['handle_chinese_chars'] = tokenize_chinese_chars | |||
| self.backend_tokenizer.normalizer = normalizer_class( | |||
| **normalizer_state) | |||
| self.do_lower_case = do_lower_case | |||
| def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): | |||
| """ | |||
| Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and | |||
| adding special tokens. A BERT sequence has the following format: | |||
| - single sequence: `[CLS] X [SEP]` | |||
| - pair of sequences: `[CLS] A [SEP] B [SEP]` | |||
| Args: | |||
| token_ids_0 (`List[int]`): | |||
| List of IDs to which the special tokens will be added. | |||
| token_ids_1 (`List[int]`, *optional*): | |||
| Optional second list of IDs for sequence pairs. | |||
| Returns: | |||
| `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. | |||
| """ | |||
| output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] | |||
| if token_ids_1: | |||
| output += token_ids_1 + [self.sep_token_id] | |||
| return output | |||
| def create_token_type_ids_from_sequences( | |||
| self, | |||
| token_ids_0: List[int], | |||
| token_ids_1: Optional[List[int]] = None) -> List[int]: | |||
| """ | |||
| Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence | |||
| pair mask has the following format: | |||
| ``` | |||
| 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 | |||
| | first sequence | second sequence | | |||
| ``` | |||
| If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). | |||
| Args: | |||
| token_ids_0 (`List[int]`): | |||
| List of IDs. | |||
| token_ids_1 (`List[int]`, *optional*): | |||
| Optional second list of IDs for sequence pairs. | |||
| Returns: | |||
| `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). | |||
| """ | |||
| sep = [self.sep_token_id] | |||
| cls = [self.cls_token_id] | |||
| if token_ids_1 is None: | |||
| return len(cls + token_ids_0 + sep) * [0] | |||
| return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 | |||
| + sep) * [1] | |||
| def save_vocabulary(self, | |||
| save_directory: str, | |||
| filename_prefix: Optional[str] = None) -> Tuple[str]: | |||
| files = self._tokenizer.model.save( | |||
| save_directory, name=filename_prefix) | |||
| return tuple(files) | |||
| @@ -16,7 +16,7 @@ from modelscope.preprocessors.ofa.utils.collate import collate_tokens | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile | |||
| from modelscope.utils.trie import Trie | |||
| from .ofa import OFAModel, OFATokenizer | |||
| from .ofa import OFAModel, OFATokenizer, OFATokenizerZH | |||
| from .ofa.generate import sequence_generator as sg | |||
| from .ofa.generate.utils import move_to_device | |||
| from .ofa.utils.constant import OFA_TASK_KEY_MAPPING, Tasks | |||
| @@ -41,11 +41,21 @@ class OfaForAllTasks(TorchModel): | |||
| self.cfg = Config.from_file( | |||
| osp.join(model_dir, ModelFile.CONFIGURATION)) | |||
| self.model = model.module if hasattr(model, 'module') else model | |||
| self.tokenizer = OFATokenizer.from_pretrained(model_dir) | |||
| self.language = self.cfg.model.get('language', 'en') | |||
| if self.language == 'en': | |||
| self.tokenizer = OFATokenizer.from_pretrained(model_dir) | |||
| elif self.language in ['zh', 'cn']: | |||
| self.tokenizer = OFATokenizerZH.from_pretrained(model_dir) | |||
| else: | |||
| raise NotImplementedError | |||
| # there is some diff between here and our ofa code, | |||
| # there will be no need to use param: use_bpe | |||
| self.tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)]) | |||
| self.tokenizer.add_tokens(['<bin_{}>'.format(i) for i in range(1000)]) | |||
| self.cfg.update({'num_bins': 1000, 'num_codes': 8192}) | |||
| self.batch_size = self.cfg.model.get('batch_size', 1) | |||
| self.patch_image_size = self.cfg.model.get('patch_image_size', 480) | |||
| self.max_image_size = self.cfg.model.get('max_image_size', 512) | |||
| self.val_batch_size = self.cfg.model.get('valid_batch_size', | |||
| self.batch_size) | |||
| self.gen_type = self.cfg.model.get('gen_type', 'generation') | |||
| @@ -129,8 +139,8 @@ class OfaForAllTasks(TorchModel): | |||
| - len(self.tokenizer.get_vocab().items()) | |||
| + self.cfg.num_bins) | |||
| region_tensor = torch.stack(region_coord_l, dim=0) | |||
| region_tensor = region_tensor / ( | |||
| self.cfg.num_bins - 1) * self.cfg.model.get('max_image_size', 512) | |||
| region_tensor = region_tensor / (self.cfg.num_bins | |||
| - 1) * self.max_image_size | |||
| region_tensor[:, ::2] /= input['w_resize_ratios'] | |||
| region_tensor[:, 1::2] /= input['h_resize_ratios'] | |||
| return { | |||
| @@ -6,7 +6,7 @@ import json | |||
| import numpy as np | |||
| import torch | |||
| from modelscope.models.multi_modal.ofa import OFATokenizer | |||
| from modelscope.models.multi_modal.ofa import OFATokenizer, OFATokenizerZH | |||
| from modelscope.utils.trie import Trie | |||
| from .utils.random_help import set_torch_seed | |||
| @@ -21,7 +21,15 @@ class OfaBasePreprocessor: | |||
| model_dir (str): model path | |||
| """ | |||
| self.cfg = cfg | |||
| tokenizer = OFATokenizer.from_pretrained(model_dir) | |||
| self.language = self.cfg.model.get('language', 'en') | |||
| if self.language == 'en': | |||
| tokenizer = OFATokenizer.from_pretrained(model_dir) | |||
| elif self.language in ['zh', 'cn']: | |||
| tokenizer = OFATokenizerZH.from_pretrained(model_dir) | |||
| else: | |||
| raise NotImplementedError | |||
| # there is some diff between here and our ofa code, | |||
| # there will be no need to use param: use_bpe | |||
| tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)]) | |||
| tokenizer.add_tokens(['<bin_{}>'.format(i) for i in range(1000)]) | |||
| self.tokenizer = tokenizer | |||
| @@ -1,17 +1,33 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import unittest | |||
| from os import path as osp | |||
| import cv2 | |||
| import numpy as np | |||
| from PIL import Image | |||
| from modelscope.models import Model | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.preprocessors.image import load_image | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class OfaTasksTest(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| self.output_dir = 'unittest_output' | |||
| os.makedirs(self.output_dir, exist_ok=True) | |||
| def save_img(self, image_in, box, image_out): | |||
| image = load_image(image_in) | |||
| img = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) | |||
| cv2.rectangle(img, (int(box[0]), int(box[1])), | |||
| (int(box[2]), int(box[3])), (0, 255, 0), 3) | |||
| cv2.imwrite(osp.join(self.output_dir, image_out), img) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_image_captioning_with_model(self): | |||
| model = Model.from_pretrained('damo/ofa_image-caption_coco_large_en') | |||
| @@ -132,6 +148,9 @@ class OfaTasksTest(unittest.TestCase): | |||
| input = {'image': image, 'text': text} | |||
| result = ofa_pipe(input) | |||
| print(result) | |||
| image_name = image.split('/')[-2] | |||
| self.save_img(image, result[OutputKeys.BOXES], | |||
| osp.join('large_en_model_' + image_name + '.png')) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_visual_grounding_with_name(self): | |||
| @@ -143,6 +162,22 @@ class OfaTasksTest(unittest.TestCase): | |||
| input = {'image': image, 'text': text} | |||
| result = ofa_pipe(input) | |||
| print(result) | |||
| image_name = image.split('/')[-2] | |||
| self.save_img(image, result[OutputKeys.BOXES], | |||
| osp.join('large_en_name_' + image_name + '.png')) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_visual_grounding_zh_with_name(self): | |||
| model = 'damo/ofa_visual-grounding_refcoco_large_zh' | |||
| ofa_pipe = pipeline(Tasks.visual_grounding, model=model) | |||
| image = 'data/test/images/visual_grounding.png' | |||
| text = '一个圆头的蓝色宝可梦' | |||
| input = {'image': image, 'text': text} | |||
| result = ofa_pipe(input) | |||
| print(result) | |||
| image_name = image.split('/')[-1] | |||
| self.save_img(image, result[OutputKeys.BOXES], | |||
| osp.join('large_zh_name_' + image_name)) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_visual_question_answering_with_model(self): | |||