diff --git a/.dev_scripts/dockerci.sh b/.dev_scripts/dockerci.sh index c502175b..07ea947a 100644 --- a/.dev_scripts/dockerci.sh +++ b/.dev_scripts/dockerci.sh @@ -37,6 +37,7 @@ do -e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \ -e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \ -e TEST_LEVEL=$TEST_LEVEL \ + -e MODELSCOPE_ENVIRONMENT='ci' \ -e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \ -e MODEL_TAG_URL=$MODEL_TAG_URL \ --workdir=$CODE_DIR_IN_CONTAINER \ @@ -59,6 +60,7 @@ do -e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \ -e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \ -e TEST_LEVEL=$TEST_LEVEL \ + -e MODELSCOPE_ENVIRONMENT='ci' \ -e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \ -e MODEL_TAG_URL=$MODEL_TAG_URL \ --workdir=$CODE_DIR_IN_CONTAINER \ diff --git a/README.md b/README.md index 944c1f07..1da48ef2 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,26 @@ # Introduction -ModelScope library is targeted to support training, evaluation and inference for the state of the art models provided by Mind and further support third-party models provided by users outside alibaba. +[ModelScope]( https://www.modelscope.cn) is a “Model-as-a-Service” (MaaS) platform that seeks to bringing together most advanced machine learning models from the AI community, and to streamlining the process of leveraging and applying AI models . The core ModelScope library enables developers to perform model inference, training and evaluation, through rich layers of API designs that facilitate a unified experience across state-of-the-art models from different AI domains. -# Design doc +The Python library offers the layered-APIs necessary for model contributors to integrate models from CV, NLP, Speech, Multi-Modality, as well as Scientific-computation, into the ModelScope ecosystem. Implementations for all these different models are encapsulated within the library in a way that allows easy and unified access. With such integration, model inference, finetuning, and evaluations can be done within only a few lines of codes. In the meantime, flexibilities are provided so that different components in the model applications can be customized as well, where necessary. -Please refer to alidoc [link](https://alidocs.dingtalk.com/i/nodes/OBldywvrKxo89xmAO05yJQk2ngpNbLz4?nav=spaces&navQuery=spaceId%3Dnb9XJNlZxbgrOXyA&iframeQuery=utm_source%3Dportal%26utm_medium%3Dportal_space_file_tree) +Apart from harboring implementations of various models, ModelScope library also enables the necessary interactions with the backend services of ModelScope, particularly with the Model-Hub and Dataset-Hub. Such interactions facilitate various entity (models and datasets) management to be performed seamlessly under-the-hood, such as entity lookup, version control, and cache management. -# Development doc +# Installation -Please refer to [develop.md](docs/source/develop.md) +Please refer to [installation](https://modelscope.cn/docs/%E7%8E%AF%E5%A2%83%E5%AE%89%E8%A3%85). -# ChangeLog -* 20/05/2022 First release version +# Get Started -Refer to [change_log.md](docs/source/change_log.md) for more details +You can refer to [quick_start](https://modelscope.cn/docs/%E5%BF%AB%E9%80%9F%E5%BC%80%E5%A7%8B) for quick start. + +We also provide other documentations including: +* [Introduction to tasks](https://modelscope.cn/docs/%E4%BB%BB%E5%8A%A1%E7%9A%84%E4%BB%8B%E7%BB%8D) +* [Use pipeline for model inference](https://modelscope.cn/docs/%E6%A8%A1%E5%9E%8B%E7%9A%84%E6%8E%A8%E7%90%86Pipeline) +* [Finetune example](https://modelscope.cn/docs/%E6%A8%A1%E5%9E%8B%E7%9A%84%E8%AE%AD%E7%BB%83Train) +* [Preprocessing of data](https://modelscope.cn/docs/%E6%95%B0%E6%8D%AE%E7%9A%84%E9%A2%84%E5%A4%84%E7%90%86) +* [Evaluation metrics](https://modelscope.cn/docs/%E6%A8%A1%E5%9E%8B%E7%9A%84%E8%AF%84%E4%BC%B0) + +# License + +This project is licensed under the [Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE). diff --git a/modelscope/exporters/torch_model_exporter.py b/modelscope/exporters/torch_model_exporter.py index 7bf6c0c0..1d332591 100644 --- a/modelscope/exporters/torch_model_exporter.py +++ b/modelscope/exporters/torch_model_exporter.py @@ -128,7 +128,7 @@ class TorchModelExporter(Exporter): args_list = list(args) else: args_list = [args] - if isinstance(args_list[-1], dict): + if isinstance(args_list[-1], Mapping): args_dict = args_list[-1] args_list = args_list[:-1] n_nonkeyword = len(args_list) @@ -284,9 +284,8 @@ class TorchModelExporter(Exporter): 'Model property dummy_inputs must be set.') dummy_inputs = collate_fn(dummy_inputs, device) if isinstance(dummy_inputs, Mapping): - dummy_inputs = self._decide_input_format(model, dummy_inputs) dummy_inputs_filter = [] - for _input in dummy_inputs: + for _input in self._decide_input_format(model, dummy_inputs): if _input is not None: dummy_inputs_filter.append(_input) else: diff --git a/modelscope/hub/api.py b/modelscope/hub/api.py index dca6d099..f2ff822d 100644 --- a/modelscope/hub/api.py +++ b/modelscope/hub/api.py @@ -23,7 +23,8 @@ from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, API_RESPONSE_FIELD_MESSAGE, API_RESPONSE_FIELD_USERNAME, DEFAULT_CREDENTIALS_PATH, - MODELSCOPE_ENVIRONMENT, ONE_YEAR_SECONDS, + MODELSCOPE_ENVIRONMENT, + MODELSCOPE_USERNAME, ONE_YEAR_SECONDS, Licenses, ModelVisibility) from modelscope.hub.errors import (InvalidParameter, NotExistError, NotLoginException, NoValidRevisionError, @@ -38,8 +39,8 @@ from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, DEFAULT_MODEL_REVISION, DEFAULT_REPOSITORY_REVISION, MASTER_MODEL_BRANCH, DatasetFormations, - DatasetMetaFormats, DownloadMode, - ModelFile) + DatasetMetaFormats, DownloadChannel, + DownloadMode, ModelFile) from modelscope.utils.logger import get_logger from .utils.utils import (get_endpoint, get_release_datetime, model_id_to_group_owner_name) @@ -645,6 +646,25 @@ class HubApi: def check_local_cookies(self, use_cookies) -> CookieJar: return self._check_cookie(use_cookies=use_cookies) + def dataset_download_uv(self, dataset_name: str, namespace: str): + if not dataset_name or not namespace: + raise ValueError('dataset_name or namespace cannot be empty!') + + # get channel and user_name + channel = DownloadChannel.LOCAL.value + user_name = '' + if MODELSCOPE_ENVIRONMENT in os.environ: + channel = os.environ[MODELSCOPE_ENVIRONMENT] + if MODELSCOPE_USERNAME in os.environ: + user_name = os.environ[MODELSCOPE_USERNAME] + + url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/uv/{channel}?user={user_name}' + cookies = ModelScopeConfig.get_cookies() + r = requests.post(url, cookies=cookies, headers=self.headers) + resp = r.json() + raise_on_error(resp) + return resp['Message'] + class ModelScopeConfig: path_credential = expanduser(DEFAULT_CREDENTIALS_PATH) @@ -760,14 +780,18 @@ class ModelScopeConfig: env = 'custom' if MODELSCOPE_ENVIRONMENT in os.environ: env = os.environ[MODELSCOPE_ENVIRONMENT] + user_name = 'unknown' + if MODELSCOPE_USERNAME in os.environ: + user_name = os.environ[MODELSCOPE_USERNAME] - ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s' % ( + ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s; user/%s' % ( __version__, platform.python_version(), ModelScopeConfig.get_user_session_id(), platform.platform(), platform.processor(), env, + user_name, ) if isinstance(user_agent, dict): ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items()) diff --git a/modelscope/hub/constants.py b/modelscope/hub/constants.py index 730702c1..373a0cf4 100644 --- a/modelscope/hub/constants.py +++ b/modelscope/hub/constants.py @@ -18,6 +18,7 @@ API_RESPONSE_FIELD_EMAIL = 'Email' API_RESPONSE_FIELD_MESSAGE = 'Message' MODELSCOPE_ENVIRONMENT = 'MODELSCOPE_ENVIRONMENT' MODELSCOPE_SDK_DEBUG = 'MODELSCOPE_SDK_DEBUG' +MODELSCOPE_USERNAME = 'MODELSCOPE_USERNAME' ONE_YEAR_SECONDS = 24 * 365 * 60 * 60 diff --git a/modelscope/hub/utils/utils.py b/modelscope/hub/utils/utils.py index a54f3413..61d560fa 100644 --- a/modelscope/hub/utils/utils.py +++ b/modelscope/hub/utils/utils.py @@ -5,6 +5,8 @@ import os from datetime import datetime from typing import Optional +import requests + from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN, DEFAULT_MODELSCOPE_GROUP, MODEL_ID_SEPARATOR, MODELSCOPE_SDK_DEBUG, @@ -85,3 +87,16 @@ def file_integrity_validation(file_path, expected_sha256): msg = 'File %s integrity check failed, the download may be incomplete, please try again.' % file_path logger.error(msg) raise FileIntegrityError(msg) + + +def create_library_statistics(method: str, name: str, cn_name: Optional[str]): + try: + from modelscope.hub.api import ModelScopeConfig + path = f'{get_endpoint()}/api/v1/statistics/library' + headers = {'user-agent': ModelScopeConfig.get_user_agent()} + params = {'Method': method, 'Name': name, 'CnName': cn_name} + r = requests.post(path, params=params, headers=headers) + r.raise_for_status() + except Exception: + pass + return diff --git a/modelscope/models/audio/kws/farfield/model.py b/modelscope/models/audio/kws/farfield/model.py index d63d1e2a..af1c0a27 100644 --- a/modelscope/models/audio/kws/farfield/model.py +++ b/modelscope/models/audio/kws/farfield/model.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os +import tempfile from typing import Dict, Optional from modelscope.metainfo import Models @@ -36,12 +37,15 @@ class FSMNSeleNetV2Decorator(TorchModel): else: sc_config_file = os.path.join(model_dir, self.SC_CONFIG) model_txt_file = os.path.join(model_dir, self.MODEL_TXT) + self.tmp_dir = tempfile.TemporaryDirectory() + new_config_file = os.path.join(self.tmp_dir.name, self.SC_CONFIG) + self._sc = None if os.path.exists(model_txt_file): conf_dict = dict(mode=56542, kws_model=model_txt_file) - update_conf(sc_config_file, sc_config_file, conf_dict) + update_conf(sc_config_file, new_config_file, conf_dict) import py_sound_connect - self._sc = py_sound_connect.SoundConnect(sc_config_file) + self._sc = py_sound_connect.SoundConnect(new_config_file) self.size_in = self._sc.bytesPerBlockIn() self.size_out = self._sc.bytesPerBlockOut() else: @@ -49,6 +53,9 @@ class FSMNSeleNetV2Decorator(TorchModel): f'Invalid model directory! Failed to load model file: {model_txt_file}.' ) + def __del__(self): + self.tmp_dir.cleanup() + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: return self.model.forward(input) diff --git a/modelscope/models/base/base_model.py b/modelscope/models/base/base_model.py index 1ca7e030..721478c3 100644 --- a/modelscope/models/base/base_model.py +++ b/modelscope/models/base/base_model.py @@ -131,6 +131,8 @@ class Model(ABC): if not hasattr(model, 'cfg'): model.cfg = cfg + + model.name = model_name_or_path return model def save_pretrained(self, diff --git a/modelscope/models/multi_modal/clip/model.py b/modelscope/models/multi_modal/clip/model.py index b1c84292..9b82e4a1 100644 --- a/modelscope/models/multi_modal/clip/model.py +++ b/modelscope/models/multi_modal/clip/model.py @@ -349,11 +349,13 @@ class CLIP(nn.Module): text_num_hidden_layers: int, text_type_vocab_size: int, tokenizer: FullTokenizer, + # vision_head_width, added this param for ViT-H + vision_head_width: int = 64, ): super().__init__() if isinstance(vision_layers, (tuple, list)): - vision_heads = vision_width * 32 // 64 + vision_heads = vision_width * 32 // vision_head_width self.visual = ModifiedResNet( layers=vision_layers, output_dim=embed_dim, @@ -361,7 +363,7 @@ class CLIP(nn.Module): input_resolution=image_resolution, width=vision_width) else: - vision_heads = vision_width // 64 + vision_heads = vision_width // vision_head_width self.visual = VisualTransformer( input_resolution=image_resolution, patch_size=vision_patch_size, diff --git a/modelscope/models/multi_modal/ofa/configuration_ofa.py b/modelscope/models/multi_modal/ofa/configuration_ofa.py index 4899f416..2edc651e 100644 --- a/modelscope/models/multi_modal/ofa/configuration_ofa.py +++ b/modelscope/models/multi_modal/ofa/configuration_ofa.py @@ -136,6 +136,12 @@ class OFAConfig(PretrainedConfig): entangle_position_embedding=False, interpolate_position=False, orig_patch_image_size=224, + share_attn_bias=False, + use_image_feature=True, + disable_entangle=False, + use_ofasys=False, + vit_type='vit_base', + vit_drop_path_rate=0.0, **kwargs): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings @@ -178,6 +184,13 @@ class OFAConfig(PretrainedConfig): self.interpolate_position = interpolate_position self.orig_patch_image_size = orig_patch_image_size + self.share_attn_bias = share_attn_bias + self.use_image_feature = use_image_feature + self.disable_entangle = disable_entangle + self.use_ofasys = use_ofasys + self.vit_type = vit_type + self.vit_drop_path_rate = vit_drop_path_rate + super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, diff --git a/modelscope/models/multi_modal/ofa/modeling_ofa.py b/modelscope/models/multi_modal/ofa/modeling_ofa.py old mode 100755 new mode 100644 index 0a7a2ce6..69005ef0 --- a/modelscope/models/multi_modal/ofa/modeling_ofa.py +++ b/modelscope/models/multi_modal/ofa/modeling_ofa.py @@ -35,6 +35,8 @@ from transformers.utils import logging from .configuration_ofa import OFAConfig from .generate import utils from .resnet import ResNet +from .utils.utils import DropPath +from .vit import vit_base, vit_huge, vit_large, vit_large_336 logger = logging.get_logger(__name__) @@ -249,45 +251,6 @@ class LayerDropModuleList(nn.ModuleList): yield m -def drop_path(x, drop_prob: float = 0.0, training: bool = False): - r""" - Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - - Args: - x (`nn.Modules`): input nn layers. - drop_prob (`float`): drop path ratio. - training (`bool`): whether is training or inference. - """ - if drop_prob == 0.0 or not training: - return x - keep_prob = 1 - drop_prob - shape = (1, x.shape[1], 1) - random_tensor = keep_prob + torch.rand( - shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor - return output - - -class DropPath(nn.Module): - r""" - Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - - Args: - drop_prob: drop path ratio. - """ - - def __init__(self, drop_prob=None): - super().__init__() - self.drop_prob = drop_prob - - def forward(self, x): - return drop_path(x, self.drop_prob, self.training) - - def extra_repr(self) -> str: - return 'p={}'.format(self.drop_prob) - - class OFAAttention(nn.Module): r""" Multi-headed attention, with additional implementation for NormFormer. @@ -898,31 +861,49 @@ class OFAEncoder(OFAPreTrainedModel): self.padding_idx) if config.add_type_embedding: - self.type_embedding = Embedding(2, embed_dim, padding_idx=None) + if config.use_image_feature: + self.type_embedding = Embedding(2, embed_dim, padding_idx=None) + else: + self.type_embedding = Embedding(1, embed_dim, padding_idx=None) else: self.type_embedding = None - if config.resnet_type == 'resnet18': - self.embed_images = ResNet( - [2, 2, 2], drop_path_rate=config.resnet_drop_path_rate) - elif config.resnet_type == 'resnet34': - self.embed_images = ResNet( - [3, 4, 6], drop_path_rate=config.resnet_drop_path_rate) - elif config.resnet_type == 'resnet50': - self.embed_images = ResNet( - [3, 4, 6], drop_path_rate=config.resnet_drop_path_rate) - elif config.resnet_type == 'resnet101': - self.embed_images = ResNet( - [3, 4, 23], drop_path_rate=config.resnet_drop_path_rate) - elif config.resnet_type == 'resnet152': - self.embed_images = ResNet( - [3, 8, 36], drop_path_rate=config.resnet_drop_path_rate) - else: - raise NotImplementedError + if config.use_image_feature: + if config.use_ofasys: + vit_backbone = { + 'vit_base': vit_base, + 'vit_large': vit_large, + 'vit_large_336': vit_large_336, + 'vit_huge': vit_huge, + }[config.vit_type] + self.embed_images = vit_backbone(config.vit_drop_path_rate) - self.image_proj = Linear(1024, embed_dim) + self.image_proj = Linear(self.embed_images.width, embed_dim) - if config.resnet_model_path: + else: + if config.resnet_type == 'resnet18': + self.embed_images = ResNet( + [2, 2, 2], drop_path_rate=config.resnet_drop_path_rate) + elif config.resnet_type == 'resnet34': + self.embed_images = ResNet( + [3, 4, 6], drop_path_rate=config.resnet_drop_path_rate) + elif config.resnet_type == 'resnet50': + self.embed_images = ResNet( + [3, 4, 6], drop_path_rate=config.resnet_drop_path_rate) + elif config.resnet_type == 'resnet101': + self.embed_images = ResNet( + [3, 4, 23], + drop_path_rate=config.resnet_drop_path_rate) + elif config.resnet_type == 'resnet152': + self.embed_images = ResNet( + [3, 8, 36], + drop_path_rate=config.resnet_drop_path_rate) + else: + raise NotImplementedError + + self.image_proj = Linear(1024, embed_dim) + + if not config.use_ofasys and config.resnet_model_path: print('load resnet {}'.format(config.resnet_model_path)) resnet_state_dict = torch.load(config.resnet_model_path) self.embed_images.load_state_dict(resnet_state_dict) @@ -933,14 +914,21 @@ class OFAEncoder(OFAPreTrainedModel): self.embed_positions = Embedding(self.max_source_positions + 2, embed_dim) - self.embed_image_positions = Embedding(config.image_bucket_size**2 + 1, - embed_dim) - self.pos_ln = LayerNorm(embed_dim) - self.image_pos_ln = LayerNorm(embed_dim) + + if config.use_image_feature: + self.embed_image_positions = Embedding( + config.image_bucket_size**2 + 1, embed_dim) + if not config.use_ofasys: + self.pos_ln = LayerNorm(embed_dim) + + if config.use_image_feature: + self.image_pos_ln = LayerNorm(embed_dim) self.pos_scaling = float(embed_dim / self.num_attention_heads * config.attn_scale_factor)**-0.5 - self.pos_q_linear = nn.Linear(embed_dim, embed_dim) - self.pos_k_linear = nn.Linear(embed_dim, embed_dim) + + if not (config.use_ofasys and config.entangle_position_embedding): + self.pos_q_linear = nn.Linear(embed_dim, embed_dim) + self.pos_k_linear = nn.Linear(embed_dim, embed_dim) if self.encoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.encoder_layerdrop) @@ -965,22 +953,28 @@ class OFAEncoder(OFAPreTrainedModel): self.token_bucket_size = config.token_bucket_size token_num_rel_dis = 2 * config.token_bucket_size - 1 token_rp_bucket = make_token_bucket_position(config.token_bucket_size) + self.share_attn_bias = config.share_attn_bias + num_rel_pos_tables = 1 if config.share_attn_bias else config.encoder_layers self.token_rel_pos_table_list = nn.ModuleList([ Embedding( token_num_rel_dis, self.num_attention_heads, zero_init=True) - for _ in range(config.encoder_layers) + for _ in range(num_rel_pos_tables) ]) - self.image_bucket_size = config.image_bucket_size - image_num_rel_dis = (2 * config.image_bucket_size - - 1) * (2 * config.image_bucket_size - 1) + 3 - image_rp_bucket = make_image_bucket_position(config.image_bucket_size, - image_num_rel_dis) - self.image_rel_pos_table_list = nn.ModuleList([ - Embedding( - image_num_rel_dis, self.num_attention_heads, zero_init=True) - for _ in range(config.encoder_layers) - ]) + if config.use_image_feature: + self.image_bucket_size = config.image_bucket_size + image_num_rel_dis = (2 * config.image_bucket_size + - 1) * (2 * config.image_bucket_size - 1) + 3 + image_rp_bucket = make_image_bucket_position( + config.image_bucket_size, image_num_rel_dis) + self.image_rel_pos_table_list = nn.ModuleList([ + Embedding( + image_num_rel_dis, + self.num_attention_heads, + zero_init=True) for _ in range(num_rel_pos_tables) + ]) + + self.register_buffer('image_rp_bucket', image_rp_bucket) if config.layernorm_embedding: self.layernorm_embedding = LayerNorm(embed_dim) @@ -988,12 +982,12 @@ class OFAEncoder(OFAPreTrainedModel): self.layernorm_embedding = None self.register_buffer('token_rp_bucket', token_rp_bucket) - self.register_buffer('image_rp_bucket', image_rp_bucket) self.entangle_position_embedding = config.entangle_position_embedding self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() + self.use_ofasys = config.use_ofasys def get_input_embeddings(self): r""" @@ -1305,21 +1299,41 @@ class OFAEncoder(OFAPreTrainedModel): if has_pads: x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) - pos_embed = self.pos_ln(pos_embed) - if patch_images is not None: - image_pos_embed = self.image_pos_ln(image_pos_embed) - pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1) - if patch_images_2 is not None: - image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2) - pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1) + if self.use_ofasys: + if patch_images is not None: + pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1) + if patch_images_2 is not None: + pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1) + else: + pos_embed = self.pos_ln(pos_embed) + if patch_images is not None: + image_pos_embed = self.image_pos_ln(image_pos_embed) + pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1) + if patch_images_2 is not None: + image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2) + pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1) + + def build_abs_pos_bias(pos_embed): + batch_size, seq_length = pos_embed.size(0), pos_embed.size(1) + if not (self.use_ofasys and self.entangle_position_embedding): + pos_q = self.pos_q_linear(pos_embed).view( + batch_size, seq_length, self.num_attention_heads, + -1).transpose(1, 2) * self.pos_scaling + pos_k = self.pos_k_linear(pos_embed).view( + batch_size, seq_length, self.num_attention_heads, + -1).transpose(1, 2) + abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) + else: + abs_pos_bias = torch.zeros( + batch_size, + self.num_attention_heads, + seq_length, + seq_length, + dtype=pos_embed.dtype, + device=pos_embed.device) + return abs_pos_bias - pos_q = self.pos_q_linear(pos_embed).view( - x.size(0), x.size(1), self.num_attention_heads, -1).transpose( - 1, 2) * self.pos_scaling - pos_k = self.pos_k_linear(pos_embed).view( - x.size(0), x.size(1), self.num_attention_heads, - -1).transpose(1, 2) - abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) + abs_pos_bias = build_abs_pos_bias(pos_embed) # expand attention_mask if has_pads: @@ -1334,19 +1348,22 @@ class OFAEncoder(OFAPreTrainedModel): if output_hidden_states: encoder_states += (x, ) self_attn_bias = abs_pos_bias.clone() + + real_idx = 0 if self.share_attn_bias else idx + self_attn_bias[:, :, -input_ids.size(1):, -input_ids.size(1):] += self.get_rel_pos_bias( - input_ids, idx) + input_ids, real_idx) if patch_images_2 is not None: self_attn_bias[:, :, :image_num_patches_2, :image_num_patches_2] += \ - self.get_image_rel_pos_bias(image_position_ids_2, idx) + self.get_image_rel_pos_bias(image_position_ids_2, real_idx) self_attn_bias[:, :, image_num_patches_2:image_num_patches_2 + image_num_patches, # noqa image_num_patches_2:image_num_patches_2 + image_num_patches] += \ - self.get_image_rel_pos_bias(image_position_ids, idx) # noqa + self.get_image_rel_pos_bias(image_position_ids, real_idx) # noqa elif patch_images is not None: self_attn_bias[:, :, :x.size(1) - input_ids.size(1), :x.size(1) - input_ids.size(1)] += \ - self.get_image_rel_pos_bias(image_position_ids, idx) + self.get_image_rel_pos_bias(image_position_ids, real_idx) self_attn_bias = self_attn_bias.reshape(-1, x.size(1), x.size(1)) hidden_outputs = layer( @@ -1398,6 +1415,8 @@ class OFADecoder(OFAPreTrainedModel): self._future_mask = torch.empty(0) self.share_input_output_embed = config.share_decoder_input_output_embed self.num_attention_heads = config.decoder_attention_heads + self.use_ofasys = config.use_ofasys + self.disable_entangle = config.disable_entangle if embed_tokens is not None: self.embed_tokens = embed_tokens @@ -1415,18 +1434,31 @@ class OFADecoder(OFAPreTrainedModel): else: self.layernorm_embedding = None + if config.use_ofasys: + if config.add_type_embedding: + self.type_embedding = Embedding( + 1, self.embed_dim, padding_idx=None) + else: + self.type_embedding = None + self.window_size = config.code_image_size // 8 self.embed_positions = Embedding(self.max_target_positions + 2, self.embed_dim) - self.embed_image_positions = Embedding(config.image_bucket_size**2 + 1, - self.embed_dim) - self.pos_ln = LayerNorm(self.embed_dim) - self.image_pos_ln = LayerNorm(self.embed_dim) + + if not config.use_ofasys: + self.embed_image_positions = Embedding( + config.image_bucket_size**2 + 1, self.embed_dim) + if not config.use_ofasys: + self.pos_ln = LayerNorm(self.embed_dim) + self.image_pos_ln = LayerNorm(self.embed_dim) self.pos_scaling = float(self.embed_dim / self.num_attention_heads * config.attn_scale_factor)**-0.5 - self.self_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim) - self.self_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim) + + if not (config.use_ofasys and config.entangle_position_embedding): + self.self_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim) + self.self_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim) + self.cross_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim) self.cross_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim) @@ -1463,33 +1495,41 @@ class OFADecoder(OFAPreTrainedModel): self.token_bucket_size = config.token_bucket_size token_num_rel_dis = 2 * config.token_bucket_size - 1 token_rp_bucket = make_token_bucket_position(config.token_bucket_size) + + self.share_attn_bias = config.share_attn_bias + num_rel_pos_tables = 1 if config.share_attn_bias else config.decoder_layers self.token_rel_pos_table_list = nn.ModuleList([ Embedding( token_num_rel_dis, self.num_attention_heads, zero_init=True) - for _ in range(config.decoder_layers) + for _ in range(num_rel_pos_tables) ]) - self.image_bucket_size = config.image_bucket_size - image_num_rel_dis = (2 * config.image_bucket_size - - 1) * (2 * config.image_bucket_size - 1) + 3 - image_rp_bucket = make_image_bucket_position(config.image_bucket_size, - image_num_rel_dis) - image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \ - torch.arange(self.window_size).unsqueeze(1) * config.image_bucket_size + 1 # noqa - image_position_idx = torch.cat( - [torch.tensor([0]), image_position_idx.view(-1)]) - image_position_idx = torch.cat( - [image_position_idx, - torch.tensor([1024] * 768)]) - self.image_rel_pos_table_list = nn.ModuleList([ - Embedding( - image_num_rel_dis, self.num_attention_heads, zero_init=True) - for _ in range(config.decoder_layers) - ]) + if config.use_image_feature: + if not config.use_ofasys: + self.image_bucket_size = config.image_bucket_size + image_num_rel_dis = (2 * config.image_bucket_size - 1) * ( + 2 * config.image_bucket_size - 1) + 3 + image_rp_bucket = make_image_bucket_position( + config.image_bucket_size, image_num_rel_dis) + image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \ + torch.arange(self.window_size).unsqueeze(1) * config.image_bucket_size + 1 # noqa + image_position_idx = torch.cat( + [torch.tensor([0]), + image_position_idx.view(-1)]) + image_position_idx = torch.cat( + [image_position_idx, + torch.tensor([1024] * 768)]) + self.register_buffer('image_position_idx', image_position_idx) + + self.image_rel_pos_table_list = nn.ModuleList([ + Embedding( + image_num_rel_dis, + self.num_attention_heads, + zero_init=True) for _ in range(num_rel_pos_tables) + ]) + self.register_buffer('image_rp_bucket', image_rp_bucket) self.register_buffer('token_rp_bucket', token_rp_bucket) - self.register_buffer('image_rp_bucket', image_rp_bucket) - self.register_buffer('image_position_idx', image_position_idx) self.entangle_position_embedding = config.entangle_position_embedding self.gradient_checkpointing = False @@ -1556,26 +1596,46 @@ class OFADecoder(OFAPreTrainedModel): batch_size = tgt_pos_embed.size(0) tgt_len = tgt_pos_embed.size(1) - tgt_pos_embed = self.image_pos_ln( - tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed) + if not self.use_ofasys: + tgt_pos_embed = self.image_pos_ln( + tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed) if src_pos_embed is not None: src_len = src_pos_embed.size(1) - pos_q = self.cross_pos_q_linear(tgt_pos_embed).view( - batch_size, tgt_len, self.num_attention_heads, -1).transpose( - 1, 2) * self.pos_scaling - pos_k = self.cross_pos_k_linear(src_pos_embed).view( - batch_size, src_len, self.num_attention_heads, - -1).transpose(1, 2) + if not (self.entangle_position_embedding and self.use_ofasys): + pos_q = self.cross_pos_q_linear(tgt_pos_embed).view( + batch_size, tgt_len, self.num_attention_heads, + -1).transpose(1, 2) * self.pos_scaling + pos_k = self.cross_pos_k_linear(src_pos_embed).view( + batch_size, src_len, self.num_attention_heads, + -1).transpose(1, 2) + abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) + else: + abs_pos_bias = torch.zeros( + batch_size, + self.num_attention_heads, + tgt_len, + src_len, + dtype=tgt_pos_embed.dtype, + device=tgt_pos_embed.device) else: - src_len = tgt_pos_embed.size(1) - pos_q = self.self_pos_q_linear(tgt_pos_embed).view( - batch_size, tgt_len, self.num_attention_heads, -1).transpose( - 1, 2) * self.pos_scaling - pos_k = self.self_pos_k_linear(tgt_pos_embed).view( - batch_size, src_len, self.num_attention_heads, - -1).transpose(1, 2) - abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) + # batch_size, seq_length = tgt_pos_embed.size(0), tgt_pos_embed.size(1) + if not (self.entangle_position_embedding and self.use_ofasys): + pos_q = self.self_pos_q_linear(tgt_pos_embed).view( + batch_size, tgt_len, self.num_attention_heads, + -1).transpose(1, 2) * self.pos_scaling + pos_k = self.self_pos_k_linear(tgt_pos_embed).view( + batch_size, tgt_len, self.num_attention_heads, + -1).transpose(1, 2) + abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) + else: + abs_pos_bias = torch.zeros( + batch_size, + self.num_attention_heads, + tgt_len, + tgt_len, + dtype=tgt_pos_embed.dtype, + device=tgt_pos_embed.device) return abs_pos_bias @@ -1809,17 +1869,18 @@ class OFADecoder(OFAPreTrainedModel): past_key_values) > 0 else None self_attn_bias = self_abs_pos_bias.clone() + real_idx = 0 if self.share_attn_bias else idx if code_masks is None or not code_masks.any(): self_attn_bias += self.get_rel_pos_bias( - all_prev_output_tokens, idx).unsqueeze(0) + all_prev_output_tokens, real_idx).unsqueeze(0) elif code_masks is not None and code_masks.all(): self_attn_bias += self.get_image_rel_pos_bias( - all_prev_output_tokens, idx).unsqueeze(0) + all_prev_output_tokens, real_idx).unsqueeze(0) else: self_attn_bias[~code_masks] += self.get_rel_pos_bias( - all_prev_output_tokens, idx).unsqueeze(0) + all_prev_output_tokens, real_idx).unsqueeze(0) self_attn_bias[code_masks] += self.get_image_rel_pos_bias( - all_prev_output_tokens, idx).unsqueeze(0) + all_prev_output_tokens, real_idx).unsqueeze(0) self_attn_bias = self_attn_bias.reshape( -1, *self_attn_bias.size()[-2:]) @@ -1892,6 +1953,7 @@ class OFAModel(OFAPreTrainedModel): self.encoder = OFAEncoder(config, shared) self.decoder = OFADecoder(config, shared) + self.use_ofasys = config.use_ofasys # Initialize weights and apply final processing self.post_init() diff --git a/modelscope/models/multi_modal/ofa/utils/utils.py b/modelscope/models/multi_modal/ofa/utils/utils.py index 6d8943a1..c5aa8483 100644 --- a/modelscope/models/multi_modal/ofa/utils/utils.py +++ b/modelscope/models/multi_modal/ofa/utils/utils.py @@ -2,6 +2,7 @@ from typing import Optional import torch +import torch.nn as nn def expand_mask(mask: torch.Tensor, @@ -17,3 +18,42 @@ def expand_mask(mask: torch.Tensor, src_len).to(dtype) return expanded_mask.masked_fill(expanded_mask.bool(), torch.finfo(dtype).min) + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + r""" + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Args: + x (`nn.Modules`): input nn layers. + drop_prob (`float`): drop path ratio. + training (`bool`): whether is training or inference. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (1, x.shape[1], 1) + random_tensor = keep_prob + torch.rand( + shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + r""" + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Args: + drop_prob: drop path ratio. + """ + + def __init__(self, drop_prob=None): + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return 'p={}'.format(self.drop_prob) diff --git a/modelscope/models/multi_modal/ofa/vit.py b/modelscope/models/multi_modal/ofa/vit.py new file mode 100644 index 00000000..b6bba7ee --- /dev/null +++ b/modelscope/models/multi_modal/ofa/vit.py @@ -0,0 +1,155 @@ +from collections import OrderedDict + +import torch +import torch.nn.functional as F +from fairseq.modules import LayerNorm +from torch import nn + +from .utils.utils import DropPath + +__all__ = [ + 'vit_base', + 'vit_large', + 'vit_large_336', + 'vit_huge', +] + + +class QuickGELU(nn.Module): + + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + + def __init__(self, + d_model: int, + n_head: int, + attn_mask: torch.Tensor = None, + drop_path_rate=0.0): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([ + ('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model)), + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + self.drop_path = DropPath(drop_path_rate) + + def attention(self, x: torch.Tensor): + self.attn_mask = ( + self.attn_mask.to(dtype=x.dtype, device=x.device) + if self.attn_mask is not None else None) + return self.attn( + x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.drop_path(self.attention(self.ln_1(x))) + x = x + self.drop_path(self.mlp(self.ln_2(x))) + return x + + +class Transformer(nn.Module): + + def __init__( + self, + width: int, + layers: int, + heads: int, + attn_mask: torch.Tensor = None, + drop_path_rate: float = 0.0, + ): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ + ResidualAttentionBlock(width, heads, attn_mask, drop_path_rate) + for _ in range(layers) + ]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + + def __init__( + self, + input_resolution: int, + patch_size: int, + width: int, + layers: int, + heads: int, + drop_path_rate: float = 0.0, + ): + super().__init__() + self.input_resolution = input_resolution + self.patch_size = patch_size + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False, + ) + + scale = width**-0.5 + self.width = width + self.positional_embedding = nn.Parameter(scale * torch.randn( + (input_resolution // patch_size)**2 + 1, width)) + self.ln_pre = LayerNorm(width) + self.transformer = Transformer( + width, layers, heads, drop_path_rate=drop_path_rate) + + def forward(self, x: torch.Tensor): + resolution = x.shape[-2] + height, width = x.shape[-2] // self.patch_size, x.shape[ + -1] // self.patch_size + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + if resolution != self.input_resolution: + old_pe = self.positional_embedding[1:] + patch_num = self.input_resolution // self.patch_size + old_pe = old_pe.reshape(1, patch_num, patch_num, + -1).permute(0, 3, 1, 2) + new_pe = F.interpolate( + old_pe, size=(height, width), mode='bilinear') + new_pe = new_pe.permute(0, 2, 3, 1).reshape(height * width, -1) + x = x + new_pe.to(x.dtype) + else: + x = x + self.positional_embedding[1:].to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + bz, seq, hidden = x.shape + x = x.transpose(1, 2).reshape(bz, hidden, height, width) + + return x + + +def vit_base(drop_path_rate: float = 0.0): + return VisionTransformer(224, 16, 768, 9, 12, drop_path_rate) + + +def vit_large(drop_path_rate: float = 0.0): + return VisionTransformer(224, 14, 1024, 18, 16, drop_path_rate) + + +def vit_large_336(drop_path_rate: float = 0.0): + return VisionTransformer(336, 14, 1024, 18, 16, drop_path_rate) + + +def vit_huge(drop_path_rate: float = 0.0): + return VisionTransformer(224, 14, 1280, 24, 16, drop_path_rate) diff --git a/modelscope/models/multi_modal/ofa_for_all_tasks.py b/modelscope/models/multi_modal/ofa_for_all_tasks.py index 56d19ad8..fc578b25 100644 --- a/modelscope/models/multi_modal/ofa_for_all_tasks.py +++ b/modelscope/models/multi_modal/ofa_for_all_tasks.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import math import os +import re import string from functools import partial from os import path as osp @@ -53,8 +54,11 @@ class OfaForAllTasks(TorchModel): raise NotImplementedError # there is some diff between here and our ofa code, # there will be no need to use param: use_bpe - self.tokenizer.add_tokens([''.format(i) for i in range(8192)]) - self.tokenizer.add_tokens([''.format(i) for i in range(1000)]) + if not model.use_ofasys: + self.tokenizer.add_tokens( + [''.format(i) for i in range(8192)]) + self.tokenizer.add_tokens( + [''.format(i) for i in range(1000)]) self.cfg.update({'num_bins': 1000, 'num_codes': 8192}) self.batch_size = self.cfg.model.get('batch_size', 1) self.patch_image_size = self.cfg.model.get('patch_image_size', 480) @@ -107,6 +111,8 @@ class OfaForAllTasks(TorchModel): Tasks.text_classification: inference_d[self.gen_type], Tasks.image_classification: inference_d[self.gen_type], } + pattern_str = '((?<=[^ a-zA-Z0-9.,:!?]) +| +(?=[^ a-zA-Z0-9.,:!?]))' + self.pattern = re.compile(pattern_str) def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: input = move_to_device(input, self.model.device) @@ -132,8 +138,18 @@ class OfaForAllTasks(TorchModel): caption = input[OutputKeys.CAPTION] result_l = list() for cap in caption: - result_l.append(cap.translate(self.transtab).strip()) + if self.language == 'en': + result_l.append(cap.translate(self.transtab).strip()) + else: + result_l.append(cap) input[OutputKeys.CAPTION] = result_l + if self.gen_type == 'generation' and self.language in [ + 'zh', 'cn' + ] and self.cfg.task != Tasks.visual_grounding: + ret_l = list() + for text in input[OFA_TASK_KEY_MAPPING[self.cfg.task]]: + ret_l.append(self.detokenizer(text)) + input[OFA_TASK_KEY_MAPPING[self.cfg.task]] = ret_l return input def _text_gen_inference(self, input): @@ -311,3 +327,6 @@ class OfaForAllTasks(TorchModel): save_function=partial(save_function, with_meta=False), config=config, **kwargs) + + def detokenizer(self, text): + return self.pattern.sub('', text) diff --git a/modelscope/models/science/unifold/dataset.py b/modelscope/models/science/unifold/dataset.py index 05803f2c..29e1a8b0 100644 --- a/modelscope/models/science/unifold/dataset.py +++ b/modelscope/models/science/unifold/dataset.py @@ -1,3 +1,6 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + import copy import logging import os diff --git a/modelscope/models/science/unifold/model.py b/modelscope/models/science/unifold/model.py index 6632751a..7f28f18d 100644 --- a/modelscope/models/science/unifold/model.py +++ b/modelscope/models/science/unifold/model.py @@ -1,3 +1,6 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + import argparse import os from typing import Any diff --git a/modelscope/models/science/unifold/modules/__init__.py b/modelscope/models/science/unifold/modules/__init__.py new file mode 100644 index 00000000..63aa84ed --- /dev/null +++ b/modelscope/models/science/unifold/modules/__init__.py @@ -0,0 +1,3 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. +"""Unifold Modules.""" diff --git a/modelscope/msdatasets/ms_dataset.py b/modelscope/msdatasets/ms_dataset.py index 0c537df7..5c8ea59f 100644 --- a/modelscope/msdatasets/ms_dataset.py +++ b/modelscope/msdatasets/ms_dataset.py @@ -274,6 +274,8 @@ class MsDataset: try: api.on_dataset_download( dataset_name=download_dataset, namespace=namespace) + api.dataset_download_uv( + dataset_name=download_dataset, namespace=namespace) except Exception as e: logger.error(e) diff --git a/modelscope/outputs/outputs.py b/modelscope/outputs/outputs.py index b983125a..2c6dd85a 100644 --- a/modelscope/outputs/outputs.py +++ b/modelscope/outputs/outputs.py @@ -69,11 +69,23 @@ TASK_OUTPUTS = { # face 2d keypoint result for single sample # { # "keypoints": [ - # [x1, y1]*106 + # [[x, y]*106], + # [[x, y]*106], + # [[x, y]*106], # ], - # "poses": [pitch, roll, yaw] + # "poses": [ + # [pitch, roll, yaw], + # [pitch, roll, yaw], + # [pitch, roll, yaw], + # ], + # "boxes": [ + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # ] # } - Tasks.face_2d_keypoints: [OutputKeys.KEYPOINTS, OutputKeys.POSES], + Tasks.face_2d_keypoints: + [OutputKeys.KEYPOINTS, OutputKeys.POSES, OutputKeys.BOXES], # face detection result for single sample # { @@ -479,17 +491,8 @@ TASK_OUTPUTS = { # word segmentation result for single sample # { # "output": "今天 天气 不错 , 适合 出去 游玩" - # "labels": [ - # {'word': '今天', 'label': 'PROPN'}, - # {'word': '天气', 'label': 'PROPN'}, - # {'word': '不错', 'label': 'VERB'}, - # {'word': ',', 'label': 'NUM'}, - # {'word': '适合', 'label': 'NOUN'}, - # {'word': '出去', 'label': 'PART'}, - # {'word': '游玩', 'label': 'ADV'}, - # ] # } - Tasks.word_segmentation: [OutputKeys.OUTPUT, OutputKeys.LABELS], + Tasks.word_segmentation: [OutputKeys.OUTPUT], # TODO @wenmeng.zwm support list of result check # named entity recognition result for single sample @@ -699,8 +702,9 @@ TASK_OUTPUTS = { # "text_embedding": np.array with shape [1, D], # "caption": "this is an image caption text." # } - Tasks.generative_multi_modal_embedding: - [OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING, OutputKeys.CAPTION], + Tasks.generative_multi_modal_embedding: [ + OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING, OutputKeys.CAPTION + ], # multi-modal similarity result for single sample # { diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index bca80502..68010012 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -10,6 +10,7 @@ from typing import Any, Dict, Generator, List, Mapping, Union import numpy as np +from modelscope.hub.utils.utils import create_library_statistics from modelscope.models.base import Model from modelscope.msdatasets import MsDataset from modelscope.outputs import TASK_OUTPUTS @@ -151,7 +152,9 @@ class Pipeline(ABC): **kwargs) -> Union[Dict[str, Any], Generator]: # model provider should leave it as it is # modelscope library developer will handle this function - + for single_model in self.models: + if hasattr(single_model, 'name'): + create_library_statistics('pipeline', single_model.name, None) # place model to cpu or gpu if (self.model or (self.has_multiple_models and self.models[0])): if not self._model_prepare: diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 498c9ed8..70f8f11c 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -93,9 +93,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_resnet50_live-category'), Tasks.video_category: (Pipelines.video_category, 'damo/cv_resnet50_video-category'), - Tasks.multi_modal_embedding: - (Pipelines.multi_modal_embedding, - 'damo/multi-modal_clip-vit-large-patch14_zh'), + Tasks.multi_modal_embedding: (Pipelines.multi_modal_embedding, + 'damo/multi-modal_clip-vit-base-patch16_zh'), Tasks.generative_multi_modal_embedding: (Pipelines.generative_multi_modal_embedding, 'damo/multi-modal_gemm-vit-large-patch14_generative-multi-modal-embedding' diff --git a/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py b/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py index 8522ceff..d113fb3c 100644 --- a/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py +++ b/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py @@ -132,8 +132,8 @@ class Body3DKeypointsPipeline(Pipeline): device='gpu' if torch.cuda.is_available() else 'cpu') def preprocess(self, input: Input) -> Dict[str, Any]: - video_url = input - video_frames = self.read_video_frames(video_url) + self.video_url = input + video_frames = self.read_video_frames(self.video_url) if 0 == len(video_frames): res = {'success': False, 'msg': 'get video frame failed.'} return res @@ -198,7 +198,7 @@ class Body3DKeypointsPipeline(Pipeline): } if not input['success']: - pass + res[OutputKeys.OUTPUT_VIDEO] = self.video_url else: poses = input[KeypointsTypes.POSES_CAMERA] pred_3d_pose = poses.data.cpu().numpy()[ diff --git a/modelscope/pipelines/cv/easycv_pipelines/face_2d_keypoints_pipeline.py b/modelscope/pipelines/cv/easycv_pipelines/face_2d_keypoints_pipeline.py index b48d013e..29a96a5f 100644 --- a/modelscope/pipelines/cv/easycv_pipelines/face_2d_keypoints_pipeline.py +++ b/modelscope/pipelines/cv/easycv_pipelines/face_2d_keypoints_pipeline.py @@ -1,12 +1,22 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import copy +import math from typing import Any +import cv2 +import numpy as np + from modelscope.metainfo import Pipelines from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger from .base import EasyCVPipeline +logger = get_logger() + @PIPELINES.register_module( Tasks.face_2d_keypoints, module_name=Pipelines.face_2d_keypoints) @@ -29,18 +39,206 @@ class Face2DKeypointsPipeline(EasyCVPipeline): *args, **kwargs) + # face detect pipeline + det_model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' + self.face_detection = pipeline( + Tasks.face_detection, model=det_model_id) + def show_result(self, img, points, scale=2, save_path=None): return self.predict_op.show_result(img, points, scale, save_path) + def _choose_face(self, det_result, min_face=10): + """ + choose face with maximum area + Args: + det_result: output of face detection pipeline + min_face: minimum size of valid face w/h + """ + bboxes = np.array(det_result[OutputKeys.BOXES]) + landmarks = np.array(det_result[OutputKeys.KEYPOINTS]) + if bboxes.shape[0] == 0: + logger.warn('No face detected!') + return None + # face idx with enough size + face_idx = [] + for i in range(bboxes.shape[0]): + box = bboxes[i] + if (box[2] - box[0]) >= min_face and (box[3] - box[1]) >= min_face: + face_idx += [i] + if len(face_idx) == 0: + logger.warn( + f'Face size not enough, less than {min_face}x{min_face}!') + return None + bboxes = bboxes[face_idx] + landmarks = landmarks[face_idx] + + return bboxes, landmarks + + def expend_box(self, box, w, h, scalex=0.3, scaley=0.5): + x1 = box[0] + y1 = box[1] + wb = box[2] - x1 + hb = box[3] - y1 + deltax = int(wb * scalex) + deltay1 = int(hb * scaley) + deltay2 = int(hb * scalex) + x1 = x1 - deltax + y1 = y1 - deltay1 + if x1 < 0: + deltax = deltax + x1 + x1 = 0 + if y1 < 0: + deltay1 = deltay1 + y1 + y1 = 0 + x2 = x1 + wb + 2 * deltax + y2 = y1 + hb + deltay1 + deltay2 + x2 = np.clip(x2, 0, w - 1) + y2 = np.clip(y2, 0, h - 1) + return [x1, y1, x2, y2] + + def rotate_point(self, angle, center, landmark): + rad = angle * np.pi / 180.0 + alpha = np.cos(rad) + beta = np.sin(rad) + M = np.zeros((2, 3), dtype=np.float32) + M[0, 0] = alpha + M[0, 1] = beta + M[0, 2] = (1 - alpha) * center[0] - beta * center[1] + M[1, 0] = -beta + M[1, 1] = alpha + M[1, 2] = beta * center[0] + (1 - alpha) * center[1] + + landmark_ = np.asarray([(M[0, 0] * x + M[0, 1] * y + M[0, 2], + M[1, 0] * x + M[1, 1] * y + M[1, 2]) + for (x, y) in landmark]) + return M, landmark_ + + def rotate_crop_img(self, img, pts, M): + imgT = cv2.warpAffine(img, M, (int(img.shape[1]), int(img.shape[0]))) + + x1 = pts[5][0] + x2 = pts[5][0] + y1 = pts[5][1] + y2 = pts[5][1] + for i in range(0, 9): + x1 = min(x1, pts[i][0]) + x2 = max(x2, pts[i][0]) + y1 = min(y1, pts[i][1]) + y2 = max(y2, pts[i][1]) + + height, width, _ = imgT.shape + x1 = min(max(0, int(x1)), width) + y1 = min(max(0, int(y1)), height) + x2 = min(max(0, int(x2)), width) + y2 = min(max(0, int(y2)), height) + sub_imgT = imgT[y1:y2, x1:x2] + + return sub_imgT, imgT, [x1, y1, x2, y2] + + def crop_img(self, imgT, pts): + enlarge_ratio = 1.1 + + x1 = np.min(pts[:, 0]) + x2 = np.max(pts[:, 0]) + y1 = np.min(pts[:, 1]) + y2 = np.max(pts[:, 1]) + w = x2 - x1 + 1 + h = y2 - y1 + 1 + x1 = int(x1 - (enlarge_ratio - 1.0) / 2.0 * w) + y1 = int(y1 - (enlarge_ratio - 1.0) / 2.0 * h) + x1 = max(0, x1) + y1 = max(0, y1) + + new_w = int(enlarge_ratio * w) + new_h = int(enlarge_ratio * h) + new_x1 = x1 + new_y1 = y1 + new_x2 = new_x1 + new_w + new_y2 = new_y1 + new_h + + height, width, _ = imgT.shape + + new_x1 = min(max(0, new_x1), width) + new_y1 = min(max(0, new_y1), height) + new_x2 = max(min(width, new_x2), 0) + new_y2 = max(min(height, new_y2), 0) + + sub_imgT = imgT[new_y1:new_y2, new_x1:new_x2] + + return sub_imgT, [new_x1, new_y1, new_x2, new_y2] + def __call__(self, inputs) -> Any: - outputs = self.predict_op(inputs) + img = LoadImage.convert_to_ndarray(inputs) + h, w, c = img.shape + img_rgb = copy.deepcopy(img) + img_rgb = img_rgb[:, :, ::-1] + det_result = self.face_detection(img_rgb) + + bboxes = np.array(det_result[OutputKeys.BOXES]) + if bboxes.shape[0] == 0: + logger.warn('No face detected!') + results = { + OutputKeys.KEYPOINTS: [], + OutputKeys.POSES: [], + OutputKeys.BOXES: [] + } + return results + + boxes, keypoints = self._choose_face(det_result) + + output_boxes = [] + output_keypoints = [] + output_poses = [] + for index, box_ori in enumerate(boxes): + box = self.expend_box(box_ori, w, h, scalex=0.1, scaley=0.1) + y0 = int(box[1]) + y1 = int(box[3]) + x0 = int(box[0]) + x1 = int(box[2]) + sub_img = img[y0:y1, x0:x1] + + keypoint = keypoints[index] + pts = [[keypoint[0], keypoint[1]], [keypoint[2], keypoint[3]], + [keypoint[4], keypoint[5]], [keypoint[6], keypoint[7]], + [keypoint[8], keypoint[9]], [box[0], box[1]], + [box[2], box[1]], [box[0], box[3]], [box[2], box[3]]] + # radian + angle = math.atan2((pts[1][1] - pts[0][1]), + (pts[1][0] - pts[0][0])) + # angle + theta = angle * (180 / np.pi) + + center = [w // 2, h // 2] + cx, cy = center + M, landmark_ = self.rotate_point(theta, (cx, cy), pts) + sub_imgT, imgT, bbox = self.rotate_crop_img(img, landmark_, M) + + outputs = self.predict_op([sub_imgT])[0] + tmp_keypoints = outputs['point'] + + for idx in range(0, len(tmp_keypoints)): + tmp_keypoints[idx][0] += bbox[0] + tmp_keypoints[idx][1] += bbox[1] + + for idx in range(0, 6): + sub_img, bbox = self.crop_img(imgT, tmp_keypoints) + outputs = self.predict_op([sub_img])[0] + tmp_keypoints = outputs['point'] + for idx in range(0, len(tmp_keypoints)): + tmp_keypoints[idx][0] += bbox[0] + tmp_keypoints[idx][1] += bbox[1] + + M2, tmp_keypoints = self.rotate_point(-theta, (cx, cy), + tmp_keypoints) - results = [{ - OutputKeys.KEYPOINTS: output['point'], - OutputKeys.POSES: output['pose'] - } for output in outputs] + output_keypoints.append(np.array(tmp_keypoints)) + output_poses.append(np.array(outputs['pose'])) + output_boxes.append(np.array(box_ori)) - if self._is_single_inputs(inputs): - results = results[0] + results = { + OutputKeys.KEYPOINTS: output_keypoints, + OutputKeys.POSES: output_poses, + OutputKeys.BOXES: output_boxes + } return results diff --git a/modelscope/pipelines/nlp/token_classification_pipeline.py b/modelscope/pipelines/nlp/token_classification_pipeline.py index 75bc538d..4af187ee 100644 --- a/modelscope/pipelines/nlp/token_classification_pipeline.py +++ b/modelscope/pipelines/nlp/token_classification_pipeline.py @@ -109,13 +109,13 @@ class TokenClassificationPipeline(Pipeline): chunk['span'] = text[chunk['start']:chunk['end']] chunks.append(chunk) - # for cws output + # for cws outputs if len(chunks) > 0 and chunks[0]['type'] == 'cws': spans = [ chunk['span'] for chunk in chunks if chunk['span'].strip() ] seg_result = ' '.join(spans) - outputs = {OutputKeys.OUTPUT: seg_result, OutputKeys.LABELS: []} + outputs = {OutputKeys.OUTPUT: seg_result} # for ner outputs else: diff --git a/modelscope/pipelines/nlp/word_segmentation_pipeline.py b/modelscope/pipelines/nlp/word_segmentation_pipeline.py index 0df8f1ad..c57f6b93 100644 --- a/modelscope/pipelines/nlp/word_segmentation_pipeline.py +++ b/modelscope/pipelines/nlp/word_segmentation_pipeline.py @@ -115,15 +115,15 @@ class WordSegmentationPipeline(Pipeline): chunk['span'] = text[chunk['start']:chunk['end']] chunks.append(chunk) - # for cws output + # for cws outputs if len(chunks) > 0 and chunks[0]['type'] == 'cws': spans = [ chunk['span'] for chunk in chunks if chunk['span'].strip() ] seg_result = ' '.join(spans) - outputs = {OutputKeys.OUTPUT: seg_result, OutputKeys.LABELS: []} + outputs = {OutputKeys.OUTPUT: seg_result} - # for ner outpus + # for ner output else: outputs = {OutputKeys.OUTPUT: chunks} return outputs diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index 13876058..3a3ae820 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -77,7 +77,7 @@ class OfaPreprocessor(Preprocessor): data[key] = item return data - def _ofa_input_compatibility_conversion(self, data): + def _ofa_input_compatibility_conversion(self, data): # fake if 'image' in data and self.cfg.model.get('type', None) == 'ofa': if isinstance(data['image'], str): image = load_image(data['image']) diff --git a/modelscope/preprocessors/nlp/nlp_base.py b/modelscope/preprocessors/nlp/nlp_base.py index 48a04d7a..45efc6e7 100644 --- a/modelscope/preprocessors/nlp/nlp_base.py +++ b/modelscope/preprocessors/nlp/nlp_base.py @@ -34,6 +34,7 @@ class NLPBasePreprocessor(Preprocessor, ABC): label=None, label2id=None, mode=ModeKeys.INFERENCE, + use_fast=None, **kwargs): """The NLP preprocessor base class. @@ -45,14 +46,18 @@ class NLPBasePreprocessor(Preprocessor, ABC): label2id: An optional label2id mapping, the class will try to call utils.parse_label_mapping if this mapping is not supplied. mode: Run this preprocessor in either 'train'/'eval'/'inference' mode + use_fast: use the fast version of tokenizer + """ self.model_dir = model_dir self.first_sequence = first_sequence self.second_sequence = second_sequence self.label = label - self.use_fast = kwargs.pop('use_fast', None) - if self.use_fast is None and os.path.isfile( + self.use_fast = use_fast + if self.use_fast is None and model_dir is None: + self.use_fast = False + elif self.use_fast is None and os.path.isfile( os.path.join(model_dir, 'tokenizer_config.json')): with open(os.path.join(model_dir, 'tokenizer_config.json'), 'r') as f: @@ -61,8 +66,8 @@ class NLPBasePreprocessor(Preprocessor, ABC): self.use_fast = False if self.use_fast is None else self.use_fast self.label2id = label2id - if self.label2id is None: - self.label2id = parse_label_mapping(self.model_dir) + if self.label2id is None and model_dir is not None: + self.label2id = parse_label_mapping(model_dir) super().__init__(mode, **kwargs) @property @@ -106,6 +111,7 @@ class NLPTokenizerPreprocessorBase(NLPBasePreprocessor): label: str = 'label', label2id: dict = None, mode: str = ModeKeys.INFERENCE, + use_fast: bool = None, **kwargs): """The NLP tokenizer preprocessor base class. @@ -122,11 +128,12 @@ class NLPTokenizerPreprocessorBase(NLPBasePreprocessor): - config.json label2id/id2label - label_mapping.json mode: Run this preprocessor in either 'train'/'eval'/'inference' mode, the behavior may be different. + use_fast: use the fast version of tokenizer kwargs: These kwargs will be directly fed into the tokenizer. """ super().__init__(model_dir, first_sequence, second_sequence, label, - label2id, mode) + label2id, mode, use_fast, **kwargs) self.model_dir = model_dir self.tokenize_kwargs = kwargs self.tokenizer = self.build_tokenizer(model_dir) diff --git a/modelscope/preprocessors/nlp/token_classification_preprocessor.py b/modelscope/preprocessors/nlp/token_classification_preprocessor.py index 2de0c806..92b7c46b 100644 --- a/modelscope/preprocessors/nlp/token_classification_preprocessor.py +++ b/modelscope/preprocessors/nlp/token_classification_preprocessor.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Tuple, Union +import numpy as np import torch from modelscope.metainfo import Preprocessors @@ -20,9 +21,7 @@ class WordSegmentationBlankSetToLabelPreprocessor(NLPBasePreprocessor): """ def __init__(self, **kwargs): - super().__init__(**kwargs) - self.first_sequence: str = kwargs.pop('first_sequence', - 'first_sequence') + self.first_sequence: str = kwargs.pop('first_sequence', 'tokens') self.label = kwargs.pop('label', OutputKeys.LABELS) def __call__(self, data: str) -> Union[Dict[str, Any], Tuple]: @@ -80,10 +79,9 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): 'is_split_into_words', False) if 'label2id' in kwargs: kwargs.pop('label2id') - self.tokenize_kwargs = kwargs - @type_assert(object, str) - def __call__(self, data: str) -> Dict[str, Any]: + @type_assert(object, (str, dict)) + def __call__(self, data: Union[dict, str]) -> Dict[str, Any]: """process the raw input data Args: @@ -99,18 +97,24 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): text = None labels_list = None if isinstance(data, str): + # for inference inputs without label text = data + self.tokenize_kwargs['add_special_tokens'] = False elif isinstance(data, dict): + # for finetune inputs with label text = data.get(self.first_sequence) labels_list = data.get(self.label) + if isinstance(text, list): + self.tokenize_kwargs['is_split_into_words'] = True input_ids = [] label_mask = [] offset_mapping = [] - if self.is_split_into_words: - for offset, token in enumerate(list(data)): - subtoken_ids = self.tokenizer.encode( - token, add_special_tokens=False) + token_type_ids = [] + if self.is_split_into_words and self._mode == ModeKeys.INFERENCE: + for offset, token in enumerate(list(text)): + subtoken_ids = self.tokenizer.encode(token, + **self.tokenize_kwargs) if len(subtoken_ids) == 0: subtoken_ids = [self.tokenizer.unk_token_id] input_ids.extend(subtoken_ids) @@ -119,10 +123,9 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): else: if self.tokenizer.is_fast: encodings = self.tokenizer( - text, - add_special_tokens=False, - return_offsets_mapping=True, - **self.tokenize_kwargs) + text, return_offsets_mapping=True, **self.tokenize_kwargs) + attention_mask = encodings['attention_mask'] + token_type_ids = encodings['token_type_ids'] input_ids = encodings['input_ids'] word_ids = encodings.word_ids() for i in range(len(word_ids)): @@ -137,75 +140,85 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): label_mask.append(1) offset_mapping.append(encodings['offset_mapping'][i]) else: - encodings = self.tokenizer( - text, add_special_tokens=False, **self.tokenize_kwargs) + encodings = self.tokenizer(text, **self.tokenize_kwargs) input_ids = encodings['input_ids'] label_mask, offset_mapping = self.get_label_mask_and_offset_mapping( text) - if len(input_ids) >= self.sequence_length - 2: - input_ids = input_ids[:self.sequence_length - 2] - label_mask = label_mask[:self.sequence_length - 2] - input_ids = [self.tokenizer.cls_token_id - ] + input_ids + [self.tokenizer.sep_token_id] - label_mask = [0] + label_mask + [0] - attention_mask = [1] * len(input_ids) - offset_mapping = offset_mapping[:sum(label_mask)] + if self._mode == ModeKeys.INFERENCE: + if len(input_ids) >= self.sequence_length - 2: + input_ids = input_ids[:self.sequence_length - 2] + label_mask = label_mask[:self.sequence_length - 2] + input_ids = [self.tokenizer.cls_token_id + ] + input_ids + [self.tokenizer.sep_token_id] + label_mask = [0] + label_mask + [0] + attention_mask = [1] * len(input_ids) + offset_mapping = offset_mapping[:sum(label_mask)] - if not self.is_transformer_based_model: - input_ids = input_ids[1:-1] - attention_mask = attention_mask[1:-1] - label_mask = label_mask[1:-1] + if not self.is_transformer_based_model: + input_ids = input_ids[1:-1] + attention_mask = attention_mask[1:-1] + label_mask = label_mask[1:-1] - if self._mode == ModeKeys.INFERENCE: input_ids = torch.tensor(input_ids).unsqueeze(0) attention_mask = torch.tensor(attention_mask).unsqueeze(0) label_mask = torch.tensor( label_mask, dtype=torch.bool).unsqueeze(0) - # the token classification - output = { - 'text': text, - 'input_ids': input_ids, - 'attention_mask': attention_mask, - 'label_mask': label_mask, - 'offset_mapping': offset_mapping - } - - # align the labels with tokenized text - if labels_list is not None: - assert self.label2id is not None - # Map that sends B-Xxx label to its I-Xxx counterpart - b_to_i_label = [] - label_enumerate_values = [ - k for k, v in sorted( - self.label2id.items(), key=lambda item: item[1]) - ] - for idx, label in enumerate(label_enumerate_values): - if label.startswith('B-') and label.replace( - 'B-', 'I-') in label_enumerate_values: - b_to_i_label.append( - label_enumerate_values.index( - label.replace('B-', 'I-'))) - else: - b_to_i_label.append(idx) + # the token classification + output = { + 'text': text, + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'label_mask': label_mask, + 'offset_mapping': offset_mapping + } + else: + output = { + 'input_ids': input_ids, + 'token_type_ids': token_type_ids, + 'attention_mask': attention_mask, + 'label_mask': label_mask, + } - label_row = [self.label2id[lb] for lb in labels_list] - previous_word_idx = None - label_ids = [] - for word_idx in word_ids: - if word_idx is None: - label_ids.append(-100) - elif word_idx != previous_word_idx: - label_ids.append(label_row[word_idx]) - else: - if self.label_all_tokens: - label_ids.append(b_to_i_label[label_row[word_idx]]) + # align the labels with tokenized text + if labels_list is not None: + assert self.label2id is not None + # Map that sends B-Xxx label to its I-Xxx counterpart + b_to_i_label = [] + label_enumerate_values = [ + k for k, v in sorted( + self.label2id.items(), key=lambda item: item[1]) + ] + for idx, label in enumerate(label_enumerate_values): + if label.startswith('B-') and label.replace( + 'B-', 'I-') in label_enumerate_values: + b_to_i_label.append( + label_enumerate_values.index( + label.replace('B-', 'I-'))) else: + b_to_i_label.append(idx) + + label_row = [self.label2id[lb] for lb in labels_list] + previous_word_idx = None + label_ids = [] + for word_idx in word_ids: + if word_idx is None: label_ids.append(-100) - previous_word_idx = word_idx - labels = label_ids - output['labels'] = labels + elif word_idx != previous_word_idx: + label_ids.append(label_row[word_idx]) + else: + if self.label_all_tokens: + label_ids.append(b_to_i_label[label_row[word_idx]]) + else: + label_ids.append(-100) + previous_word_idx = word_idx + labels = label_ids + output['labels'] = labels + output = { + k: np.array(v) if isinstance(v, list) else v + for k, v in output.items() + } return output def get_tokenizer_class(self): diff --git a/modelscope/preprocessors/ofa/ocr_recognition.py b/modelscope/preprocessors/ofa/ocr_recognition.py index 95dab492..e15be93f 100644 --- a/modelscope/preprocessors/ofa/ocr_recognition.py +++ b/modelscope/preprocessors/ofa/ocr_recognition.py @@ -74,8 +74,8 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): self.patch_resize_transform = transforms.Compose([ lambda image: ocr_resize( image, - self.cfg.model.patch_image_size, - is_document=self.cfg.model.is_document), + self.patch_image_size, + is_document=self.cfg.model.get('is_document', False)), transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std), ]) diff --git a/modelscope/trainers/audio/kws_farfield_trainer.py b/modelscope/trainers/audio/kws_farfield_trainer.py index a720ced5..85c1a496 100644 --- a/modelscope/trainers/audio/kws_farfield_trainer.py +++ b/modelscope/trainers/audio/kws_farfield_trainer.py @@ -69,11 +69,14 @@ class KWSFarfieldTrainer(BaseTrainer): super().__init__(cfg_file, arg_parse_fn) - self.model = self.build_model() - self.work_dir = work_dir # the number of model output dimension # should update config outside the trainer, if user need more wake word + num_syn = kwargs.get('num_syn', None) + if num_syn: + self.cfg.model.num_syn = num_syn self._num_classes = self.cfg.model.num_syn + self.model = self.build_model() + self.work_dir = work_dir if kwargs.get('launcher', None) is not None: init_dist(kwargs['launcher']) diff --git a/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py b/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py index 3c38884c..3930febb 100644 --- a/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py +++ b/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py @@ -103,20 +103,20 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): def __init__(self, args): super().__init__() - self.sentence_avg = args.sentence_avg - self.eps = args.label_smoothing - self.ignore_prefix_size = args.ignore_prefix_size - self.ignore_eos = args.ignore_eos - self.report_accuracy = args.report_accuracy - self.drop_worst_ratio = args.drop_worst_ratio - self.drop_worst_after = args.drop_worst_after - self.use_rdrop = args.use_rdrop - self.reg_alpha = args.reg_alpha - self.sample_patch_num = args.sample_patch_num + self.sentence_avg = args.get('sentence_avg', False) + self.eps = args.get('label_smoothing', 0.1) + self.ignore_prefix_size = args.get('ignore_prefix_size', 0) + self.ignore_eos = args.get('ignore_eos', False) + self.report_accuracy = args.get('report_accuracy', False) + self.drop_worst_ratio = args.get('drop_worst_ratio', 0.0) + self.drop_worst_after = args.get('drop_worst_after', 0) + self.use_rdrop = args.get('use_rdrop', False) + self.reg_alpha = args.get('reg_alpha', 1.0) + self.sample_patch_num = args.get('sample_patch_num', 196) self.constraint_start = None self.constraint_end = None - if args.constraint_range: + if args.get('constraint_range', None): constraint_start, constraint_end = args.constraint_range.split(',') self.constraint_start = int(constraint_start) self.constraint_end = int(constraint_end) diff --git a/modelscope/trainers/nlp/text_generation_trainer.py b/modelscope/trainers/nlp/text_generation_trainer.py index 0e26f153..f02faf71 100644 --- a/modelscope/trainers/nlp/text_generation_trainer.py +++ b/modelscope/trainers/nlp/text_generation_trainer.py @@ -18,7 +18,7 @@ class TextGenerationTrainer(NlpEpochBasedTrainer): return tokenizer.decode(tokens.tolist(), skip_special_tokens=True) def evaluation_step(self, data): - model = self.model + model = self.model.module if self._dist else self.model model.eval() with torch.no_grad(): diff --git a/modelscope/trainers/nlp_trainer.py b/modelscope/trainers/nlp_trainer.py index a92a3706..5ff6f62f 100644 --- a/modelscope/trainers/nlp_trainer.py +++ b/modelscope/trainers/nlp_trainer.py @@ -586,14 +586,16 @@ class NlpEpochBasedTrainer(EpochBasedTrainer): preprocessor_mode=ModeKeys.TRAIN, **model_args, **self.train_keys, - mode=ModeKeys.TRAIN) + mode=ModeKeys.TRAIN, + use_fast=True) eval_preprocessor = Preprocessor.from_pretrained( self.model_dir, cfg_dict=self.cfg, preprocessor_mode=ModeKeys.EVAL, **model_args, **self.eval_keys, - mode=ModeKeys.EVAL) + mode=ModeKeys.EVAL, + use_fast=True) return train_preprocessor, eval_preprocessor diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index 7478d8e4..12c25f30 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -15,6 +15,7 @@ from torch.utils.data.dataloader import default_collate from torch.utils.data.distributed import DistributedSampler from modelscope.hub.snapshot_download import snapshot_download +from modelscope.hub.utils.utils import create_library_statistics from modelscope.metainfo import Trainers from modelscope.metrics import build_metric, task_default_metrics from modelscope.models.base import Model, TorchModel @@ -436,6 +437,8 @@ class EpochBasedTrainer(BaseTrainer): def train(self, checkpoint_path=None, *args, **kwargs): self._mode = ModeKeys.TRAIN + if hasattr(self.model, 'name'): + create_library_statistics('train', self.model.name, None) if self.train_dataset is None: self.train_dataloader = self.get_train_dataloader() @@ -456,6 +459,8 @@ class EpochBasedTrainer(BaseTrainer): self.train_loop(self.train_dataloader) def evaluate(self, checkpoint_path=None): + if hasattr(self.model, 'name'): + create_library_statistics('evaluate', self.model.name, None) if checkpoint_path is not None and os.path.isfile(checkpoint_path): from modelscope.trainers.hooks import CheckpointHook CheckpointHook.load_checkpoint(checkpoint_path, self) @@ -876,7 +881,7 @@ class EpochBasedTrainer(BaseTrainer): Subclass and override to inject custom behavior. """ - model = self.model + model = self.model.module if self._dist else self.model model.eval() if is_parallel(model): diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 2729b75a..f0a97dbd 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -238,6 +238,14 @@ class DownloadMode(enum.Enum): FORCE_REDOWNLOAD = 'force_redownload' +class DownloadChannel(enum.Enum): + """ Channels of datasets downloading for uv/pv counting. + """ + LOCAL = 'local' + DSW = 'dsw' + EAIS = 'eais' + + class UploadMode(enum.Enum): """ How to upload object to remote. """ diff --git a/modelscope/utils/cv/image_utils.py b/modelscope/utils/cv/image_utils.py index 34dc2348..095c36ec 100644 --- a/modelscope/utils/cv/image_utils.py +++ b/modelscope/utils/cv/image_utils.py @@ -91,6 +91,71 @@ def draw_keypoints(output, original_image): return image +def draw_106face_keypoints(in_path, + keypoints, + boxes, + scale=4.0, + save_path=None): + face_contour_point_index = [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32 + ] + left_eye_brow_point_index = [33, 34, 35, 36, 37, 38, 39, 40, 41, 33] + right_eye_brow_point_index = [42, 43, 44, 45, 46, 47, 48, 49, 50, 42] + left_eye_point_index = [66, 67, 68, 69, 70, 71, 72, 73, 66] + right_eye_point_index = [75, 76, 77, 78, 79, 80, 81, 82, 75] + nose_bridge_point_index = [51, 52, 53, 54] + nose_contour_point_index = [55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65] + mouth_outer_point_index = [ + 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 84 + ] + mouth_inter_point_index = [96, 97, 98, 99, 100, 101, 102, 103, 96] + + img = cv2.imread(in_path) + + for i in range(len(boxes)): + draw_box(img, np.array(boxes[i])) + + image = cv2.resize(img, dsize=None, fx=scale, fy=scale) + + def draw_line(point_index, image, point): + for i in range(len(point_index) - 1): + cur_index = point_index[i] + next_index = point_index[i + 1] + cur_pt = (int(point[cur_index][0] * scale), + int(point[cur_index][1] * scale)) + next_pt = (int(point[next_index][0] * scale), + int(point[next_index][1] * scale)) + cv2.line(image, cur_pt, next_pt, (0, 0, 255), thickness=2) + + for i in range(len(keypoints)): + points = keypoints[i] + + draw_line(face_contour_point_index, image, points) + draw_line(left_eye_brow_point_index, image, points) + draw_line(right_eye_brow_point_index, image, points) + draw_line(left_eye_point_index, image, points) + draw_line(right_eye_point_index, image, points) + draw_line(nose_bridge_point_index, image, points) + draw_line(nose_contour_point_index, image, points) + draw_line(mouth_outer_point_index, image, points) + draw_line(mouth_inter_point_index, image, points) + + size = len(points) + for i in range(size): + x = int(points[i][0]) + y = int(points[i][1]) + cv2.putText(image, str(i), (int(x * scale), int(y * scale)), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) + cv2.circle(image, (int(x * scale), int(y * scale)), 2, (0, 255, 0), + cv2.FILLED) + + if save_path is not None: + cv2.imwrite(save_path, image) + + return image + + def draw_face_detection_no_lm_result(img_path, detection_result): bboxes = np.array(detection_result[OutputKeys.BOXES]) scores = np.array(detection_result[OutputKeys.SCORES]) diff --git a/requirements/framework.txt b/requirements/framework.txt index 2408cda6..a86c0cc5 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -1,6 +1,7 @@ addict attrs -datasets +# version beyond 2.5.2 introduces compatbility issue and is being resolved +datasets<=2.5.2 easydict einops filelock>=3.3.0 diff --git a/requirements/multi-modal.txt b/requirements/multi-modal.txt index 578f0b54..31e9601d 100644 --- a/requirements/multi-modal.txt +++ b/requirements/multi-modal.txt @@ -2,6 +2,8 @@ ftfy>=6.0.3 ofa>=0.0.2 pycocoevalcap>=1.2 pycocotools>=2.0.4 +# compatible with taming-transformers-rom1504 +pytorch_lightning<=1.7.7 # rough-score was just recently updated from 0.0.4 to 0.0.7 # which introduced compatability issues that are being investigated rouge_score<=0.0.4 diff --git a/requirements/science.txt b/requirements/science.txt index 72994f72..c30ff644 100644 --- a/requirements/science.txt +++ b/requirements/science.txt @@ -1,4 +1,6 @@ +biopython iopath +ipdb lmdb ml_collections scipy diff --git a/tests/msdatasets/test_dataset_upload.py b/tests/msdatasets/test_dataset_upload.py index 3d35d480..d91f24d7 100644 --- a/tests/msdatasets/test_dataset_upload.py +++ b/tests/msdatasets/test_dataset_upload.py @@ -8,7 +8,8 @@ import zipfile from modelscope.msdatasets import MsDataset from modelscope.msdatasets.utils.dataset_utils import list_dataset_objects from modelscope.utils import logger as logging -from modelscope.utils.constant import DEFAULT_DATASET_REVISION, ModelFile +from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, DownloadMode, + ModelFile) from modelscope.utils.test_utils import test_level logger = logging.get_logger(__name__) @@ -104,7 +105,10 @@ class DatasetUploadTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_ds_download_dir(self): - test_ds = MsDataset.load(self.dataset_name, self.namespace) + test_ds = MsDataset.load( + self.dataset_name, + namespace=self.namespace, + download_mode=DownloadMode.FORCE_REDOWNLOAD) assert test_ds.config_kwargs['split_config'].values() @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') diff --git a/tests/outputs/test_model_outputs.py b/tests/outputs/test_model_outputs.py index 31271869..311ce201 100644 --- a/tests/outputs/test_model_outputs.py +++ b/tests/outputs/test_model_outputs.py @@ -21,9 +21,10 @@ class TestModelOutput(unittest.TestCase): self.assertEqual(outputs['logits'], torch.Tensor([1])) self.assertEqual(outputs[0], torch.Tensor([1])) self.assertEqual(outputs.logits, torch.Tensor([1])) + outputs.loss = torch.Tensor([2]) logits, loss = outputs self.assertEqual(logits, torch.Tensor([1])) - self.assertTrue(loss is None) + self.assertTrue(loss is not None) if __name__ == '__main__': diff --git a/tests/pipelines/test_face_2d_keypoints.py b/tests/pipelines/test_face_2d_keypoints.py index a5e347e8..7ccc8a59 100644 --- a/tests/pipelines/test_face_2d_keypoints.py +++ b/tests/pipelines/test_face_2d_keypoints.py @@ -1,11 +1,10 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import unittest -import cv2 - from modelscope.outputs import OutputKeys from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import draw_106face_keypoints from modelscope.utils.test_utils import test_level @@ -13,7 +12,7 @@ class EasyCVFace2DKeypointsPipelineTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_face_2d_keypoints(self): - img_path = 'data/test/images/keypoints_detect/test_img_face_2d_keypoints.png' + img_path = 'data/test/images/face_detection.png' model_id = 'damo/cv_mobilenet_face-2d-keypoints_alignment' face_2d_keypoints_align = pipeline( @@ -21,15 +20,21 @@ class EasyCVFace2DKeypointsPipelineTest(unittest.TestCase): output = face_2d_keypoints_align(img_path) output_keypoints = output[OutputKeys.KEYPOINTS] - output_pose = output[OutputKeys.POSES] - - img = cv2.imread(img_path) - img = face_2d_keypoints_align.show_result( - img, output_keypoints, scale=2, save_path='face_keypoints.jpg') - - self.assertEqual(output_keypoints.shape[0], 106) - self.assertEqual(output_keypoints.shape[1], 2) - self.assertEqual(output_pose.shape[0], 3) + output_poses = output[OutputKeys.POSES] + output_boxes = output[OutputKeys.BOXES] + + draw_106face_keypoints( + img_path, + output_keypoints, + output_boxes, + scale=2, + save_path='face_keypoints.jpg') + + for idx in range(len(output_keypoints)): + self.assertEqual(output_keypoints[idx].shape[0], 106) + self.assertEqual(output_keypoints[idx].shape[1], 2) + self.assertEqual(output_poses[idx].shape[0], 3) + self.assertEqual(output_boxes[idx].shape[0], 4) if __name__ == '__main__': diff --git a/tests/pipelines/test_named_entity_recognition.py b/tests/pipelines/test_named_entity_recognition.py index 3658cf3f..aef4aaed 100644 --- a/tests/pipelines/test_named_entity_recognition.py +++ b/tests/pipelines/test_named_entity_recognition.py @@ -19,9 +19,11 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): self.task = Tasks.named_entity_recognition self.model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' + english_model_id = 'damo/nlp_raner_named-entity-recognition_english-large-ecom' tcrf_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' lcrf_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-news' sentence = '这与温岭市新河镇的一个神秘的传说有关。' + sentence_en = 'pizza shovel' @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_tcrf_by_direct_model_download(self): @@ -89,6 +91,12 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): task=Tasks.named_entity_recognition, model=self.lcrf_model_id) print(pipeline_ins(input=self.sentence)) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_english_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, model=self.english_model_id) + print(pipeline_ins(input='pizza shovel')) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_default_model(self): pipeline_ins = pipeline(task=Tasks.named_entity_recognition) diff --git a/tests/pipelines/test_unifold.py b/tests/pipelines/test_unifold.py index df35dc5e..47bb7874 100644 --- a/tests/pipelines/test_unifold.py +++ b/tests/pipelines/test_unifold.py @@ -19,7 +19,7 @@ class UnifoldProteinStructureTest(unittest.TestCase, DemoCompatibilityCheck): self.protein_multimer = 'GAMGLPEEPSSPQESTLKALSLYEAHLSSYIMYLQTFLVKTKQKVNNKNYPEFTLFDTSKLKKDQTLKSIKT' + \ 'NIAALKNHIDKIKPIAMQIYKKYSKNIP' - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_by_direct_model_download(self): model_dir = snapshot_download(self.model_id) mono_pipeline_ins = pipeline(task=self.task, model=model_dir) diff --git a/tests/trainers/test_finetune_token_classificatin.py b/tests/trainers/test_finetune_token_classificatin.py index 9bdab9b7..a92cee7b 100644 --- a/tests/trainers/test_finetune_token_classificatin.py +++ b/tests/trainers/test_finetune_token_classificatin.py @@ -87,7 +87,7 @@ class TestFinetuneTokenClassification(unittest.TestCase): cfg['dataset'] = { 'train': { 'labels': label_enumerate_values, - 'first_sequence': 'first_sequence', + 'first_sequence': 'tokens', 'label': 'labels', } }