| @@ -37,6 +37,7 @@ do | |||
| -e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \ | |||
| -e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \ | |||
| -e TEST_LEVEL=$TEST_LEVEL \ | |||
| -e MODELSCOPE_ENVIRONMENT='ci' \ | |||
| -e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \ | |||
| -e MODEL_TAG_URL=$MODEL_TAG_URL \ | |||
| --workdir=$CODE_DIR_IN_CONTAINER \ | |||
| @@ -59,6 +60,7 @@ do | |||
| -e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \ | |||
| -e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \ | |||
| -e TEST_LEVEL=$TEST_LEVEL \ | |||
| -e MODELSCOPE_ENVIRONMENT='ci' \ | |||
| -e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \ | |||
| -e MODEL_TAG_URL=$MODEL_TAG_URL \ | |||
| --workdir=$CODE_DIR_IN_CONTAINER \ | |||
| @@ -1,16 +1,26 @@ | |||
| # Introduction | |||
| ModelScope library is targeted to support training, evaluation and inference for the state of the art models provided by Mind and further support third-party models provided by users outside alibaba. | |||
| [ModelScope]( https://www.modelscope.cn) is a “Model-as-a-Service” (MaaS) platform that seeks to bringing together most advanced machine learning models from the AI community, and to streamlining the process of leveraging and applying AI models . The core ModelScope library enables developers to perform model inference, training and evaluation, through rich layers of API designs that facilitate a unified experience across state-of-the-art models from different AI domains. | |||
| # Design doc | |||
| The Python library offers the layered-APIs necessary for model contributors to integrate models from CV, NLP, Speech, Multi-Modality, as well as Scientific-computation, into the ModelScope ecosystem. Implementations for all these different models are encapsulated within the library in a way that allows easy and unified access. With such integration, model inference, finetuning, and evaluations can be done within only a few lines of codes. In the meantime, flexibilities are provided so that different components in the model applications can be customized as well, where necessary. | |||
| Please refer to alidoc [link](https://alidocs.dingtalk.com/i/nodes/OBldywvrKxo89xmAO05yJQk2ngpNbLz4?nav=spaces&navQuery=spaceId%3Dnb9XJNlZxbgrOXyA&iframeQuery=utm_source%3Dportal%26utm_medium%3Dportal_space_file_tree) | |||
| Apart from harboring implementations of various models, ModelScope library also enables the necessary interactions with the backend services of ModelScope, particularly with the Model-Hub and Dataset-Hub. Such interactions facilitate various entity (models and datasets) management to be performed seamlessly under-the-hood, such as entity lookup, version control, and cache management. | |||
| # Development doc | |||
| # Installation | |||
| Please refer to [develop.md](docs/source/develop.md) | |||
| Please refer to [installation](https://modelscope.cn/docs/%E7%8E%AF%E5%A2%83%E5%AE%89%E8%A3%85). | |||
| # ChangeLog | |||
| * 20/05/2022 First release version | |||
| # Get Started | |||
| Refer to [change_log.md](docs/source/change_log.md) for more details | |||
| You can refer to [quick_start](https://modelscope.cn/docs/%E5%BF%AB%E9%80%9F%E5%BC%80%E5%A7%8B) for quick start. | |||
| We also provide other documentations including: | |||
| * [Introduction to tasks](https://modelscope.cn/docs/%E4%BB%BB%E5%8A%A1%E7%9A%84%E4%BB%8B%E7%BB%8D) | |||
| * [Use pipeline for model inference](https://modelscope.cn/docs/%E6%A8%A1%E5%9E%8B%E7%9A%84%E6%8E%A8%E7%90%86Pipeline) | |||
| * [Finetune example](https://modelscope.cn/docs/%E6%A8%A1%E5%9E%8B%E7%9A%84%E8%AE%AD%E7%BB%83Train) | |||
| * [Preprocessing of data](https://modelscope.cn/docs/%E6%95%B0%E6%8D%AE%E7%9A%84%E9%A2%84%E5%A4%84%E7%90%86) | |||
| * [Evaluation metrics](https://modelscope.cn/docs/%E6%A8%A1%E5%9E%8B%E7%9A%84%E8%AF%84%E4%BC%B0) | |||
| # License | |||
| This project is licensed under the [Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE). | |||
| @@ -128,7 +128,7 @@ class TorchModelExporter(Exporter): | |||
| args_list = list(args) | |||
| else: | |||
| args_list = [args] | |||
| if isinstance(args_list[-1], dict): | |||
| if isinstance(args_list[-1], Mapping): | |||
| args_dict = args_list[-1] | |||
| args_list = args_list[:-1] | |||
| n_nonkeyword = len(args_list) | |||
| @@ -284,9 +284,8 @@ class TorchModelExporter(Exporter): | |||
| 'Model property dummy_inputs must be set.') | |||
| dummy_inputs = collate_fn(dummy_inputs, device) | |||
| if isinstance(dummy_inputs, Mapping): | |||
| dummy_inputs = self._decide_input_format(model, dummy_inputs) | |||
| dummy_inputs_filter = [] | |||
| for _input in dummy_inputs: | |||
| for _input in self._decide_input_format(model, dummy_inputs): | |||
| if _input is not None: | |||
| dummy_inputs_filter.append(_input) | |||
| else: | |||
| @@ -23,7 +23,8 @@ from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | |||
| API_RESPONSE_FIELD_MESSAGE, | |||
| API_RESPONSE_FIELD_USERNAME, | |||
| DEFAULT_CREDENTIALS_PATH, | |||
| MODELSCOPE_ENVIRONMENT, ONE_YEAR_SECONDS, | |||
| MODELSCOPE_ENVIRONMENT, | |||
| MODELSCOPE_USERNAME, ONE_YEAR_SECONDS, | |||
| Licenses, ModelVisibility) | |||
| from modelscope.hub.errors import (InvalidParameter, NotExistError, | |||
| NotLoginException, NoValidRevisionError, | |||
| @@ -38,8 +39,8 @@ from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | |||
| DEFAULT_MODEL_REVISION, | |||
| DEFAULT_REPOSITORY_REVISION, | |||
| MASTER_MODEL_BRANCH, DatasetFormations, | |||
| DatasetMetaFormats, DownloadMode, | |||
| ModelFile) | |||
| DatasetMetaFormats, DownloadChannel, | |||
| DownloadMode, ModelFile) | |||
| from modelscope.utils.logger import get_logger | |||
| from .utils.utils import (get_endpoint, get_release_datetime, | |||
| model_id_to_group_owner_name) | |||
| @@ -645,6 +646,25 @@ class HubApi: | |||
| def check_local_cookies(self, use_cookies) -> CookieJar: | |||
| return self._check_cookie(use_cookies=use_cookies) | |||
| def dataset_download_uv(self, dataset_name: str, namespace: str): | |||
| if not dataset_name or not namespace: | |||
| raise ValueError('dataset_name or namespace cannot be empty!') | |||
| # get channel and user_name | |||
| channel = DownloadChannel.LOCAL.value | |||
| user_name = '' | |||
| if MODELSCOPE_ENVIRONMENT in os.environ: | |||
| channel = os.environ[MODELSCOPE_ENVIRONMENT] | |||
| if MODELSCOPE_USERNAME in os.environ: | |||
| user_name = os.environ[MODELSCOPE_USERNAME] | |||
| url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/uv/{channel}?user={user_name}' | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| r = requests.post(url, cookies=cookies, headers=self.headers) | |||
| resp = r.json() | |||
| raise_on_error(resp) | |||
| return resp['Message'] | |||
| class ModelScopeConfig: | |||
| path_credential = expanduser(DEFAULT_CREDENTIALS_PATH) | |||
| @@ -760,14 +780,18 @@ class ModelScopeConfig: | |||
| env = 'custom' | |||
| if MODELSCOPE_ENVIRONMENT in os.environ: | |||
| env = os.environ[MODELSCOPE_ENVIRONMENT] | |||
| user_name = 'unknown' | |||
| if MODELSCOPE_USERNAME in os.environ: | |||
| user_name = os.environ[MODELSCOPE_USERNAME] | |||
| ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s' % ( | |||
| ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s; user/%s' % ( | |||
| __version__, | |||
| platform.python_version(), | |||
| ModelScopeConfig.get_user_session_id(), | |||
| platform.platform(), | |||
| platform.processor(), | |||
| env, | |||
| user_name, | |||
| ) | |||
| if isinstance(user_agent, dict): | |||
| ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items()) | |||
| @@ -18,6 +18,7 @@ API_RESPONSE_FIELD_EMAIL = 'Email' | |||
| API_RESPONSE_FIELD_MESSAGE = 'Message' | |||
| MODELSCOPE_ENVIRONMENT = 'MODELSCOPE_ENVIRONMENT' | |||
| MODELSCOPE_SDK_DEBUG = 'MODELSCOPE_SDK_DEBUG' | |||
| MODELSCOPE_USERNAME = 'MODELSCOPE_USERNAME' | |||
| ONE_YEAR_SECONDS = 24 * 365 * 60 * 60 | |||
| @@ -5,6 +5,8 @@ import os | |||
| from datetime import datetime | |||
| from typing import Optional | |||
| import requests | |||
| from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN, | |||
| DEFAULT_MODELSCOPE_GROUP, | |||
| MODEL_ID_SEPARATOR, MODELSCOPE_SDK_DEBUG, | |||
| @@ -85,3 +87,16 @@ def file_integrity_validation(file_path, expected_sha256): | |||
| msg = 'File %s integrity check failed, the download may be incomplete, please try again.' % file_path | |||
| logger.error(msg) | |||
| raise FileIntegrityError(msg) | |||
| def create_library_statistics(method: str, name: str, cn_name: Optional[str]): | |||
| try: | |||
| from modelscope.hub.api import ModelScopeConfig | |||
| path = f'{get_endpoint()}/api/v1/statistics/library' | |||
| headers = {'user-agent': ModelScopeConfig.get_user_agent()} | |||
| params = {'Method': method, 'Name': name, 'CnName': cn_name} | |||
| r = requests.post(path, params=params, headers=headers) | |||
| r.raise_for_status() | |||
| except Exception: | |||
| pass | |||
| return | |||
| @@ -1,6 +1,7 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import tempfile | |||
| from typing import Dict, Optional | |||
| from modelscope.metainfo import Models | |||
| @@ -36,12 +37,15 @@ class FSMNSeleNetV2Decorator(TorchModel): | |||
| else: | |||
| sc_config_file = os.path.join(model_dir, self.SC_CONFIG) | |||
| model_txt_file = os.path.join(model_dir, self.MODEL_TXT) | |||
| self.tmp_dir = tempfile.TemporaryDirectory() | |||
| new_config_file = os.path.join(self.tmp_dir.name, self.SC_CONFIG) | |||
| self._sc = None | |||
| if os.path.exists(model_txt_file): | |||
| conf_dict = dict(mode=56542, kws_model=model_txt_file) | |||
| update_conf(sc_config_file, sc_config_file, conf_dict) | |||
| update_conf(sc_config_file, new_config_file, conf_dict) | |||
| import py_sound_connect | |||
| self._sc = py_sound_connect.SoundConnect(sc_config_file) | |||
| self._sc = py_sound_connect.SoundConnect(new_config_file) | |||
| self.size_in = self._sc.bytesPerBlockIn() | |||
| self.size_out = self._sc.bytesPerBlockOut() | |||
| else: | |||
| @@ -49,6 +53,9 @@ class FSMNSeleNetV2Decorator(TorchModel): | |||
| f'Invalid model directory! Failed to load model file: {model_txt_file}.' | |||
| ) | |||
| def __del__(self): | |||
| self.tmp_dir.cleanup() | |||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| return self.model.forward(input) | |||
| @@ -131,6 +131,8 @@ class Model(ABC): | |||
| if not hasattr(model, 'cfg'): | |||
| model.cfg = cfg | |||
| model.name = model_name_or_path | |||
| return model | |||
| def save_pretrained(self, | |||
| @@ -349,11 +349,13 @@ class CLIP(nn.Module): | |||
| text_num_hidden_layers: int, | |||
| text_type_vocab_size: int, | |||
| tokenizer: FullTokenizer, | |||
| # vision_head_width, added this param for ViT-H | |||
| vision_head_width: int = 64, | |||
| ): | |||
| super().__init__() | |||
| if isinstance(vision_layers, (tuple, list)): | |||
| vision_heads = vision_width * 32 // 64 | |||
| vision_heads = vision_width * 32 // vision_head_width | |||
| self.visual = ModifiedResNet( | |||
| layers=vision_layers, | |||
| output_dim=embed_dim, | |||
| @@ -361,7 +363,7 @@ class CLIP(nn.Module): | |||
| input_resolution=image_resolution, | |||
| width=vision_width) | |||
| else: | |||
| vision_heads = vision_width // 64 | |||
| vision_heads = vision_width // vision_head_width | |||
| self.visual = VisualTransformer( | |||
| input_resolution=image_resolution, | |||
| patch_size=vision_patch_size, | |||
| @@ -136,6 +136,12 @@ class OFAConfig(PretrainedConfig): | |||
| entangle_position_embedding=False, | |||
| interpolate_position=False, | |||
| orig_patch_image_size=224, | |||
| share_attn_bias=False, | |||
| use_image_feature=True, | |||
| disable_entangle=False, | |||
| use_ofasys=False, | |||
| vit_type='vit_base', | |||
| vit_drop_path_rate=0.0, | |||
| **kwargs): | |||
| self.vocab_size = vocab_size | |||
| self.max_position_embeddings = max_position_embeddings | |||
| @@ -178,6 +184,13 @@ class OFAConfig(PretrainedConfig): | |||
| self.interpolate_position = interpolate_position | |||
| self.orig_patch_image_size = orig_patch_image_size | |||
| self.share_attn_bias = share_attn_bias | |||
| self.use_image_feature = use_image_feature | |||
| self.disable_entangle = disable_entangle | |||
| self.use_ofasys = use_ofasys | |||
| self.vit_type = vit_type | |||
| self.vit_drop_path_rate = vit_drop_path_rate | |||
| super().__init__( | |||
| pad_token_id=pad_token_id, | |||
| bos_token_id=bos_token_id, | |||
| @@ -35,6 +35,8 @@ from transformers.utils import logging | |||
| from .configuration_ofa import OFAConfig | |||
| from .generate import utils | |||
| from .resnet import ResNet | |||
| from .utils.utils import DropPath | |||
| from .vit import vit_base, vit_huge, vit_large, vit_large_336 | |||
| logger = logging.get_logger(__name__) | |||
| @@ -249,45 +251,6 @@ class LayerDropModuleList(nn.ModuleList): | |||
| yield m | |||
| def drop_path(x, drop_prob: float = 0.0, training: bool = False): | |||
| r""" | |||
| Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |||
| Args: | |||
| x (`nn.Modules`): input nn layers. | |||
| drop_prob (`float`): drop path ratio. | |||
| training (`bool`): whether is training or inference. | |||
| """ | |||
| if drop_prob == 0.0 or not training: | |||
| return x | |||
| keep_prob = 1 - drop_prob | |||
| shape = (1, x.shape[1], 1) | |||
| random_tensor = keep_prob + torch.rand( | |||
| shape, dtype=x.dtype, device=x.device) | |||
| random_tensor.floor_() # binarize | |||
| output = x.div(keep_prob) * random_tensor | |||
| return output | |||
| class DropPath(nn.Module): | |||
| r""" | |||
| Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |||
| Args: | |||
| drop_prob: drop path ratio. | |||
| """ | |||
| def __init__(self, drop_prob=None): | |||
| super().__init__() | |||
| self.drop_prob = drop_prob | |||
| def forward(self, x): | |||
| return drop_path(x, self.drop_prob, self.training) | |||
| def extra_repr(self) -> str: | |||
| return 'p={}'.format(self.drop_prob) | |||
| class OFAAttention(nn.Module): | |||
| r""" | |||
| Multi-headed attention, with additional implementation for NormFormer. | |||
| @@ -898,31 +861,49 @@ class OFAEncoder(OFAPreTrainedModel): | |||
| self.padding_idx) | |||
| if config.add_type_embedding: | |||
| self.type_embedding = Embedding(2, embed_dim, padding_idx=None) | |||
| if config.use_image_feature: | |||
| self.type_embedding = Embedding(2, embed_dim, padding_idx=None) | |||
| else: | |||
| self.type_embedding = Embedding(1, embed_dim, padding_idx=None) | |||
| else: | |||
| self.type_embedding = None | |||
| if config.resnet_type == 'resnet18': | |||
| self.embed_images = ResNet( | |||
| [2, 2, 2], drop_path_rate=config.resnet_drop_path_rate) | |||
| elif config.resnet_type == 'resnet34': | |||
| self.embed_images = ResNet( | |||
| [3, 4, 6], drop_path_rate=config.resnet_drop_path_rate) | |||
| elif config.resnet_type == 'resnet50': | |||
| self.embed_images = ResNet( | |||
| [3, 4, 6], drop_path_rate=config.resnet_drop_path_rate) | |||
| elif config.resnet_type == 'resnet101': | |||
| self.embed_images = ResNet( | |||
| [3, 4, 23], drop_path_rate=config.resnet_drop_path_rate) | |||
| elif config.resnet_type == 'resnet152': | |||
| self.embed_images = ResNet( | |||
| [3, 8, 36], drop_path_rate=config.resnet_drop_path_rate) | |||
| else: | |||
| raise NotImplementedError | |||
| if config.use_image_feature: | |||
| if config.use_ofasys: | |||
| vit_backbone = { | |||
| 'vit_base': vit_base, | |||
| 'vit_large': vit_large, | |||
| 'vit_large_336': vit_large_336, | |||
| 'vit_huge': vit_huge, | |||
| }[config.vit_type] | |||
| self.embed_images = vit_backbone(config.vit_drop_path_rate) | |||
| self.image_proj = Linear(1024, embed_dim) | |||
| self.image_proj = Linear(self.embed_images.width, embed_dim) | |||
| if config.resnet_model_path: | |||
| else: | |||
| if config.resnet_type == 'resnet18': | |||
| self.embed_images = ResNet( | |||
| [2, 2, 2], drop_path_rate=config.resnet_drop_path_rate) | |||
| elif config.resnet_type == 'resnet34': | |||
| self.embed_images = ResNet( | |||
| [3, 4, 6], drop_path_rate=config.resnet_drop_path_rate) | |||
| elif config.resnet_type == 'resnet50': | |||
| self.embed_images = ResNet( | |||
| [3, 4, 6], drop_path_rate=config.resnet_drop_path_rate) | |||
| elif config.resnet_type == 'resnet101': | |||
| self.embed_images = ResNet( | |||
| [3, 4, 23], | |||
| drop_path_rate=config.resnet_drop_path_rate) | |||
| elif config.resnet_type == 'resnet152': | |||
| self.embed_images = ResNet( | |||
| [3, 8, 36], | |||
| drop_path_rate=config.resnet_drop_path_rate) | |||
| else: | |||
| raise NotImplementedError | |||
| self.image_proj = Linear(1024, embed_dim) | |||
| if not config.use_ofasys and config.resnet_model_path: | |||
| print('load resnet {}'.format(config.resnet_model_path)) | |||
| resnet_state_dict = torch.load(config.resnet_model_path) | |||
| self.embed_images.load_state_dict(resnet_state_dict) | |||
| @@ -933,14 +914,21 @@ class OFAEncoder(OFAPreTrainedModel): | |||
| self.embed_positions = Embedding(self.max_source_positions + 2, | |||
| embed_dim) | |||
| self.embed_image_positions = Embedding(config.image_bucket_size**2 + 1, | |||
| embed_dim) | |||
| self.pos_ln = LayerNorm(embed_dim) | |||
| self.image_pos_ln = LayerNorm(embed_dim) | |||
| if config.use_image_feature: | |||
| self.embed_image_positions = Embedding( | |||
| config.image_bucket_size**2 + 1, embed_dim) | |||
| if not config.use_ofasys: | |||
| self.pos_ln = LayerNorm(embed_dim) | |||
| if config.use_image_feature: | |||
| self.image_pos_ln = LayerNorm(embed_dim) | |||
| self.pos_scaling = float(embed_dim / self.num_attention_heads | |||
| * config.attn_scale_factor)**-0.5 | |||
| self.pos_q_linear = nn.Linear(embed_dim, embed_dim) | |||
| self.pos_k_linear = nn.Linear(embed_dim, embed_dim) | |||
| if not (config.use_ofasys and config.entangle_position_embedding): | |||
| self.pos_q_linear = nn.Linear(embed_dim, embed_dim) | |||
| self.pos_k_linear = nn.Linear(embed_dim, embed_dim) | |||
| if self.encoder_layerdrop > 0.0: | |||
| self.layers = LayerDropModuleList(p=self.encoder_layerdrop) | |||
| @@ -965,22 +953,28 @@ class OFAEncoder(OFAPreTrainedModel): | |||
| self.token_bucket_size = config.token_bucket_size | |||
| token_num_rel_dis = 2 * config.token_bucket_size - 1 | |||
| token_rp_bucket = make_token_bucket_position(config.token_bucket_size) | |||
| self.share_attn_bias = config.share_attn_bias | |||
| num_rel_pos_tables = 1 if config.share_attn_bias else config.encoder_layers | |||
| self.token_rel_pos_table_list = nn.ModuleList([ | |||
| Embedding( | |||
| token_num_rel_dis, self.num_attention_heads, zero_init=True) | |||
| for _ in range(config.encoder_layers) | |||
| for _ in range(num_rel_pos_tables) | |||
| ]) | |||
| self.image_bucket_size = config.image_bucket_size | |||
| image_num_rel_dis = (2 * config.image_bucket_size | |||
| - 1) * (2 * config.image_bucket_size - 1) + 3 | |||
| image_rp_bucket = make_image_bucket_position(config.image_bucket_size, | |||
| image_num_rel_dis) | |||
| self.image_rel_pos_table_list = nn.ModuleList([ | |||
| Embedding( | |||
| image_num_rel_dis, self.num_attention_heads, zero_init=True) | |||
| for _ in range(config.encoder_layers) | |||
| ]) | |||
| if config.use_image_feature: | |||
| self.image_bucket_size = config.image_bucket_size | |||
| image_num_rel_dis = (2 * config.image_bucket_size | |||
| - 1) * (2 * config.image_bucket_size - 1) + 3 | |||
| image_rp_bucket = make_image_bucket_position( | |||
| config.image_bucket_size, image_num_rel_dis) | |||
| self.image_rel_pos_table_list = nn.ModuleList([ | |||
| Embedding( | |||
| image_num_rel_dis, | |||
| self.num_attention_heads, | |||
| zero_init=True) for _ in range(num_rel_pos_tables) | |||
| ]) | |||
| self.register_buffer('image_rp_bucket', image_rp_bucket) | |||
| if config.layernorm_embedding: | |||
| self.layernorm_embedding = LayerNorm(embed_dim) | |||
| @@ -988,12 +982,12 @@ class OFAEncoder(OFAPreTrainedModel): | |||
| self.layernorm_embedding = None | |||
| self.register_buffer('token_rp_bucket', token_rp_bucket) | |||
| self.register_buffer('image_rp_bucket', image_rp_bucket) | |||
| self.entangle_position_embedding = config.entangle_position_embedding | |||
| self.gradient_checkpointing = False | |||
| # Initialize weights and apply final processing | |||
| self.post_init() | |||
| self.use_ofasys = config.use_ofasys | |||
| def get_input_embeddings(self): | |||
| r""" | |||
| @@ -1305,21 +1299,41 @@ class OFAEncoder(OFAPreTrainedModel): | |||
| if has_pads: | |||
| x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) | |||
| pos_embed = self.pos_ln(pos_embed) | |||
| if patch_images is not None: | |||
| image_pos_embed = self.image_pos_ln(image_pos_embed) | |||
| pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1) | |||
| if patch_images_2 is not None: | |||
| image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2) | |||
| pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1) | |||
| if self.use_ofasys: | |||
| if patch_images is not None: | |||
| pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1) | |||
| if patch_images_2 is not None: | |||
| pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1) | |||
| else: | |||
| pos_embed = self.pos_ln(pos_embed) | |||
| if patch_images is not None: | |||
| image_pos_embed = self.image_pos_ln(image_pos_embed) | |||
| pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1) | |||
| if patch_images_2 is not None: | |||
| image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2) | |||
| pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1) | |||
| def build_abs_pos_bias(pos_embed): | |||
| batch_size, seq_length = pos_embed.size(0), pos_embed.size(1) | |||
| if not (self.use_ofasys and self.entangle_position_embedding): | |||
| pos_q = self.pos_q_linear(pos_embed).view( | |||
| batch_size, seq_length, self.num_attention_heads, | |||
| -1).transpose(1, 2) * self.pos_scaling | |||
| pos_k = self.pos_k_linear(pos_embed).view( | |||
| batch_size, seq_length, self.num_attention_heads, | |||
| -1).transpose(1, 2) | |||
| abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) | |||
| else: | |||
| abs_pos_bias = torch.zeros( | |||
| batch_size, | |||
| self.num_attention_heads, | |||
| seq_length, | |||
| seq_length, | |||
| dtype=pos_embed.dtype, | |||
| device=pos_embed.device) | |||
| return abs_pos_bias | |||
| pos_q = self.pos_q_linear(pos_embed).view( | |||
| x.size(0), x.size(1), self.num_attention_heads, -1).transpose( | |||
| 1, 2) * self.pos_scaling | |||
| pos_k = self.pos_k_linear(pos_embed).view( | |||
| x.size(0), x.size(1), self.num_attention_heads, | |||
| -1).transpose(1, 2) | |||
| abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) | |||
| abs_pos_bias = build_abs_pos_bias(pos_embed) | |||
| # expand attention_mask | |||
| if has_pads: | |||
| @@ -1334,19 +1348,22 @@ class OFAEncoder(OFAPreTrainedModel): | |||
| if output_hidden_states: | |||
| encoder_states += (x, ) | |||
| self_attn_bias = abs_pos_bias.clone() | |||
| real_idx = 0 if self.share_attn_bias else idx | |||
| self_attn_bias[:, :, -input_ids.size(1):, | |||
| -input_ids.size(1):] += self.get_rel_pos_bias( | |||
| input_ids, idx) | |||
| input_ids, real_idx) | |||
| if patch_images_2 is not None: | |||
| self_attn_bias[:, :, :image_num_patches_2, :image_num_patches_2] += \ | |||
| self.get_image_rel_pos_bias(image_position_ids_2, idx) | |||
| self.get_image_rel_pos_bias(image_position_ids_2, real_idx) | |||
| self_attn_bias[:, :, | |||
| image_num_patches_2:image_num_patches_2 + image_num_patches, # noqa | |||
| image_num_patches_2:image_num_patches_2 + image_num_patches] += \ | |||
| self.get_image_rel_pos_bias(image_position_ids, idx) # noqa | |||
| self.get_image_rel_pos_bias(image_position_ids, real_idx) # noqa | |||
| elif patch_images is not None: | |||
| self_attn_bias[:, :, :x.size(1) - input_ids.size(1), :x.size(1) - input_ids.size(1)] += \ | |||
| self.get_image_rel_pos_bias(image_position_ids, idx) | |||
| self.get_image_rel_pos_bias(image_position_ids, real_idx) | |||
| self_attn_bias = self_attn_bias.reshape(-1, x.size(1), x.size(1)) | |||
| hidden_outputs = layer( | |||
| @@ -1398,6 +1415,8 @@ class OFADecoder(OFAPreTrainedModel): | |||
| self._future_mask = torch.empty(0) | |||
| self.share_input_output_embed = config.share_decoder_input_output_embed | |||
| self.num_attention_heads = config.decoder_attention_heads | |||
| self.use_ofasys = config.use_ofasys | |||
| self.disable_entangle = config.disable_entangle | |||
| if embed_tokens is not None: | |||
| self.embed_tokens = embed_tokens | |||
| @@ -1415,18 +1434,31 @@ class OFADecoder(OFAPreTrainedModel): | |||
| else: | |||
| self.layernorm_embedding = None | |||
| if config.use_ofasys: | |||
| if config.add_type_embedding: | |||
| self.type_embedding = Embedding( | |||
| 1, self.embed_dim, padding_idx=None) | |||
| else: | |||
| self.type_embedding = None | |||
| self.window_size = config.code_image_size // 8 | |||
| self.embed_positions = Embedding(self.max_target_positions + 2, | |||
| self.embed_dim) | |||
| self.embed_image_positions = Embedding(config.image_bucket_size**2 + 1, | |||
| self.embed_dim) | |||
| self.pos_ln = LayerNorm(self.embed_dim) | |||
| self.image_pos_ln = LayerNorm(self.embed_dim) | |||
| if not config.use_ofasys: | |||
| self.embed_image_positions = Embedding( | |||
| config.image_bucket_size**2 + 1, self.embed_dim) | |||
| if not config.use_ofasys: | |||
| self.pos_ln = LayerNorm(self.embed_dim) | |||
| self.image_pos_ln = LayerNorm(self.embed_dim) | |||
| self.pos_scaling = float(self.embed_dim / self.num_attention_heads | |||
| * config.attn_scale_factor)**-0.5 | |||
| self.self_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim) | |||
| self.self_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim) | |||
| if not (config.use_ofasys and config.entangle_position_embedding): | |||
| self.self_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim) | |||
| self.self_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim) | |||
| self.cross_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim) | |||
| self.cross_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim) | |||
| @@ -1463,33 +1495,41 @@ class OFADecoder(OFAPreTrainedModel): | |||
| self.token_bucket_size = config.token_bucket_size | |||
| token_num_rel_dis = 2 * config.token_bucket_size - 1 | |||
| token_rp_bucket = make_token_bucket_position(config.token_bucket_size) | |||
| self.share_attn_bias = config.share_attn_bias | |||
| num_rel_pos_tables = 1 if config.share_attn_bias else config.decoder_layers | |||
| self.token_rel_pos_table_list = nn.ModuleList([ | |||
| Embedding( | |||
| token_num_rel_dis, self.num_attention_heads, zero_init=True) | |||
| for _ in range(config.decoder_layers) | |||
| for _ in range(num_rel_pos_tables) | |||
| ]) | |||
| self.image_bucket_size = config.image_bucket_size | |||
| image_num_rel_dis = (2 * config.image_bucket_size | |||
| - 1) * (2 * config.image_bucket_size - 1) + 3 | |||
| image_rp_bucket = make_image_bucket_position(config.image_bucket_size, | |||
| image_num_rel_dis) | |||
| image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \ | |||
| torch.arange(self.window_size).unsqueeze(1) * config.image_bucket_size + 1 # noqa | |||
| image_position_idx = torch.cat( | |||
| [torch.tensor([0]), image_position_idx.view(-1)]) | |||
| image_position_idx = torch.cat( | |||
| [image_position_idx, | |||
| torch.tensor([1024] * 768)]) | |||
| self.image_rel_pos_table_list = nn.ModuleList([ | |||
| Embedding( | |||
| image_num_rel_dis, self.num_attention_heads, zero_init=True) | |||
| for _ in range(config.decoder_layers) | |||
| ]) | |||
| if config.use_image_feature: | |||
| if not config.use_ofasys: | |||
| self.image_bucket_size = config.image_bucket_size | |||
| image_num_rel_dis = (2 * config.image_bucket_size - 1) * ( | |||
| 2 * config.image_bucket_size - 1) + 3 | |||
| image_rp_bucket = make_image_bucket_position( | |||
| config.image_bucket_size, image_num_rel_dis) | |||
| image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \ | |||
| torch.arange(self.window_size).unsqueeze(1) * config.image_bucket_size + 1 # noqa | |||
| image_position_idx = torch.cat( | |||
| [torch.tensor([0]), | |||
| image_position_idx.view(-1)]) | |||
| image_position_idx = torch.cat( | |||
| [image_position_idx, | |||
| torch.tensor([1024] * 768)]) | |||
| self.register_buffer('image_position_idx', image_position_idx) | |||
| self.image_rel_pos_table_list = nn.ModuleList([ | |||
| Embedding( | |||
| image_num_rel_dis, | |||
| self.num_attention_heads, | |||
| zero_init=True) for _ in range(num_rel_pos_tables) | |||
| ]) | |||
| self.register_buffer('image_rp_bucket', image_rp_bucket) | |||
| self.register_buffer('token_rp_bucket', token_rp_bucket) | |||
| self.register_buffer('image_rp_bucket', image_rp_bucket) | |||
| self.register_buffer('image_position_idx', image_position_idx) | |||
| self.entangle_position_embedding = config.entangle_position_embedding | |||
| self.gradient_checkpointing = False | |||
| @@ -1556,26 +1596,46 @@ class OFADecoder(OFAPreTrainedModel): | |||
| batch_size = tgt_pos_embed.size(0) | |||
| tgt_len = tgt_pos_embed.size(1) | |||
| tgt_pos_embed = self.image_pos_ln( | |||
| tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed) | |||
| if not self.use_ofasys: | |||
| tgt_pos_embed = self.image_pos_ln( | |||
| tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed) | |||
| if src_pos_embed is not None: | |||
| src_len = src_pos_embed.size(1) | |||
| pos_q = self.cross_pos_q_linear(tgt_pos_embed).view( | |||
| batch_size, tgt_len, self.num_attention_heads, -1).transpose( | |||
| 1, 2) * self.pos_scaling | |||
| pos_k = self.cross_pos_k_linear(src_pos_embed).view( | |||
| batch_size, src_len, self.num_attention_heads, | |||
| -1).transpose(1, 2) | |||
| if not (self.entangle_position_embedding and self.use_ofasys): | |||
| pos_q = self.cross_pos_q_linear(tgt_pos_embed).view( | |||
| batch_size, tgt_len, self.num_attention_heads, | |||
| -1).transpose(1, 2) * self.pos_scaling | |||
| pos_k = self.cross_pos_k_linear(src_pos_embed).view( | |||
| batch_size, src_len, self.num_attention_heads, | |||
| -1).transpose(1, 2) | |||
| abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) | |||
| else: | |||
| abs_pos_bias = torch.zeros( | |||
| batch_size, | |||
| self.num_attention_heads, | |||
| tgt_len, | |||
| src_len, | |||
| dtype=tgt_pos_embed.dtype, | |||
| device=tgt_pos_embed.device) | |||
| else: | |||
| src_len = tgt_pos_embed.size(1) | |||
| pos_q = self.self_pos_q_linear(tgt_pos_embed).view( | |||
| batch_size, tgt_len, self.num_attention_heads, -1).transpose( | |||
| 1, 2) * self.pos_scaling | |||
| pos_k = self.self_pos_k_linear(tgt_pos_embed).view( | |||
| batch_size, src_len, self.num_attention_heads, | |||
| -1).transpose(1, 2) | |||
| abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) | |||
| # batch_size, seq_length = tgt_pos_embed.size(0), tgt_pos_embed.size(1) | |||
| if not (self.entangle_position_embedding and self.use_ofasys): | |||
| pos_q = self.self_pos_q_linear(tgt_pos_embed).view( | |||
| batch_size, tgt_len, self.num_attention_heads, | |||
| -1).transpose(1, 2) * self.pos_scaling | |||
| pos_k = self.self_pos_k_linear(tgt_pos_embed).view( | |||
| batch_size, tgt_len, self.num_attention_heads, | |||
| -1).transpose(1, 2) | |||
| abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) | |||
| else: | |||
| abs_pos_bias = torch.zeros( | |||
| batch_size, | |||
| self.num_attention_heads, | |||
| tgt_len, | |||
| tgt_len, | |||
| dtype=tgt_pos_embed.dtype, | |||
| device=tgt_pos_embed.device) | |||
| return abs_pos_bias | |||
| @@ -1809,17 +1869,18 @@ class OFADecoder(OFAPreTrainedModel): | |||
| past_key_values) > 0 else None | |||
| self_attn_bias = self_abs_pos_bias.clone() | |||
| real_idx = 0 if self.share_attn_bias else idx | |||
| if code_masks is None or not code_masks.any(): | |||
| self_attn_bias += self.get_rel_pos_bias( | |||
| all_prev_output_tokens, idx).unsqueeze(0) | |||
| all_prev_output_tokens, real_idx).unsqueeze(0) | |||
| elif code_masks is not None and code_masks.all(): | |||
| self_attn_bias += self.get_image_rel_pos_bias( | |||
| all_prev_output_tokens, idx).unsqueeze(0) | |||
| all_prev_output_tokens, real_idx).unsqueeze(0) | |||
| else: | |||
| self_attn_bias[~code_masks] += self.get_rel_pos_bias( | |||
| all_prev_output_tokens, idx).unsqueeze(0) | |||
| all_prev_output_tokens, real_idx).unsqueeze(0) | |||
| self_attn_bias[code_masks] += self.get_image_rel_pos_bias( | |||
| all_prev_output_tokens, idx).unsqueeze(0) | |||
| all_prev_output_tokens, real_idx).unsqueeze(0) | |||
| self_attn_bias = self_attn_bias.reshape( | |||
| -1, | |||
| *self_attn_bias.size()[-2:]) | |||
| @@ -1892,6 +1953,7 @@ class OFAModel(OFAPreTrainedModel): | |||
| self.encoder = OFAEncoder(config, shared) | |||
| self.decoder = OFADecoder(config, shared) | |||
| self.use_ofasys = config.use_ofasys | |||
| # Initialize weights and apply final processing | |||
| self.post_init() | |||
| @@ -2,6 +2,7 @@ | |||
| from typing import Optional | |||
| import torch | |||
| import torch.nn as nn | |||
| def expand_mask(mask: torch.Tensor, | |||
| @@ -17,3 +18,42 @@ def expand_mask(mask: torch.Tensor, | |||
| src_len).to(dtype) | |||
| return expanded_mask.masked_fill(expanded_mask.bool(), | |||
| torch.finfo(dtype).min) | |||
| def drop_path(x, drop_prob: float = 0.0, training: bool = False): | |||
| r""" | |||
| Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |||
| Args: | |||
| x (`nn.Modules`): input nn layers. | |||
| drop_prob (`float`): drop path ratio. | |||
| training (`bool`): whether is training or inference. | |||
| """ | |||
| if drop_prob == 0.0 or not training: | |||
| return x | |||
| keep_prob = 1 - drop_prob | |||
| shape = (1, x.shape[1], 1) | |||
| random_tensor = keep_prob + torch.rand( | |||
| shape, dtype=x.dtype, device=x.device) | |||
| random_tensor.floor_() # binarize | |||
| output = x.div(keep_prob) * random_tensor | |||
| return output | |||
| class DropPath(nn.Module): | |||
| r""" | |||
| Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |||
| Args: | |||
| drop_prob: drop path ratio. | |||
| """ | |||
| def __init__(self, drop_prob=None): | |||
| super().__init__() | |||
| self.drop_prob = drop_prob | |||
| def forward(self, x): | |||
| return drop_path(x, self.drop_prob, self.training) | |||
| def extra_repr(self) -> str: | |||
| return 'p={}'.format(self.drop_prob) | |||
| @@ -0,0 +1,155 @@ | |||
| from collections import OrderedDict | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from fairseq.modules import LayerNorm | |||
| from torch import nn | |||
| from .utils.utils import DropPath | |||
| __all__ = [ | |||
| 'vit_base', | |||
| 'vit_large', | |||
| 'vit_large_336', | |||
| 'vit_huge', | |||
| ] | |||
| class QuickGELU(nn.Module): | |||
| def forward(self, x: torch.Tensor): | |||
| return x * torch.sigmoid(1.702 * x) | |||
| class ResidualAttentionBlock(nn.Module): | |||
| def __init__(self, | |||
| d_model: int, | |||
| n_head: int, | |||
| attn_mask: torch.Tensor = None, | |||
| drop_path_rate=0.0): | |||
| super().__init__() | |||
| self.attn = nn.MultiheadAttention(d_model, n_head) | |||
| self.ln_1 = LayerNorm(d_model) | |||
| self.mlp = nn.Sequential( | |||
| OrderedDict([ | |||
| ('c_fc', nn.Linear(d_model, d_model * 4)), | |||
| ('gelu', QuickGELU()), | |||
| ('c_proj', nn.Linear(d_model * 4, d_model)), | |||
| ])) | |||
| self.ln_2 = LayerNorm(d_model) | |||
| self.attn_mask = attn_mask | |||
| self.drop_path = DropPath(drop_path_rate) | |||
| def attention(self, x: torch.Tensor): | |||
| self.attn_mask = ( | |||
| self.attn_mask.to(dtype=x.dtype, device=x.device) | |||
| if self.attn_mask is not None else None) | |||
| return self.attn( | |||
| x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] | |||
| def forward(self, x: torch.Tensor): | |||
| x = x + self.drop_path(self.attention(self.ln_1(x))) | |||
| x = x + self.drop_path(self.mlp(self.ln_2(x))) | |||
| return x | |||
| class Transformer(nn.Module): | |||
| def __init__( | |||
| self, | |||
| width: int, | |||
| layers: int, | |||
| heads: int, | |||
| attn_mask: torch.Tensor = None, | |||
| drop_path_rate: float = 0.0, | |||
| ): | |||
| super().__init__() | |||
| self.width = width | |||
| self.layers = layers | |||
| self.resblocks = nn.Sequential(*[ | |||
| ResidualAttentionBlock(width, heads, attn_mask, drop_path_rate) | |||
| for _ in range(layers) | |||
| ]) | |||
| def forward(self, x: torch.Tensor): | |||
| return self.resblocks(x) | |||
| class VisionTransformer(nn.Module): | |||
| def __init__( | |||
| self, | |||
| input_resolution: int, | |||
| patch_size: int, | |||
| width: int, | |||
| layers: int, | |||
| heads: int, | |||
| drop_path_rate: float = 0.0, | |||
| ): | |||
| super().__init__() | |||
| self.input_resolution = input_resolution | |||
| self.patch_size = patch_size | |||
| self.conv1 = nn.Conv2d( | |||
| in_channels=3, | |||
| out_channels=width, | |||
| kernel_size=patch_size, | |||
| stride=patch_size, | |||
| bias=False, | |||
| ) | |||
| scale = width**-0.5 | |||
| self.width = width | |||
| self.positional_embedding = nn.Parameter(scale * torch.randn( | |||
| (input_resolution // patch_size)**2 + 1, width)) | |||
| self.ln_pre = LayerNorm(width) | |||
| self.transformer = Transformer( | |||
| width, layers, heads, drop_path_rate=drop_path_rate) | |||
| def forward(self, x: torch.Tensor): | |||
| resolution = x.shape[-2] | |||
| height, width = x.shape[-2] // self.patch_size, x.shape[ | |||
| -1] // self.patch_size | |||
| x = self.conv1(x) # shape = [*, width, grid, grid] | |||
| x = x.reshape(x.shape[0], x.shape[1], | |||
| -1) # shape = [*, width, grid ** 2] | |||
| x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |||
| if resolution != self.input_resolution: | |||
| old_pe = self.positional_embedding[1:] | |||
| patch_num = self.input_resolution // self.patch_size | |||
| old_pe = old_pe.reshape(1, patch_num, patch_num, | |||
| -1).permute(0, 3, 1, 2) | |||
| new_pe = F.interpolate( | |||
| old_pe, size=(height, width), mode='bilinear') | |||
| new_pe = new_pe.permute(0, 2, 3, 1).reshape(height * width, -1) | |||
| x = x + new_pe.to(x.dtype) | |||
| else: | |||
| x = x + self.positional_embedding[1:].to(x.dtype) | |||
| x = self.ln_pre(x) | |||
| x = x.permute(1, 0, 2) # NLD -> LND | |||
| x = self.transformer(x) | |||
| x = x.permute(1, 0, 2) # LND -> NLD | |||
| bz, seq, hidden = x.shape | |||
| x = x.transpose(1, 2).reshape(bz, hidden, height, width) | |||
| return x | |||
| def vit_base(drop_path_rate: float = 0.0): | |||
| return VisionTransformer(224, 16, 768, 9, 12, drop_path_rate) | |||
| def vit_large(drop_path_rate: float = 0.0): | |||
| return VisionTransformer(224, 14, 1024, 18, 16, drop_path_rate) | |||
| def vit_large_336(drop_path_rate: float = 0.0): | |||
| return VisionTransformer(336, 14, 1024, 18, 16, drop_path_rate) | |||
| def vit_huge(drop_path_rate: float = 0.0): | |||
| return VisionTransformer(224, 14, 1280, 24, 16, drop_path_rate) | |||
| @@ -1,6 +1,7 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import math | |||
| import os | |||
| import re | |||
| import string | |||
| from functools import partial | |||
| from os import path as osp | |||
| @@ -53,8 +54,11 @@ class OfaForAllTasks(TorchModel): | |||
| raise NotImplementedError | |||
| # there is some diff between here and our ofa code, | |||
| # there will be no need to use param: use_bpe | |||
| self.tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)]) | |||
| self.tokenizer.add_tokens(['<bin_{}>'.format(i) for i in range(1000)]) | |||
| if not model.use_ofasys: | |||
| self.tokenizer.add_tokens( | |||
| ['<code_{}>'.format(i) for i in range(8192)]) | |||
| self.tokenizer.add_tokens( | |||
| ['<bin_{}>'.format(i) for i in range(1000)]) | |||
| self.cfg.update({'num_bins': 1000, 'num_codes': 8192}) | |||
| self.batch_size = self.cfg.model.get('batch_size', 1) | |||
| self.patch_image_size = self.cfg.model.get('patch_image_size', 480) | |||
| @@ -107,6 +111,8 @@ class OfaForAllTasks(TorchModel): | |||
| Tasks.text_classification: inference_d[self.gen_type], | |||
| Tasks.image_classification: inference_d[self.gen_type], | |||
| } | |||
| pattern_str = '((?<=[^ a-zA-Z0-9.,:!?]) +| +(?=[^ a-zA-Z0-9.,:!?]))' | |||
| self.pattern = re.compile(pattern_str) | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| input = move_to_device(input, self.model.device) | |||
| @@ -132,8 +138,18 @@ class OfaForAllTasks(TorchModel): | |||
| caption = input[OutputKeys.CAPTION] | |||
| result_l = list() | |||
| for cap in caption: | |||
| result_l.append(cap.translate(self.transtab).strip()) | |||
| if self.language == 'en': | |||
| result_l.append(cap.translate(self.transtab).strip()) | |||
| else: | |||
| result_l.append(cap) | |||
| input[OutputKeys.CAPTION] = result_l | |||
| if self.gen_type == 'generation' and self.language in [ | |||
| 'zh', 'cn' | |||
| ] and self.cfg.task != Tasks.visual_grounding: | |||
| ret_l = list() | |||
| for text in input[OFA_TASK_KEY_MAPPING[self.cfg.task]]: | |||
| ret_l.append(self.detokenizer(text)) | |||
| input[OFA_TASK_KEY_MAPPING[self.cfg.task]] = ret_l | |||
| return input | |||
| def _text_gen_inference(self, input): | |||
| @@ -311,3 +327,6 @@ class OfaForAllTasks(TorchModel): | |||
| save_function=partial(save_function, with_meta=False), | |||
| config=config, | |||
| **kwargs) | |||
| def detokenizer(self, text): | |||
| return self.pattern.sub('', text) | |||
| @@ -1,3 +1,6 @@ | |||
| # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||
| # and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||
| import copy | |||
| import logging | |||
| import os | |||
| @@ -1,3 +1,6 @@ | |||
| # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||
| # and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||
| import argparse | |||
| import os | |||
| from typing import Any | |||
| @@ -0,0 +1,3 @@ | |||
| # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||
| # and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||
| """Unifold Modules.""" | |||
| @@ -274,6 +274,8 @@ class MsDataset: | |||
| try: | |||
| api.on_dataset_download( | |||
| dataset_name=download_dataset, namespace=namespace) | |||
| api.dataset_download_uv( | |||
| dataset_name=download_dataset, namespace=namespace) | |||
| except Exception as e: | |||
| logger.error(e) | |||
| @@ -69,11 +69,23 @@ TASK_OUTPUTS = { | |||
| # face 2d keypoint result for single sample | |||
| # { | |||
| # "keypoints": [ | |||
| # [x1, y1]*106 | |||
| # [[x, y]*106], | |||
| # [[x, y]*106], | |||
| # [[x, y]*106], | |||
| # ], | |||
| # "poses": [pitch, roll, yaw] | |||
| # "poses": [ | |||
| # [pitch, roll, yaw], | |||
| # [pitch, roll, yaw], | |||
| # [pitch, roll, yaw], | |||
| # ], | |||
| # "boxes": [ | |||
| # [x1, y1, x2, y2], | |||
| # [x1, y1, x2, y2], | |||
| # [x1, y1, x2, y2], | |||
| # ] | |||
| # } | |||
| Tasks.face_2d_keypoints: [OutputKeys.KEYPOINTS, OutputKeys.POSES], | |||
| Tasks.face_2d_keypoints: | |||
| [OutputKeys.KEYPOINTS, OutputKeys.POSES, OutputKeys.BOXES], | |||
| # face detection result for single sample | |||
| # { | |||
| @@ -479,17 +491,8 @@ TASK_OUTPUTS = { | |||
| # word segmentation result for single sample | |||
| # { | |||
| # "output": "今天 天气 不错 , 适合 出去 游玩" | |||
| # "labels": [ | |||
| # {'word': '今天', 'label': 'PROPN'}, | |||
| # {'word': '天气', 'label': 'PROPN'}, | |||
| # {'word': '不错', 'label': 'VERB'}, | |||
| # {'word': ',', 'label': 'NUM'}, | |||
| # {'word': '适合', 'label': 'NOUN'}, | |||
| # {'word': '出去', 'label': 'PART'}, | |||
| # {'word': '游玩', 'label': 'ADV'}, | |||
| # ] | |||
| # } | |||
| Tasks.word_segmentation: [OutputKeys.OUTPUT, OutputKeys.LABELS], | |||
| Tasks.word_segmentation: [OutputKeys.OUTPUT], | |||
| # TODO @wenmeng.zwm support list of result check | |||
| # named entity recognition result for single sample | |||
| @@ -699,8 +702,9 @@ TASK_OUTPUTS = { | |||
| # "text_embedding": np.array with shape [1, D], | |||
| # "caption": "this is an image caption text." | |||
| # } | |||
| Tasks.generative_multi_modal_embedding: | |||
| [OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING, OutputKeys.CAPTION], | |||
| Tasks.generative_multi_modal_embedding: [ | |||
| OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING, OutputKeys.CAPTION | |||
| ], | |||
| # multi-modal similarity result for single sample | |||
| # { | |||
| @@ -10,6 +10,7 @@ from typing import Any, Dict, Generator, List, Mapping, Union | |||
| import numpy as np | |||
| from modelscope.hub.utils.utils import create_library_statistics | |||
| from modelscope.models.base import Model | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.outputs import TASK_OUTPUTS | |||
| @@ -151,7 +152,9 @@ class Pipeline(ABC): | |||
| **kwargs) -> Union[Dict[str, Any], Generator]: | |||
| # model provider should leave it as it is | |||
| # modelscope library developer will handle this function | |||
| for single_model in self.models: | |||
| if hasattr(single_model, 'name'): | |||
| create_library_statistics('pipeline', single_model.name, None) | |||
| # place model to cpu or gpu | |||
| if (self.model or (self.has_multiple_models and self.models[0])): | |||
| if not self._model_prepare: | |||
| @@ -93,9 +93,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| 'damo/cv_resnet50_live-category'), | |||
| Tasks.video_category: (Pipelines.video_category, | |||
| 'damo/cv_resnet50_video-category'), | |||
| Tasks.multi_modal_embedding: | |||
| (Pipelines.multi_modal_embedding, | |||
| 'damo/multi-modal_clip-vit-large-patch14_zh'), | |||
| Tasks.multi_modal_embedding: (Pipelines.multi_modal_embedding, | |||
| 'damo/multi-modal_clip-vit-base-patch16_zh'), | |||
| Tasks.generative_multi_modal_embedding: | |||
| (Pipelines.generative_multi_modal_embedding, | |||
| 'damo/multi-modal_gemm-vit-large-patch14_generative-multi-modal-embedding' | |||
| @@ -132,8 +132,8 @@ class Body3DKeypointsPipeline(Pipeline): | |||
| device='gpu' if torch.cuda.is_available() else 'cpu') | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| video_url = input | |||
| video_frames = self.read_video_frames(video_url) | |||
| self.video_url = input | |||
| video_frames = self.read_video_frames(self.video_url) | |||
| if 0 == len(video_frames): | |||
| res = {'success': False, 'msg': 'get video frame failed.'} | |||
| return res | |||
| @@ -198,7 +198,7 @@ class Body3DKeypointsPipeline(Pipeline): | |||
| } | |||
| if not input['success']: | |||
| pass | |||
| res[OutputKeys.OUTPUT_VIDEO] = self.video_url | |||
| else: | |||
| poses = input[KeypointsTypes.POSES_CAMERA] | |||
| pred_3d_pose = poses.data.cpu().numpy()[ | |||
| @@ -1,12 +1,22 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import copy | |||
| import math | |||
| from typing import Any | |||
| import cv2 | |||
| import numpy as np | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import LoadImage | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from .base import EasyCVPipeline | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.face_2d_keypoints, module_name=Pipelines.face_2d_keypoints) | |||
| @@ -29,18 +39,206 @@ class Face2DKeypointsPipeline(EasyCVPipeline): | |||
| *args, | |||
| **kwargs) | |||
| # face detect pipeline | |||
| det_model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' | |||
| self.face_detection = pipeline( | |||
| Tasks.face_detection, model=det_model_id) | |||
| def show_result(self, img, points, scale=2, save_path=None): | |||
| return self.predict_op.show_result(img, points, scale, save_path) | |||
| def _choose_face(self, det_result, min_face=10): | |||
| """ | |||
| choose face with maximum area | |||
| Args: | |||
| det_result: output of face detection pipeline | |||
| min_face: minimum size of valid face w/h | |||
| """ | |||
| bboxes = np.array(det_result[OutputKeys.BOXES]) | |||
| landmarks = np.array(det_result[OutputKeys.KEYPOINTS]) | |||
| if bboxes.shape[0] == 0: | |||
| logger.warn('No face detected!') | |||
| return None | |||
| # face idx with enough size | |||
| face_idx = [] | |||
| for i in range(bboxes.shape[0]): | |||
| box = bboxes[i] | |||
| if (box[2] - box[0]) >= min_face and (box[3] - box[1]) >= min_face: | |||
| face_idx += [i] | |||
| if len(face_idx) == 0: | |||
| logger.warn( | |||
| f'Face size not enough, less than {min_face}x{min_face}!') | |||
| return None | |||
| bboxes = bboxes[face_idx] | |||
| landmarks = landmarks[face_idx] | |||
| return bboxes, landmarks | |||
| def expend_box(self, box, w, h, scalex=0.3, scaley=0.5): | |||
| x1 = box[0] | |||
| y1 = box[1] | |||
| wb = box[2] - x1 | |||
| hb = box[3] - y1 | |||
| deltax = int(wb * scalex) | |||
| deltay1 = int(hb * scaley) | |||
| deltay2 = int(hb * scalex) | |||
| x1 = x1 - deltax | |||
| y1 = y1 - deltay1 | |||
| if x1 < 0: | |||
| deltax = deltax + x1 | |||
| x1 = 0 | |||
| if y1 < 0: | |||
| deltay1 = deltay1 + y1 | |||
| y1 = 0 | |||
| x2 = x1 + wb + 2 * deltax | |||
| y2 = y1 + hb + deltay1 + deltay2 | |||
| x2 = np.clip(x2, 0, w - 1) | |||
| y2 = np.clip(y2, 0, h - 1) | |||
| return [x1, y1, x2, y2] | |||
| def rotate_point(self, angle, center, landmark): | |||
| rad = angle * np.pi / 180.0 | |||
| alpha = np.cos(rad) | |||
| beta = np.sin(rad) | |||
| M = np.zeros((2, 3), dtype=np.float32) | |||
| M[0, 0] = alpha | |||
| M[0, 1] = beta | |||
| M[0, 2] = (1 - alpha) * center[0] - beta * center[1] | |||
| M[1, 0] = -beta | |||
| M[1, 1] = alpha | |||
| M[1, 2] = beta * center[0] + (1 - alpha) * center[1] | |||
| landmark_ = np.asarray([(M[0, 0] * x + M[0, 1] * y + M[0, 2], | |||
| M[1, 0] * x + M[1, 1] * y + M[1, 2]) | |||
| for (x, y) in landmark]) | |||
| return M, landmark_ | |||
| def rotate_crop_img(self, img, pts, M): | |||
| imgT = cv2.warpAffine(img, M, (int(img.shape[1]), int(img.shape[0]))) | |||
| x1 = pts[5][0] | |||
| x2 = pts[5][0] | |||
| y1 = pts[5][1] | |||
| y2 = pts[5][1] | |||
| for i in range(0, 9): | |||
| x1 = min(x1, pts[i][0]) | |||
| x2 = max(x2, pts[i][0]) | |||
| y1 = min(y1, pts[i][1]) | |||
| y2 = max(y2, pts[i][1]) | |||
| height, width, _ = imgT.shape | |||
| x1 = min(max(0, int(x1)), width) | |||
| y1 = min(max(0, int(y1)), height) | |||
| x2 = min(max(0, int(x2)), width) | |||
| y2 = min(max(0, int(y2)), height) | |||
| sub_imgT = imgT[y1:y2, x1:x2] | |||
| return sub_imgT, imgT, [x1, y1, x2, y2] | |||
| def crop_img(self, imgT, pts): | |||
| enlarge_ratio = 1.1 | |||
| x1 = np.min(pts[:, 0]) | |||
| x2 = np.max(pts[:, 0]) | |||
| y1 = np.min(pts[:, 1]) | |||
| y2 = np.max(pts[:, 1]) | |||
| w = x2 - x1 + 1 | |||
| h = y2 - y1 + 1 | |||
| x1 = int(x1 - (enlarge_ratio - 1.0) / 2.0 * w) | |||
| y1 = int(y1 - (enlarge_ratio - 1.0) / 2.0 * h) | |||
| x1 = max(0, x1) | |||
| y1 = max(0, y1) | |||
| new_w = int(enlarge_ratio * w) | |||
| new_h = int(enlarge_ratio * h) | |||
| new_x1 = x1 | |||
| new_y1 = y1 | |||
| new_x2 = new_x1 + new_w | |||
| new_y2 = new_y1 + new_h | |||
| height, width, _ = imgT.shape | |||
| new_x1 = min(max(0, new_x1), width) | |||
| new_y1 = min(max(0, new_y1), height) | |||
| new_x2 = max(min(width, new_x2), 0) | |||
| new_y2 = max(min(height, new_y2), 0) | |||
| sub_imgT = imgT[new_y1:new_y2, new_x1:new_x2] | |||
| return sub_imgT, [new_x1, new_y1, new_x2, new_y2] | |||
| def __call__(self, inputs) -> Any: | |||
| outputs = self.predict_op(inputs) | |||
| img = LoadImage.convert_to_ndarray(inputs) | |||
| h, w, c = img.shape | |||
| img_rgb = copy.deepcopy(img) | |||
| img_rgb = img_rgb[:, :, ::-1] | |||
| det_result = self.face_detection(img_rgb) | |||
| bboxes = np.array(det_result[OutputKeys.BOXES]) | |||
| if bboxes.shape[0] == 0: | |||
| logger.warn('No face detected!') | |||
| results = { | |||
| OutputKeys.KEYPOINTS: [], | |||
| OutputKeys.POSES: [], | |||
| OutputKeys.BOXES: [] | |||
| } | |||
| return results | |||
| boxes, keypoints = self._choose_face(det_result) | |||
| output_boxes = [] | |||
| output_keypoints = [] | |||
| output_poses = [] | |||
| for index, box_ori in enumerate(boxes): | |||
| box = self.expend_box(box_ori, w, h, scalex=0.1, scaley=0.1) | |||
| y0 = int(box[1]) | |||
| y1 = int(box[3]) | |||
| x0 = int(box[0]) | |||
| x1 = int(box[2]) | |||
| sub_img = img[y0:y1, x0:x1] | |||
| keypoint = keypoints[index] | |||
| pts = [[keypoint[0], keypoint[1]], [keypoint[2], keypoint[3]], | |||
| [keypoint[4], keypoint[5]], [keypoint[6], keypoint[7]], | |||
| [keypoint[8], keypoint[9]], [box[0], box[1]], | |||
| [box[2], box[1]], [box[0], box[3]], [box[2], box[3]]] | |||
| # radian | |||
| angle = math.atan2((pts[1][1] - pts[0][1]), | |||
| (pts[1][0] - pts[0][0])) | |||
| # angle | |||
| theta = angle * (180 / np.pi) | |||
| center = [w // 2, h // 2] | |||
| cx, cy = center | |||
| M, landmark_ = self.rotate_point(theta, (cx, cy), pts) | |||
| sub_imgT, imgT, bbox = self.rotate_crop_img(img, landmark_, M) | |||
| outputs = self.predict_op([sub_imgT])[0] | |||
| tmp_keypoints = outputs['point'] | |||
| for idx in range(0, len(tmp_keypoints)): | |||
| tmp_keypoints[idx][0] += bbox[0] | |||
| tmp_keypoints[idx][1] += bbox[1] | |||
| for idx in range(0, 6): | |||
| sub_img, bbox = self.crop_img(imgT, tmp_keypoints) | |||
| outputs = self.predict_op([sub_img])[0] | |||
| tmp_keypoints = outputs['point'] | |||
| for idx in range(0, len(tmp_keypoints)): | |||
| tmp_keypoints[idx][0] += bbox[0] | |||
| tmp_keypoints[idx][1] += bbox[1] | |||
| M2, tmp_keypoints = self.rotate_point(-theta, (cx, cy), | |||
| tmp_keypoints) | |||
| results = [{ | |||
| OutputKeys.KEYPOINTS: output['point'], | |||
| OutputKeys.POSES: output['pose'] | |||
| } for output in outputs] | |||
| output_keypoints.append(np.array(tmp_keypoints)) | |||
| output_poses.append(np.array(outputs['pose'])) | |||
| output_boxes.append(np.array(box_ori)) | |||
| if self._is_single_inputs(inputs): | |||
| results = results[0] | |||
| results = { | |||
| OutputKeys.KEYPOINTS: output_keypoints, | |||
| OutputKeys.POSES: output_poses, | |||
| OutputKeys.BOXES: output_boxes | |||
| } | |||
| return results | |||
| @@ -109,13 +109,13 @@ class TokenClassificationPipeline(Pipeline): | |||
| chunk['span'] = text[chunk['start']:chunk['end']] | |||
| chunks.append(chunk) | |||
| # for cws output | |||
| # for cws outputs | |||
| if len(chunks) > 0 and chunks[0]['type'] == 'cws': | |||
| spans = [ | |||
| chunk['span'] for chunk in chunks if chunk['span'].strip() | |||
| ] | |||
| seg_result = ' '.join(spans) | |||
| outputs = {OutputKeys.OUTPUT: seg_result, OutputKeys.LABELS: []} | |||
| outputs = {OutputKeys.OUTPUT: seg_result} | |||
| # for ner outputs | |||
| else: | |||
| @@ -115,15 +115,15 @@ class WordSegmentationPipeline(Pipeline): | |||
| chunk['span'] = text[chunk['start']:chunk['end']] | |||
| chunks.append(chunk) | |||
| # for cws output | |||
| # for cws outputs | |||
| if len(chunks) > 0 and chunks[0]['type'] == 'cws': | |||
| spans = [ | |||
| chunk['span'] for chunk in chunks if chunk['span'].strip() | |||
| ] | |||
| seg_result = ' '.join(spans) | |||
| outputs = {OutputKeys.OUTPUT: seg_result, OutputKeys.LABELS: []} | |||
| outputs = {OutputKeys.OUTPUT: seg_result} | |||
| # for ner outpus | |||
| # for ner output | |||
| else: | |||
| outputs = {OutputKeys.OUTPUT: chunks} | |||
| return outputs | |||
| @@ -77,7 +77,7 @@ class OfaPreprocessor(Preprocessor): | |||
| data[key] = item | |||
| return data | |||
| def _ofa_input_compatibility_conversion(self, data): | |||
| def _ofa_input_compatibility_conversion(self, data): # fake | |||
| if 'image' in data and self.cfg.model.get('type', None) == 'ofa': | |||
| if isinstance(data['image'], str): | |||
| image = load_image(data['image']) | |||
| @@ -34,6 +34,7 @@ class NLPBasePreprocessor(Preprocessor, ABC): | |||
| label=None, | |||
| label2id=None, | |||
| mode=ModeKeys.INFERENCE, | |||
| use_fast=None, | |||
| **kwargs): | |||
| """The NLP preprocessor base class. | |||
| @@ -45,14 +46,18 @@ class NLPBasePreprocessor(Preprocessor, ABC): | |||
| label2id: An optional label2id mapping, the class will try to call utils.parse_label_mapping | |||
| if this mapping is not supplied. | |||
| mode: Run this preprocessor in either 'train'/'eval'/'inference' mode | |||
| use_fast: use the fast version of tokenizer | |||
| """ | |||
| self.model_dir = model_dir | |||
| self.first_sequence = first_sequence | |||
| self.second_sequence = second_sequence | |||
| self.label = label | |||
| self.use_fast = kwargs.pop('use_fast', None) | |||
| if self.use_fast is None and os.path.isfile( | |||
| self.use_fast = use_fast | |||
| if self.use_fast is None and model_dir is None: | |||
| self.use_fast = False | |||
| elif self.use_fast is None and os.path.isfile( | |||
| os.path.join(model_dir, 'tokenizer_config.json')): | |||
| with open(os.path.join(model_dir, 'tokenizer_config.json'), | |||
| 'r') as f: | |||
| @@ -61,8 +66,8 @@ class NLPBasePreprocessor(Preprocessor, ABC): | |||
| self.use_fast = False if self.use_fast is None else self.use_fast | |||
| self.label2id = label2id | |||
| if self.label2id is None: | |||
| self.label2id = parse_label_mapping(self.model_dir) | |||
| if self.label2id is None and model_dir is not None: | |||
| self.label2id = parse_label_mapping(model_dir) | |||
| super().__init__(mode, **kwargs) | |||
| @property | |||
| @@ -106,6 +111,7 @@ class NLPTokenizerPreprocessorBase(NLPBasePreprocessor): | |||
| label: str = 'label', | |||
| label2id: dict = None, | |||
| mode: str = ModeKeys.INFERENCE, | |||
| use_fast: bool = None, | |||
| **kwargs): | |||
| """The NLP tokenizer preprocessor base class. | |||
| @@ -122,11 +128,12 @@ class NLPTokenizerPreprocessorBase(NLPBasePreprocessor): | |||
| - config.json label2id/id2label | |||
| - label_mapping.json | |||
| mode: Run this preprocessor in either 'train'/'eval'/'inference' mode, the behavior may be different. | |||
| use_fast: use the fast version of tokenizer | |||
| kwargs: These kwargs will be directly fed into the tokenizer. | |||
| """ | |||
| super().__init__(model_dir, first_sequence, second_sequence, label, | |||
| label2id, mode) | |||
| label2id, mode, use_fast, **kwargs) | |||
| self.model_dir = model_dir | |||
| self.tokenize_kwargs = kwargs | |||
| self.tokenizer = self.build_tokenizer(model_dir) | |||
| @@ -2,6 +2,7 @@ | |||
| from typing import Any, Dict, Tuple, Union | |||
| import numpy as np | |||
| import torch | |||
| from modelscope.metainfo import Preprocessors | |||
| @@ -20,9 +21,7 @@ class WordSegmentationBlankSetToLabelPreprocessor(NLPBasePreprocessor): | |||
| """ | |||
| def __init__(self, **kwargs): | |||
| super().__init__(**kwargs) | |||
| self.first_sequence: str = kwargs.pop('first_sequence', | |||
| 'first_sequence') | |||
| self.first_sequence: str = kwargs.pop('first_sequence', 'tokens') | |||
| self.label = kwargs.pop('label', OutputKeys.LABELS) | |||
| def __call__(self, data: str) -> Union[Dict[str, Any], Tuple]: | |||
| @@ -80,10 +79,9 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| 'is_split_into_words', False) | |||
| if 'label2id' in kwargs: | |||
| kwargs.pop('label2id') | |||
| self.tokenize_kwargs = kwargs | |||
| @type_assert(object, str) | |||
| def __call__(self, data: str) -> Dict[str, Any]: | |||
| @type_assert(object, (str, dict)) | |||
| def __call__(self, data: Union[dict, str]) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| Args: | |||
| @@ -99,18 +97,24 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| text = None | |||
| labels_list = None | |||
| if isinstance(data, str): | |||
| # for inference inputs without label | |||
| text = data | |||
| self.tokenize_kwargs['add_special_tokens'] = False | |||
| elif isinstance(data, dict): | |||
| # for finetune inputs with label | |||
| text = data.get(self.first_sequence) | |||
| labels_list = data.get(self.label) | |||
| if isinstance(text, list): | |||
| self.tokenize_kwargs['is_split_into_words'] = True | |||
| input_ids = [] | |||
| label_mask = [] | |||
| offset_mapping = [] | |||
| if self.is_split_into_words: | |||
| for offset, token in enumerate(list(data)): | |||
| subtoken_ids = self.tokenizer.encode( | |||
| token, add_special_tokens=False) | |||
| token_type_ids = [] | |||
| if self.is_split_into_words and self._mode == ModeKeys.INFERENCE: | |||
| for offset, token in enumerate(list(text)): | |||
| subtoken_ids = self.tokenizer.encode(token, | |||
| **self.tokenize_kwargs) | |||
| if len(subtoken_ids) == 0: | |||
| subtoken_ids = [self.tokenizer.unk_token_id] | |||
| input_ids.extend(subtoken_ids) | |||
| @@ -119,10 +123,9 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| else: | |||
| if self.tokenizer.is_fast: | |||
| encodings = self.tokenizer( | |||
| text, | |||
| add_special_tokens=False, | |||
| return_offsets_mapping=True, | |||
| **self.tokenize_kwargs) | |||
| text, return_offsets_mapping=True, **self.tokenize_kwargs) | |||
| attention_mask = encodings['attention_mask'] | |||
| token_type_ids = encodings['token_type_ids'] | |||
| input_ids = encodings['input_ids'] | |||
| word_ids = encodings.word_ids() | |||
| for i in range(len(word_ids)): | |||
| @@ -137,75 +140,85 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| label_mask.append(1) | |||
| offset_mapping.append(encodings['offset_mapping'][i]) | |||
| else: | |||
| encodings = self.tokenizer( | |||
| text, add_special_tokens=False, **self.tokenize_kwargs) | |||
| encodings = self.tokenizer(text, **self.tokenize_kwargs) | |||
| input_ids = encodings['input_ids'] | |||
| label_mask, offset_mapping = self.get_label_mask_and_offset_mapping( | |||
| text) | |||
| if len(input_ids) >= self.sequence_length - 2: | |||
| input_ids = input_ids[:self.sequence_length - 2] | |||
| label_mask = label_mask[:self.sequence_length - 2] | |||
| input_ids = [self.tokenizer.cls_token_id | |||
| ] + input_ids + [self.tokenizer.sep_token_id] | |||
| label_mask = [0] + label_mask + [0] | |||
| attention_mask = [1] * len(input_ids) | |||
| offset_mapping = offset_mapping[:sum(label_mask)] | |||
| if self._mode == ModeKeys.INFERENCE: | |||
| if len(input_ids) >= self.sequence_length - 2: | |||
| input_ids = input_ids[:self.sequence_length - 2] | |||
| label_mask = label_mask[:self.sequence_length - 2] | |||
| input_ids = [self.tokenizer.cls_token_id | |||
| ] + input_ids + [self.tokenizer.sep_token_id] | |||
| label_mask = [0] + label_mask + [0] | |||
| attention_mask = [1] * len(input_ids) | |||
| offset_mapping = offset_mapping[:sum(label_mask)] | |||
| if not self.is_transformer_based_model: | |||
| input_ids = input_ids[1:-1] | |||
| attention_mask = attention_mask[1:-1] | |||
| label_mask = label_mask[1:-1] | |||
| if not self.is_transformer_based_model: | |||
| input_ids = input_ids[1:-1] | |||
| attention_mask = attention_mask[1:-1] | |||
| label_mask = label_mask[1:-1] | |||
| if self._mode == ModeKeys.INFERENCE: | |||
| input_ids = torch.tensor(input_ids).unsqueeze(0) | |||
| attention_mask = torch.tensor(attention_mask).unsqueeze(0) | |||
| label_mask = torch.tensor( | |||
| label_mask, dtype=torch.bool).unsqueeze(0) | |||
| # the token classification | |||
| output = { | |||
| 'text': text, | |||
| 'input_ids': input_ids, | |||
| 'attention_mask': attention_mask, | |||
| 'label_mask': label_mask, | |||
| 'offset_mapping': offset_mapping | |||
| } | |||
| # align the labels with tokenized text | |||
| if labels_list is not None: | |||
| assert self.label2id is not None | |||
| # Map that sends B-Xxx label to its I-Xxx counterpart | |||
| b_to_i_label = [] | |||
| label_enumerate_values = [ | |||
| k for k, v in sorted( | |||
| self.label2id.items(), key=lambda item: item[1]) | |||
| ] | |||
| for idx, label in enumerate(label_enumerate_values): | |||
| if label.startswith('B-') and label.replace( | |||
| 'B-', 'I-') in label_enumerate_values: | |||
| b_to_i_label.append( | |||
| label_enumerate_values.index( | |||
| label.replace('B-', 'I-'))) | |||
| else: | |||
| b_to_i_label.append(idx) | |||
| # the token classification | |||
| output = { | |||
| 'text': text, | |||
| 'input_ids': input_ids, | |||
| 'attention_mask': attention_mask, | |||
| 'label_mask': label_mask, | |||
| 'offset_mapping': offset_mapping | |||
| } | |||
| else: | |||
| output = { | |||
| 'input_ids': input_ids, | |||
| 'token_type_ids': token_type_ids, | |||
| 'attention_mask': attention_mask, | |||
| 'label_mask': label_mask, | |||
| } | |||
| label_row = [self.label2id[lb] for lb in labels_list] | |||
| previous_word_idx = None | |||
| label_ids = [] | |||
| for word_idx in word_ids: | |||
| if word_idx is None: | |||
| label_ids.append(-100) | |||
| elif word_idx != previous_word_idx: | |||
| label_ids.append(label_row[word_idx]) | |||
| else: | |||
| if self.label_all_tokens: | |||
| label_ids.append(b_to_i_label[label_row[word_idx]]) | |||
| # align the labels with tokenized text | |||
| if labels_list is not None: | |||
| assert self.label2id is not None | |||
| # Map that sends B-Xxx label to its I-Xxx counterpart | |||
| b_to_i_label = [] | |||
| label_enumerate_values = [ | |||
| k for k, v in sorted( | |||
| self.label2id.items(), key=lambda item: item[1]) | |||
| ] | |||
| for idx, label in enumerate(label_enumerate_values): | |||
| if label.startswith('B-') and label.replace( | |||
| 'B-', 'I-') in label_enumerate_values: | |||
| b_to_i_label.append( | |||
| label_enumerate_values.index( | |||
| label.replace('B-', 'I-'))) | |||
| else: | |||
| b_to_i_label.append(idx) | |||
| label_row = [self.label2id[lb] for lb in labels_list] | |||
| previous_word_idx = None | |||
| label_ids = [] | |||
| for word_idx in word_ids: | |||
| if word_idx is None: | |||
| label_ids.append(-100) | |||
| previous_word_idx = word_idx | |||
| labels = label_ids | |||
| output['labels'] = labels | |||
| elif word_idx != previous_word_idx: | |||
| label_ids.append(label_row[word_idx]) | |||
| else: | |||
| if self.label_all_tokens: | |||
| label_ids.append(b_to_i_label[label_row[word_idx]]) | |||
| else: | |||
| label_ids.append(-100) | |||
| previous_word_idx = word_idx | |||
| labels = label_ids | |||
| output['labels'] = labels | |||
| output = { | |||
| k: np.array(v) if isinstance(v, list) else v | |||
| for k, v in output.items() | |||
| } | |||
| return output | |||
| def get_tokenizer_class(self): | |||
| @@ -74,8 +74,8 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): | |||
| self.patch_resize_transform = transforms.Compose([ | |||
| lambda image: ocr_resize( | |||
| image, | |||
| self.cfg.model.patch_image_size, | |||
| is_document=self.cfg.model.is_document), | |||
| self.patch_image_size, | |||
| is_document=self.cfg.model.get('is_document', False)), | |||
| transforms.ToTensor(), | |||
| transforms.Normalize(mean=self.mean, std=self.std), | |||
| ]) | |||
| @@ -69,11 +69,14 @@ class KWSFarfieldTrainer(BaseTrainer): | |||
| super().__init__(cfg_file, arg_parse_fn) | |||
| self.model = self.build_model() | |||
| self.work_dir = work_dir | |||
| # the number of model output dimension | |||
| # should update config outside the trainer, if user need more wake word | |||
| num_syn = kwargs.get('num_syn', None) | |||
| if num_syn: | |||
| self.cfg.model.num_syn = num_syn | |||
| self._num_classes = self.cfg.model.num_syn | |||
| self.model = self.build_model() | |||
| self.work_dir = work_dir | |||
| if kwargs.get('launcher', None) is not None: | |||
| init_dist(kwargs['launcher']) | |||
| @@ -103,20 +103,20 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||
| def __init__(self, args): | |||
| super().__init__() | |||
| self.sentence_avg = args.sentence_avg | |||
| self.eps = args.label_smoothing | |||
| self.ignore_prefix_size = args.ignore_prefix_size | |||
| self.ignore_eos = args.ignore_eos | |||
| self.report_accuracy = args.report_accuracy | |||
| self.drop_worst_ratio = args.drop_worst_ratio | |||
| self.drop_worst_after = args.drop_worst_after | |||
| self.use_rdrop = args.use_rdrop | |||
| self.reg_alpha = args.reg_alpha | |||
| self.sample_patch_num = args.sample_patch_num | |||
| self.sentence_avg = args.get('sentence_avg', False) | |||
| self.eps = args.get('label_smoothing', 0.1) | |||
| self.ignore_prefix_size = args.get('ignore_prefix_size', 0) | |||
| self.ignore_eos = args.get('ignore_eos', False) | |||
| self.report_accuracy = args.get('report_accuracy', False) | |||
| self.drop_worst_ratio = args.get('drop_worst_ratio', 0.0) | |||
| self.drop_worst_after = args.get('drop_worst_after', 0) | |||
| self.use_rdrop = args.get('use_rdrop', False) | |||
| self.reg_alpha = args.get('reg_alpha', 1.0) | |||
| self.sample_patch_num = args.get('sample_patch_num', 196) | |||
| self.constraint_start = None | |||
| self.constraint_end = None | |||
| if args.constraint_range: | |||
| if args.get('constraint_range', None): | |||
| constraint_start, constraint_end = args.constraint_range.split(',') | |||
| self.constraint_start = int(constraint_start) | |||
| self.constraint_end = int(constraint_end) | |||
| @@ -18,7 +18,7 @@ class TextGenerationTrainer(NlpEpochBasedTrainer): | |||
| return tokenizer.decode(tokens.tolist(), skip_special_tokens=True) | |||
| def evaluation_step(self, data): | |||
| model = self.model | |||
| model = self.model.module if self._dist else self.model | |||
| model.eval() | |||
| with torch.no_grad(): | |||
| @@ -586,14 +586,16 @@ class NlpEpochBasedTrainer(EpochBasedTrainer): | |||
| preprocessor_mode=ModeKeys.TRAIN, | |||
| **model_args, | |||
| **self.train_keys, | |||
| mode=ModeKeys.TRAIN) | |||
| mode=ModeKeys.TRAIN, | |||
| use_fast=True) | |||
| eval_preprocessor = Preprocessor.from_pretrained( | |||
| self.model_dir, | |||
| cfg_dict=self.cfg, | |||
| preprocessor_mode=ModeKeys.EVAL, | |||
| **model_args, | |||
| **self.eval_keys, | |||
| mode=ModeKeys.EVAL) | |||
| mode=ModeKeys.EVAL, | |||
| use_fast=True) | |||
| return train_preprocessor, eval_preprocessor | |||
| @@ -15,6 +15,7 @@ from torch.utils.data.dataloader import default_collate | |||
| from torch.utils.data.distributed import DistributedSampler | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.hub.utils.utils import create_library_statistics | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.metrics import build_metric, task_default_metrics | |||
| from modelscope.models.base import Model, TorchModel | |||
| @@ -436,6 +437,8 @@ class EpochBasedTrainer(BaseTrainer): | |||
| def train(self, checkpoint_path=None, *args, **kwargs): | |||
| self._mode = ModeKeys.TRAIN | |||
| if hasattr(self.model, 'name'): | |||
| create_library_statistics('train', self.model.name, None) | |||
| if self.train_dataset is None: | |||
| self.train_dataloader = self.get_train_dataloader() | |||
| @@ -456,6 +459,8 @@ class EpochBasedTrainer(BaseTrainer): | |||
| self.train_loop(self.train_dataloader) | |||
| def evaluate(self, checkpoint_path=None): | |||
| if hasattr(self.model, 'name'): | |||
| create_library_statistics('evaluate', self.model.name, None) | |||
| if checkpoint_path is not None and os.path.isfile(checkpoint_path): | |||
| from modelscope.trainers.hooks import CheckpointHook | |||
| CheckpointHook.load_checkpoint(checkpoint_path, self) | |||
| @@ -876,7 +881,7 @@ class EpochBasedTrainer(BaseTrainer): | |||
| Subclass and override to inject custom behavior. | |||
| """ | |||
| model = self.model | |||
| model = self.model.module if self._dist else self.model | |||
| model.eval() | |||
| if is_parallel(model): | |||
| @@ -238,6 +238,14 @@ class DownloadMode(enum.Enum): | |||
| FORCE_REDOWNLOAD = 'force_redownload' | |||
| class DownloadChannel(enum.Enum): | |||
| """ Channels of datasets downloading for uv/pv counting. | |||
| """ | |||
| LOCAL = 'local' | |||
| DSW = 'dsw' | |||
| EAIS = 'eais' | |||
| class UploadMode(enum.Enum): | |||
| """ How to upload object to remote. | |||
| """ | |||
| @@ -91,6 +91,71 @@ def draw_keypoints(output, original_image): | |||
| return image | |||
| def draw_106face_keypoints(in_path, | |||
| keypoints, | |||
| boxes, | |||
| scale=4.0, | |||
| save_path=None): | |||
| face_contour_point_index = [ | |||
| 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, | |||
| 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32 | |||
| ] | |||
| left_eye_brow_point_index = [33, 34, 35, 36, 37, 38, 39, 40, 41, 33] | |||
| right_eye_brow_point_index = [42, 43, 44, 45, 46, 47, 48, 49, 50, 42] | |||
| left_eye_point_index = [66, 67, 68, 69, 70, 71, 72, 73, 66] | |||
| right_eye_point_index = [75, 76, 77, 78, 79, 80, 81, 82, 75] | |||
| nose_bridge_point_index = [51, 52, 53, 54] | |||
| nose_contour_point_index = [55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65] | |||
| mouth_outer_point_index = [ | |||
| 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 84 | |||
| ] | |||
| mouth_inter_point_index = [96, 97, 98, 99, 100, 101, 102, 103, 96] | |||
| img = cv2.imread(in_path) | |||
| for i in range(len(boxes)): | |||
| draw_box(img, np.array(boxes[i])) | |||
| image = cv2.resize(img, dsize=None, fx=scale, fy=scale) | |||
| def draw_line(point_index, image, point): | |||
| for i in range(len(point_index) - 1): | |||
| cur_index = point_index[i] | |||
| next_index = point_index[i + 1] | |||
| cur_pt = (int(point[cur_index][0] * scale), | |||
| int(point[cur_index][1] * scale)) | |||
| next_pt = (int(point[next_index][0] * scale), | |||
| int(point[next_index][1] * scale)) | |||
| cv2.line(image, cur_pt, next_pt, (0, 0, 255), thickness=2) | |||
| for i in range(len(keypoints)): | |||
| points = keypoints[i] | |||
| draw_line(face_contour_point_index, image, points) | |||
| draw_line(left_eye_brow_point_index, image, points) | |||
| draw_line(right_eye_brow_point_index, image, points) | |||
| draw_line(left_eye_point_index, image, points) | |||
| draw_line(right_eye_point_index, image, points) | |||
| draw_line(nose_bridge_point_index, image, points) | |||
| draw_line(nose_contour_point_index, image, points) | |||
| draw_line(mouth_outer_point_index, image, points) | |||
| draw_line(mouth_inter_point_index, image, points) | |||
| size = len(points) | |||
| for i in range(size): | |||
| x = int(points[i][0]) | |||
| y = int(points[i][1]) | |||
| cv2.putText(image, str(i), (int(x * scale), int(y * scale)), | |||
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) | |||
| cv2.circle(image, (int(x * scale), int(y * scale)), 2, (0, 255, 0), | |||
| cv2.FILLED) | |||
| if save_path is not None: | |||
| cv2.imwrite(save_path, image) | |||
| return image | |||
| def draw_face_detection_no_lm_result(img_path, detection_result): | |||
| bboxes = np.array(detection_result[OutputKeys.BOXES]) | |||
| scores = np.array(detection_result[OutputKeys.SCORES]) | |||
| @@ -1,6 +1,7 @@ | |||
| addict | |||
| attrs | |||
| datasets | |||
| # version beyond 2.5.2 introduces compatbility issue and is being resolved | |||
| datasets<=2.5.2 | |||
| easydict | |||
| einops | |||
| filelock>=3.3.0 | |||
| @@ -2,6 +2,8 @@ ftfy>=6.0.3 | |||
| ofa>=0.0.2 | |||
| pycocoevalcap>=1.2 | |||
| pycocotools>=2.0.4 | |||
| # compatible with taming-transformers-rom1504 | |||
| pytorch_lightning<=1.7.7 | |||
| # rough-score was just recently updated from 0.0.4 to 0.0.7 | |||
| # which introduced compatability issues that are being investigated | |||
| rouge_score<=0.0.4 | |||
| @@ -1,4 +1,6 @@ | |||
| biopython | |||
| iopath | |||
| ipdb | |||
| lmdb | |||
| ml_collections | |||
| scipy | |||
| @@ -8,7 +8,8 @@ import zipfile | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.msdatasets.utils.dataset_utils import list_dataset_objects | |||
| from modelscope.utils import logger as logging | |||
| from modelscope.utils.constant import DEFAULT_DATASET_REVISION, ModelFile | |||
| from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, DownloadMode, | |||
| ModelFile) | |||
| from modelscope.utils.test_utils import test_level | |||
| logger = logging.get_logger(__name__) | |||
| @@ -104,7 +105,10 @@ class DatasetUploadTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_ds_download_dir(self): | |||
| test_ds = MsDataset.load(self.dataset_name, self.namespace) | |||
| test_ds = MsDataset.load( | |||
| self.dataset_name, | |||
| namespace=self.namespace, | |||
| download_mode=DownloadMode.FORCE_REDOWNLOAD) | |||
| assert test_ds.config_kwargs['split_config'].values() | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| @@ -21,9 +21,10 @@ class TestModelOutput(unittest.TestCase): | |||
| self.assertEqual(outputs['logits'], torch.Tensor([1])) | |||
| self.assertEqual(outputs[0], torch.Tensor([1])) | |||
| self.assertEqual(outputs.logits, torch.Tensor([1])) | |||
| outputs.loss = torch.Tensor([2]) | |||
| logits, loss = outputs | |||
| self.assertEqual(logits, torch.Tensor([1])) | |||
| self.assertTrue(loss is None) | |||
| self.assertTrue(loss is not None) | |||
| if __name__ == '__main__': | |||
| @@ -1,11 +1,10 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| import cv2 | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.cv.image_utils import draw_106face_keypoints | |||
| from modelscope.utils.test_utils import test_level | |||
| @@ -13,7 +12,7 @@ class EasyCVFace2DKeypointsPipelineTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_face_2d_keypoints(self): | |||
| img_path = 'data/test/images/keypoints_detect/test_img_face_2d_keypoints.png' | |||
| img_path = 'data/test/images/face_detection.png' | |||
| model_id = 'damo/cv_mobilenet_face-2d-keypoints_alignment' | |||
| face_2d_keypoints_align = pipeline( | |||
| @@ -21,15 +20,21 @@ class EasyCVFace2DKeypointsPipelineTest(unittest.TestCase): | |||
| output = face_2d_keypoints_align(img_path) | |||
| output_keypoints = output[OutputKeys.KEYPOINTS] | |||
| output_pose = output[OutputKeys.POSES] | |||
| img = cv2.imread(img_path) | |||
| img = face_2d_keypoints_align.show_result( | |||
| img, output_keypoints, scale=2, save_path='face_keypoints.jpg') | |||
| self.assertEqual(output_keypoints.shape[0], 106) | |||
| self.assertEqual(output_keypoints.shape[1], 2) | |||
| self.assertEqual(output_pose.shape[0], 3) | |||
| output_poses = output[OutputKeys.POSES] | |||
| output_boxes = output[OutputKeys.BOXES] | |||
| draw_106face_keypoints( | |||
| img_path, | |||
| output_keypoints, | |||
| output_boxes, | |||
| scale=2, | |||
| save_path='face_keypoints.jpg') | |||
| for idx in range(len(output_keypoints)): | |||
| self.assertEqual(output_keypoints[idx].shape[0], 106) | |||
| self.assertEqual(output_keypoints[idx].shape[1], 2) | |||
| self.assertEqual(output_poses[idx].shape[0], 3) | |||
| self.assertEqual(output_boxes[idx].shape[0], 4) | |||
| if __name__ == '__main__': | |||
| @@ -19,9 +19,11 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| self.task = Tasks.named_entity_recognition | |||
| self.model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' | |||
| english_model_id = 'damo/nlp_raner_named-entity-recognition_english-large-ecom' | |||
| tcrf_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' | |||
| lcrf_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-news' | |||
| sentence = '这与温岭市新河镇的一个神秘的传说有关。' | |||
| sentence_en = 'pizza shovel' | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_tcrf_by_direct_model_download(self): | |||
| @@ -89,6 +91,12 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| task=Tasks.named_entity_recognition, model=self.lcrf_model_id) | |||
| print(pipeline_ins(input=self.sentence)) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_english_with_model_name(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.named_entity_recognition, model=self.english_model_id) | |||
| print(pipeline_ins(input='pizza shovel')) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipeline_ins = pipeline(task=Tasks.named_entity_recognition) | |||
| @@ -19,7 +19,7 @@ class UnifoldProteinStructureTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| self.protein_multimer = 'GAMGLPEEPSSPQESTLKALSLYEAHLSSYIMYLQTFLVKTKQKVNNKNYPEFTLFDTSKLKKDQTLKSIKT' + \ | |||
| 'NIAALKNHIDKIKPIAMQIYKKYSKNIP' | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_by_direct_model_download(self): | |||
| model_dir = snapshot_download(self.model_id) | |||
| mono_pipeline_ins = pipeline(task=self.task, model=model_dir) | |||
| @@ -87,7 +87,7 @@ class TestFinetuneTokenClassification(unittest.TestCase): | |||
| cfg['dataset'] = { | |||
| 'train': { | |||
| 'labels': label_enumerate_values, | |||
| 'first_sequence': 'first_sequence', | |||
| 'first_sequence': 'tokens', | |||
| 'label': 'labels', | |||
| } | |||
| } | |||