| @@ -128,7 +128,7 @@ class TorchModelExporter(Exporter): | |||||
| args_list = list(args) | args_list = list(args) | ||||
| else: | else: | ||||
| args_list = [args] | args_list = [args] | ||||
| if isinstance(args_list[-1], dict): | |||||
| if isinstance(args_list[-1], Mapping): | |||||
| args_dict = args_list[-1] | args_dict = args_list[-1] | ||||
| args_list = args_list[:-1] | args_list = args_list[:-1] | ||||
| n_nonkeyword = len(args_list) | n_nonkeyword = len(args_list) | ||||
| @@ -284,9 +284,8 @@ class TorchModelExporter(Exporter): | |||||
| 'Model property dummy_inputs must be set.') | 'Model property dummy_inputs must be set.') | ||||
| dummy_inputs = collate_fn(dummy_inputs, device) | dummy_inputs = collate_fn(dummy_inputs, device) | ||||
| if isinstance(dummy_inputs, Mapping): | if isinstance(dummy_inputs, Mapping): | ||||
| dummy_inputs = self._decide_input_format(model, dummy_inputs) | |||||
| dummy_inputs_filter = [] | dummy_inputs_filter = [] | ||||
| for _input in dummy_inputs: | |||||
| for _input in self._decide_input_format(model, dummy_inputs): | |||||
| if _input is not None: | if _input is not None: | ||||
| dummy_inputs_filter.append(_input) | dummy_inputs_filter.append(_input) | ||||
| else: | else: | ||||
| @@ -23,7 +23,8 @@ from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | |||||
| API_RESPONSE_FIELD_MESSAGE, | API_RESPONSE_FIELD_MESSAGE, | ||||
| API_RESPONSE_FIELD_USERNAME, | API_RESPONSE_FIELD_USERNAME, | ||||
| DEFAULT_CREDENTIALS_PATH, | DEFAULT_CREDENTIALS_PATH, | ||||
| MODELSCOPE_ENVIRONMENT, ONE_YEAR_SECONDS, | |||||
| MODELSCOPE_ENVIRONMENT, | |||||
| MODELSCOPE_USERNAME, ONE_YEAR_SECONDS, | |||||
| Licenses, ModelVisibility) | Licenses, ModelVisibility) | ||||
| from modelscope.hub.errors import (InvalidParameter, NotExistError, | from modelscope.hub.errors import (InvalidParameter, NotExistError, | ||||
| NotLoginException, NoValidRevisionError, | NotLoginException, NoValidRevisionError, | ||||
| @@ -38,8 +39,8 @@ from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | |||||
| DEFAULT_MODEL_REVISION, | DEFAULT_MODEL_REVISION, | ||||
| DEFAULT_REPOSITORY_REVISION, | DEFAULT_REPOSITORY_REVISION, | ||||
| MASTER_MODEL_BRANCH, DatasetFormations, | MASTER_MODEL_BRANCH, DatasetFormations, | ||||
| DatasetMetaFormats, DownloadMode, | |||||
| ModelFile) | |||||
| DatasetMetaFormats, DownloadChannel, | |||||
| DownloadMode, ModelFile) | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .utils.utils import (get_endpoint, get_release_datetime, | from .utils.utils import (get_endpoint, get_release_datetime, | ||||
| model_id_to_group_owner_name) | model_id_to_group_owner_name) | ||||
| @@ -645,6 +646,25 @@ class HubApi: | |||||
| def check_local_cookies(self, use_cookies) -> CookieJar: | def check_local_cookies(self, use_cookies) -> CookieJar: | ||||
| return self._check_cookie(use_cookies=use_cookies) | return self._check_cookie(use_cookies=use_cookies) | ||||
| def dataset_download_uv(self, dataset_name: str, namespace: str): | |||||
| if not dataset_name or not namespace: | |||||
| raise ValueError('dataset_name or namespace cannot be empty!') | |||||
| # get channel and user_name | |||||
| channel = DownloadChannel.LOCAL.value | |||||
| user_name = '' | |||||
| if MODELSCOPE_ENVIRONMENT in os.environ: | |||||
| channel = os.environ[MODELSCOPE_ENVIRONMENT] | |||||
| if MODELSCOPE_USERNAME in os.environ: | |||||
| user_name = os.environ[MODELSCOPE_USERNAME] | |||||
| url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/uv/{channel}?user={user_name}' | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| r = requests.post(url, cookies=cookies, headers=self.headers) | |||||
| resp = r.json() | |||||
| raise_on_error(resp) | |||||
| return resp['Message'] | |||||
| class ModelScopeConfig: | class ModelScopeConfig: | ||||
| path_credential = expanduser(DEFAULT_CREDENTIALS_PATH) | path_credential = expanduser(DEFAULT_CREDENTIALS_PATH) | ||||
| @@ -760,14 +780,18 @@ class ModelScopeConfig: | |||||
| env = 'custom' | env = 'custom' | ||||
| if MODELSCOPE_ENVIRONMENT in os.environ: | if MODELSCOPE_ENVIRONMENT in os.environ: | ||||
| env = os.environ[MODELSCOPE_ENVIRONMENT] | env = os.environ[MODELSCOPE_ENVIRONMENT] | ||||
| user_name = 'unknown' | |||||
| if MODELSCOPE_USERNAME in os.environ: | |||||
| user_name = os.environ[MODELSCOPE_USERNAME] | |||||
| ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s' % ( | |||||
| ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s; user/%s' % ( | |||||
| __version__, | __version__, | ||||
| platform.python_version(), | platform.python_version(), | ||||
| ModelScopeConfig.get_user_session_id(), | ModelScopeConfig.get_user_session_id(), | ||||
| platform.platform(), | platform.platform(), | ||||
| platform.processor(), | platform.processor(), | ||||
| env, | env, | ||||
| user_name, | |||||
| ) | ) | ||||
| if isinstance(user_agent, dict): | if isinstance(user_agent, dict): | ||||
| ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items()) | ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items()) | ||||
| @@ -18,6 +18,7 @@ API_RESPONSE_FIELD_EMAIL = 'Email' | |||||
| API_RESPONSE_FIELD_MESSAGE = 'Message' | API_RESPONSE_FIELD_MESSAGE = 'Message' | ||||
| MODELSCOPE_ENVIRONMENT = 'MODELSCOPE_ENVIRONMENT' | MODELSCOPE_ENVIRONMENT = 'MODELSCOPE_ENVIRONMENT' | ||||
| MODELSCOPE_SDK_DEBUG = 'MODELSCOPE_SDK_DEBUG' | MODELSCOPE_SDK_DEBUG = 'MODELSCOPE_SDK_DEBUG' | ||||
| MODELSCOPE_USERNAME = 'MODELSCOPE_USERNAME' | |||||
| ONE_YEAR_SECONDS = 24 * 365 * 60 * 60 | ONE_YEAR_SECONDS = 24 * 365 * 60 * 60 | ||||
| @@ -349,11 +349,13 @@ class CLIP(nn.Module): | |||||
| text_num_hidden_layers: int, | text_num_hidden_layers: int, | ||||
| text_type_vocab_size: int, | text_type_vocab_size: int, | ||||
| tokenizer: FullTokenizer, | tokenizer: FullTokenizer, | ||||
| # vision_head_width, added this param for ViT-H | |||||
| vision_head_width: int = 64, | |||||
| ): | ): | ||||
| super().__init__() | super().__init__() | ||||
| if isinstance(vision_layers, (tuple, list)): | if isinstance(vision_layers, (tuple, list)): | ||||
| vision_heads = vision_width * 32 // 64 | |||||
| vision_heads = vision_width * 32 // vision_head_width | |||||
| self.visual = ModifiedResNet( | self.visual = ModifiedResNet( | ||||
| layers=vision_layers, | layers=vision_layers, | ||||
| output_dim=embed_dim, | output_dim=embed_dim, | ||||
| @@ -361,7 +363,7 @@ class CLIP(nn.Module): | |||||
| input_resolution=image_resolution, | input_resolution=image_resolution, | ||||
| width=vision_width) | width=vision_width) | ||||
| else: | else: | ||||
| vision_heads = vision_width // 64 | |||||
| vision_heads = vision_width // vision_head_width | |||||
| self.visual = VisualTransformer( | self.visual = VisualTransformer( | ||||
| input_resolution=image_resolution, | input_resolution=image_resolution, | ||||
| patch_size=vision_patch_size, | patch_size=vision_patch_size, | ||||
| @@ -0,0 +1,3 @@ | |||||
| # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||||
| # and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||||
| """Unifold Modules.""" | |||||
| @@ -274,6 +274,8 @@ class MsDataset: | |||||
| try: | try: | ||||
| api.on_dataset_download( | api.on_dataset_download( | ||||
| dataset_name=download_dataset, namespace=namespace) | dataset_name=download_dataset, namespace=namespace) | ||||
| api.dataset_download_uv( | |||||
| dataset_name=download_dataset, namespace=namespace) | |||||
| except Exception as e: | except Exception as e: | ||||
| logger.error(e) | logger.error(e) | ||||
| @@ -491,17 +491,8 @@ TASK_OUTPUTS = { | |||||
| # word segmentation result for single sample | # word segmentation result for single sample | ||||
| # { | # { | ||||
| # "output": "今天 天气 不错 , 适合 出去 游玩" | # "output": "今天 天气 不错 , 适合 出去 游玩" | ||||
| # "labels": [ | |||||
| # {'word': '今天', 'label': 'PROPN'}, | |||||
| # {'word': '天气', 'label': 'PROPN'}, | |||||
| # {'word': '不错', 'label': 'VERB'}, | |||||
| # {'word': ',', 'label': 'NUM'}, | |||||
| # {'word': '适合', 'label': 'NOUN'}, | |||||
| # {'word': '出去', 'label': 'PART'}, | |||||
| # {'word': '游玩', 'label': 'ADV'}, | |||||
| # ] | |||||
| # } | # } | ||||
| Tasks.word_segmentation: [OutputKeys.OUTPUT, OutputKeys.LABELS], | |||||
| Tasks.word_segmentation: [OutputKeys.OUTPUT], | |||||
| # TODO @wenmeng.zwm support list of result check | # TODO @wenmeng.zwm support list of result check | ||||
| # named entity recognition result for single sample | # named entity recognition result for single sample | ||||
| @@ -93,9 +93,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| 'damo/cv_resnet50_live-category'), | 'damo/cv_resnet50_live-category'), | ||||
| Tasks.video_category: (Pipelines.video_category, | Tasks.video_category: (Pipelines.video_category, | ||||
| 'damo/cv_resnet50_video-category'), | 'damo/cv_resnet50_video-category'), | ||||
| Tasks.multi_modal_embedding: | |||||
| (Pipelines.multi_modal_embedding, | |||||
| 'damo/multi-modal_clip-vit-large-patch14_zh'), | |||||
| Tasks.multi_modal_embedding: (Pipelines.multi_modal_embedding, | |||||
| 'damo/multi-modal_clip-vit-base-patch16_zh'), | |||||
| Tasks.generative_multi_modal_embedding: | Tasks.generative_multi_modal_embedding: | ||||
| (Pipelines.generative_multi_modal_embedding, | (Pipelines.generative_multi_modal_embedding, | ||||
| 'damo/multi-modal_gemm-vit-large-patch14_generative-multi-modal-embedding' | 'damo/multi-modal_gemm-vit-large-patch14_generative-multi-modal-embedding' | ||||
| @@ -109,13 +109,13 @@ class TokenClassificationPipeline(Pipeline): | |||||
| chunk['span'] = text[chunk['start']:chunk['end']] | chunk['span'] = text[chunk['start']:chunk['end']] | ||||
| chunks.append(chunk) | chunks.append(chunk) | ||||
| # for cws output | |||||
| # for cws outputs | |||||
| if len(chunks) > 0 and chunks[0]['type'] == 'cws': | if len(chunks) > 0 and chunks[0]['type'] == 'cws': | ||||
| spans = [ | spans = [ | ||||
| chunk['span'] for chunk in chunks if chunk['span'].strip() | chunk['span'] for chunk in chunks if chunk['span'].strip() | ||||
| ] | ] | ||||
| seg_result = ' '.join(spans) | seg_result = ' '.join(spans) | ||||
| outputs = {OutputKeys.OUTPUT: seg_result, OutputKeys.LABELS: []} | |||||
| outputs = {OutputKeys.OUTPUT: seg_result} | |||||
| # for ner outputs | # for ner outputs | ||||
| else: | else: | ||||
| @@ -115,15 +115,15 @@ class WordSegmentationPipeline(Pipeline): | |||||
| chunk['span'] = text[chunk['start']:chunk['end']] | chunk['span'] = text[chunk['start']:chunk['end']] | ||||
| chunks.append(chunk) | chunks.append(chunk) | ||||
| # for cws output | |||||
| # for cws outputs | |||||
| if len(chunks) > 0 and chunks[0]['type'] == 'cws': | if len(chunks) > 0 and chunks[0]['type'] == 'cws': | ||||
| spans = [ | spans = [ | ||||
| chunk['span'] for chunk in chunks if chunk['span'].strip() | chunk['span'] for chunk in chunks if chunk['span'].strip() | ||||
| ] | ] | ||||
| seg_result = ' '.join(spans) | seg_result = ' '.join(spans) | ||||
| outputs = {OutputKeys.OUTPUT: seg_result, OutputKeys.LABELS: []} | |||||
| outputs = {OutputKeys.OUTPUT: seg_result} | |||||
| # for ner outpus | |||||
| # for ner output | |||||
| else: | else: | ||||
| outputs = {OutputKeys.OUTPUT: chunks} | outputs = {OutputKeys.OUTPUT: chunks} | ||||
| return outputs | return outputs | ||||
| @@ -96,7 +96,6 @@ class OfaPreprocessor(Preprocessor): | |||||
| data = input | data = input | ||||
| else: | else: | ||||
| data = self._build_dict(input) | data = self._build_dict(input) | ||||
| data = self._ofa_input_compatibility_conversion(data) | |||||
| sample = self.preprocess(data) | sample = self.preprocess(data) | ||||
| str_data = dict() | str_data = dict() | ||||
| for k, v in data.items(): | for k, v in data.items(): | ||||
| @@ -34,6 +34,7 @@ class NLPBasePreprocessor(Preprocessor, ABC): | |||||
| label=None, | label=None, | ||||
| label2id=None, | label2id=None, | ||||
| mode=ModeKeys.INFERENCE, | mode=ModeKeys.INFERENCE, | ||||
| use_fast=None, | |||||
| **kwargs): | **kwargs): | ||||
| """The NLP preprocessor base class. | """The NLP preprocessor base class. | ||||
| @@ -45,14 +46,18 @@ class NLPBasePreprocessor(Preprocessor, ABC): | |||||
| label2id: An optional label2id mapping, the class will try to call utils.parse_label_mapping | label2id: An optional label2id mapping, the class will try to call utils.parse_label_mapping | ||||
| if this mapping is not supplied. | if this mapping is not supplied. | ||||
| mode: Run this preprocessor in either 'train'/'eval'/'inference' mode | mode: Run this preprocessor in either 'train'/'eval'/'inference' mode | ||||
| use_fast: use the fast version of tokenizer | |||||
| """ | """ | ||||
| self.model_dir = model_dir | self.model_dir = model_dir | ||||
| self.first_sequence = first_sequence | self.first_sequence = first_sequence | ||||
| self.second_sequence = second_sequence | self.second_sequence = second_sequence | ||||
| self.label = label | self.label = label | ||||
| self.use_fast = kwargs.pop('use_fast', None) | |||||
| if self.use_fast is None and os.path.isfile( | |||||
| self.use_fast = use_fast | |||||
| if self.use_fast is None and model_dir is None: | |||||
| self.use_fast = False | |||||
| elif self.use_fast is None and os.path.isfile( | |||||
| os.path.join(model_dir, 'tokenizer_config.json')): | os.path.join(model_dir, 'tokenizer_config.json')): | ||||
| with open(os.path.join(model_dir, 'tokenizer_config.json'), | with open(os.path.join(model_dir, 'tokenizer_config.json'), | ||||
| 'r') as f: | 'r') as f: | ||||
| @@ -61,8 +66,8 @@ class NLPBasePreprocessor(Preprocessor, ABC): | |||||
| self.use_fast = False if self.use_fast is None else self.use_fast | self.use_fast = False if self.use_fast is None else self.use_fast | ||||
| self.label2id = label2id | self.label2id = label2id | ||||
| if self.label2id is None: | |||||
| self.label2id = parse_label_mapping(self.model_dir) | |||||
| if self.label2id is None and model_dir is not None: | |||||
| self.label2id = parse_label_mapping(model_dir) | |||||
| super().__init__(mode, **kwargs) | super().__init__(mode, **kwargs) | ||||
| @property | @property | ||||
| @@ -106,6 +111,7 @@ class NLPTokenizerPreprocessorBase(NLPBasePreprocessor): | |||||
| label: str = 'label', | label: str = 'label', | ||||
| label2id: dict = None, | label2id: dict = None, | ||||
| mode: str = ModeKeys.INFERENCE, | mode: str = ModeKeys.INFERENCE, | ||||
| use_fast: bool = None, | |||||
| **kwargs): | **kwargs): | ||||
| """The NLP tokenizer preprocessor base class. | """The NLP tokenizer preprocessor base class. | ||||
| @@ -122,11 +128,12 @@ class NLPTokenizerPreprocessorBase(NLPBasePreprocessor): | |||||
| - config.json label2id/id2label | - config.json label2id/id2label | ||||
| - label_mapping.json | - label_mapping.json | ||||
| mode: Run this preprocessor in either 'train'/'eval'/'inference' mode, the behavior may be different. | mode: Run this preprocessor in either 'train'/'eval'/'inference' mode, the behavior may be different. | ||||
| use_fast: use the fast version of tokenizer | |||||
| kwargs: These kwargs will be directly fed into the tokenizer. | kwargs: These kwargs will be directly fed into the tokenizer. | ||||
| """ | """ | ||||
| super().__init__(model_dir, first_sequence, second_sequence, label, | super().__init__(model_dir, first_sequence, second_sequence, label, | ||||
| label2id, mode) | |||||
| label2id, mode, use_fast, **kwargs) | |||||
| self.model_dir = model_dir | self.model_dir = model_dir | ||||
| self.tokenize_kwargs = kwargs | self.tokenize_kwargs = kwargs | ||||
| self.tokenizer = self.build_tokenizer(model_dir) | self.tokenizer = self.build_tokenizer(model_dir) | ||||
| @@ -2,6 +2,7 @@ | |||||
| from typing import Any, Dict, Tuple, Union | from typing import Any, Dict, Tuple, Union | ||||
| import numpy as np | |||||
| import torch | import torch | ||||
| from modelscope.metainfo import Preprocessors | from modelscope.metainfo import Preprocessors | ||||
| @@ -20,9 +21,7 @@ class WordSegmentationBlankSetToLabelPreprocessor(NLPBasePreprocessor): | |||||
| """ | """ | ||||
| def __init__(self, **kwargs): | def __init__(self, **kwargs): | ||||
| super().__init__(**kwargs) | |||||
| self.first_sequence: str = kwargs.pop('first_sequence', | |||||
| 'first_sequence') | |||||
| self.first_sequence: str = kwargs.pop('first_sequence', 'tokens') | |||||
| self.label = kwargs.pop('label', OutputKeys.LABELS) | self.label = kwargs.pop('label', OutputKeys.LABELS) | ||||
| def __call__(self, data: str) -> Union[Dict[str, Any], Tuple]: | def __call__(self, data: str) -> Union[Dict[str, Any], Tuple]: | ||||
| @@ -80,10 +79,9 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||||
| 'is_split_into_words', False) | 'is_split_into_words', False) | ||||
| if 'label2id' in kwargs: | if 'label2id' in kwargs: | ||||
| kwargs.pop('label2id') | kwargs.pop('label2id') | ||||
| self.tokenize_kwargs = kwargs | |||||
| @type_assert(object, str) | |||||
| def __call__(self, data: str) -> Dict[str, Any]: | |||||
| @type_assert(object, (str, dict)) | |||||
| def __call__(self, data: Union[dict, str]) -> Dict[str, Any]: | |||||
| """process the raw input data | """process the raw input data | ||||
| Args: | Args: | ||||
| @@ -99,18 +97,24 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||||
| text = None | text = None | ||||
| labels_list = None | labels_list = None | ||||
| if isinstance(data, str): | if isinstance(data, str): | ||||
| # for inference inputs without label | |||||
| text = data | text = data | ||||
| self.tokenize_kwargs['add_special_tokens'] = False | |||||
| elif isinstance(data, dict): | elif isinstance(data, dict): | ||||
| # for finetune inputs with label | |||||
| text = data.get(self.first_sequence) | text = data.get(self.first_sequence) | ||||
| labels_list = data.get(self.label) | labels_list = data.get(self.label) | ||||
| if isinstance(text, list): | |||||
| self.tokenize_kwargs['is_split_into_words'] = True | |||||
| input_ids = [] | input_ids = [] | ||||
| label_mask = [] | label_mask = [] | ||||
| offset_mapping = [] | offset_mapping = [] | ||||
| if self.is_split_into_words: | |||||
| for offset, token in enumerate(list(data)): | |||||
| subtoken_ids = self.tokenizer.encode( | |||||
| token, add_special_tokens=False) | |||||
| token_type_ids = [] | |||||
| if self.is_split_into_words and self._mode == ModeKeys.INFERENCE: | |||||
| for offset, token in enumerate(list(text)): | |||||
| subtoken_ids = self.tokenizer.encode(token, | |||||
| **self.tokenize_kwargs) | |||||
| if len(subtoken_ids) == 0: | if len(subtoken_ids) == 0: | ||||
| subtoken_ids = [self.tokenizer.unk_token_id] | subtoken_ids = [self.tokenizer.unk_token_id] | ||||
| input_ids.extend(subtoken_ids) | input_ids.extend(subtoken_ids) | ||||
| @@ -119,10 +123,9 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||||
| else: | else: | ||||
| if self.tokenizer.is_fast: | if self.tokenizer.is_fast: | ||||
| encodings = self.tokenizer( | encodings = self.tokenizer( | ||||
| text, | |||||
| add_special_tokens=False, | |||||
| return_offsets_mapping=True, | |||||
| **self.tokenize_kwargs) | |||||
| text, return_offsets_mapping=True, **self.tokenize_kwargs) | |||||
| attention_mask = encodings['attention_mask'] | |||||
| token_type_ids = encodings['token_type_ids'] | |||||
| input_ids = encodings['input_ids'] | input_ids = encodings['input_ids'] | ||||
| word_ids = encodings.word_ids() | word_ids = encodings.word_ids() | ||||
| for i in range(len(word_ids)): | for i in range(len(word_ids)): | ||||
| @@ -137,75 +140,85 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||||
| label_mask.append(1) | label_mask.append(1) | ||||
| offset_mapping.append(encodings['offset_mapping'][i]) | offset_mapping.append(encodings['offset_mapping'][i]) | ||||
| else: | else: | ||||
| encodings = self.tokenizer( | |||||
| text, add_special_tokens=False, **self.tokenize_kwargs) | |||||
| encodings = self.tokenizer(text, **self.tokenize_kwargs) | |||||
| input_ids = encodings['input_ids'] | input_ids = encodings['input_ids'] | ||||
| label_mask, offset_mapping = self.get_label_mask_and_offset_mapping( | label_mask, offset_mapping = self.get_label_mask_and_offset_mapping( | ||||
| text) | text) | ||||
| if len(input_ids) >= self.sequence_length - 2: | |||||
| input_ids = input_ids[:self.sequence_length - 2] | |||||
| label_mask = label_mask[:self.sequence_length - 2] | |||||
| input_ids = [self.tokenizer.cls_token_id | |||||
| ] + input_ids + [self.tokenizer.sep_token_id] | |||||
| label_mask = [0] + label_mask + [0] | |||||
| attention_mask = [1] * len(input_ids) | |||||
| offset_mapping = offset_mapping[:sum(label_mask)] | |||||
| if self._mode == ModeKeys.INFERENCE: | |||||
| if len(input_ids) >= self.sequence_length - 2: | |||||
| input_ids = input_ids[:self.sequence_length - 2] | |||||
| label_mask = label_mask[:self.sequence_length - 2] | |||||
| input_ids = [self.tokenizer.cls_token_id | |||||
| ] + input_ids + [self.tokenizer.sep_token_id] | |||||
| label_mask = [0] + label_mask + [0] | |||||
| attention_mask = [1] * len(input_ids) | |||||
| offset_mapping = offset_mapping[:sum(label_mask)] | |||||
| if not self.is_transformer_based_model: | |||||
| input_ids = input_ids[1:-1] | |||||
| attention_mask = attention_mask[1:-1] | |||||
| label_mask = label_mask[1:-1] | |||||
| if not self.is_transformer_based_model: | |||||
| input_ids = input_ids[1:-1] | |||||
| attention_mask = attention_mask[1:-1] | |||||
| label_mask = label_mask[1:-1] | |||||
| if self._mode == ModeKeys.INFERENCE: | |||||
| input_ids = torch.tensor(input_ids).unsqueeze(0) | input_ids = torch.tensor(input_ids).unsqueeze(0) | ||||
| attention_mask = torch.tensor(attention_mask).unsqueeze(0) | attention_mask = torch.tensor(attention_mask).unsqueeze(0) | ||||
| label_mask = torch.tensor( | label_mask = torch.tensor( | ||||
| label_mask, dtype=torch.bool).unsqueeze(0) | label_mask, dtype=torch.bool).unsqueeze(0) | ||||
| # the token classification | |||||
| output = { | |||||
| 'text': text, | |||||
| 'input_ids': input_ids, | |||||
| 'attention_mask': attention_mask, | |||||
| 'label_mask': label_mask, | |||||
| 'offset_mapping': offset_mapping | |||||
| } | |||||
| # align the labels with tokenized text | |||||
| if labels_list is not None: | |||||
| assert self.label2id is not None | |||||
| # Map that sends B-Xxx label to its I-Xxx counterpart | |||||
| b_to_i_label = [] | |||||
| label_enumerate_values = [ | |||||
| k for k, v in sorted( | |||||
| self.label2id.items(), key=lambda item: item[1]) | |||||
| ] | |||||
| for idx, label in enumerate(label_enumerate_values): | |||||
| if label.startswith('B-') and label.replace( | |||||
| 'B-', 'I-') in label_enumerate_values: | |||||
| b_to_i_label.append( | |||||
| label_enumerate_values.index( | |||||
| label.replace('B-', 'I-'))) | |||||
| else: | |||||
| b_to_i_label.append(idx) | |||||
| # the token classification | |||||
| output = { | |||||
| 'text': text, | |||||
| 'input_ids': input_ids, | |||||
| 'attention_mask': attention_mask, | |||||
| 'label_mask': label_mask, | |||||
| 'offset_mapping': offset_mapping | |||||
| } | |||||
| else: | |||||
| output = { | |||||
| 'input_ids': input_ids, | |||||
| 'token_type_ids': token_type_ids, | |||||
| 'attention_mask': attention_mask, | |||||
| 'label_mask': label_mask, | |||||
| } | |||||
| label_row = [self.label2id[lb] for lb in labels_list] | |||||
| previous_word_idx = None | |||||
| label_ids = [] | |||||
| for word_idx in word_ids: | |||||
| if word_idx is None: | |||||
| label_ids.append(-100) | |||||
| elif word_idx != previous_word_idx: | |||||
| label_ids.append(label_row[word_idx]) | |||||
| else: | |||||
| if self.label_all_tokens: | |||||
| label_ids.append(b_to_i_label[label_row[word_idx]]) | |||||
| # align the labels with tokenized text | |||||
| if labels_list is not None: | |||||
| assert self.label2id is not None | |||||
| # Map that sends B-Xxx label to its I-Xxx counterpart | |||||
| b_to_i_label = [] | |||||
| label_enumerate_values = [ | |||||
| k for k, v in sorted( | |||||
| self.label2id.items(), key=lambda item: item[1]) | |||||
| ] | |||||
| for idx, label in enumerate(label_enumerate_values): | |||||
| if label.startswith('B-') and label.replace( | |||||
| 'B-', 'I-') in label_enumerate_values: | |||||
| b_to_i_label.append( | |||||
| label_enumerate_values.index( | |||||
| label.replace('B-', 'I-'))) | |||||
| else: | else: | ||||
| b_to_i_label.append(idx) | |||||
| label_row = [self.label2id[lb] for lb in labels_list] | |||||
| previous_word_idx = None | |||||
| label_ids = [] | |||||
| for word_idx in word_ids: | |||||
| if word_idx is None: | |||||
| label_ids.append(-100) | label_ids.append(-100) | ||||
| previous_word_idx = word_idx | |||||
| labels = label_ids | |||||
| output['labels'] = labels | |||||
| elif word_idx != previous_word_idx: | |||||
| label_ids.append(label_row[word_idx]) | |||||
| else: | |||||
| if self.label_all_tokens: | |||||
| label_ids.append(b_to_i_label[label_row[word_idx]]) | |||||
| else: | |||||
| label_ids.append(-100) | |||||
| previous_word_idx = word_idx | |||||
| labels = label_ids | |||||
| output['labels'] = labels | |||||
| output = { | |||||
| k: np.array(v) if isinstance(v, list) else v | |||||
| for k, v in output.items() | |||||
| } | |||||
| return output | return output | ||||
| def get_tokenizer_class(self): | def get_tokenizer_class(self): | ||||
| @@ -2,12 +2,12 @@ | |||||
| from typing import Any, Dict | from typing import Any, Dict | ||||
| import torch | import torch | ||||
| from PIL import Image | |||||
| import unicodedata2 | |||||
| from torchvision import transforms | from torchvision import transforms | ||||
| from torchvision.transforms import InterpolationMode | from torchvision.transforms import InterpolationMode | ||||
| from torchvision.transforms import functional as F | from torchvision.transforms import functional as F | ||||
| from zhconv import convert | |||||
| from modelscope.preprocessors.image import load_image | |||||
| from modelscope.utils.constant import ModeKeys | from modelscope.utils.constant import ModeKeys | ||||
| from .base import OfaBasePreprocessor | from .base import OfaBasePreprocessor | ||||
| @@ -98,8 +98,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): | |||||
| def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||||
| sample = self._build_infer_sample(data) | sample = self._build_infer_sample(data) | ||||
| target = data[self.column_map['text']] | |||||
| target = target.translate(self.transtab).strip() | |||||
| target = sample['label'] | |||||
| target_token_list = target.strip().split() | target_token_list = target.strip().split() | ||||
| target = ' '.join(target_token_list[:self.max_tgt_length]) | target = ' '.join(target_token_list[:self.max_tgt_length]) | ||||
| sample['target'] = self.tokenize_text(target, add_bos=False) | sample['target'] = self.tokenize_text(target, add_bos=False) | ||||
| @@ -119,5 +118,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): | |||||
| 'patch_mask': torch.tensor([True]) | 'patch_mask': torch.tensor([True]) | ||||
| } | } | ||||
| if 'text' in self.column_map and self.column_map['text'] in data: | if 'text' in self.column_map and self.column_map['text'] in data: | ||||
| sample['label'] = data[self.column_map['text']] | |||||
| target = data[self.column_map['text']] | |||||
| target = unicodedata2.normalize('NFKC', convert(target, 'zh-hans')) | |||||
| sample['label'] = target | |||||
| return sample | return sample | ||||
| @@ -18,7 +18,7 @@ class TextGenerationTrainer(NlpEpochBasedTrainer): | |||||
| return tokenizer.decode(tokens.tolist(), skip_special_tokens=True) | return tokenizer.decode(tokens.tolist(), skip_special_tokens=True) | ||||
| def evaluation_step(self, data): | def evaluation_step(self, data): | ||||
| model = self.model | |||||
| model = self.model.module if self._dist else self.model | |||||
| model.eval() | model.eval() | ||||
| with torch.no_grad(): | with torch.no_grad(): | ||||
| @@ -586,14 +586,16 @@ class NlpEpochBasedTrainer(EpochBasedTrainer): | |||||
| preprocessor_mode=ModeKeys.TRAIN, | preprocessor_mode=ModeKeys.TRAIN, | ||||
| **model_args, | **model_args, | ||||
| **self.train_keys, | **self.train_keys, | ||||
| mode=ModeKeys.TRAIN) | |||||
| mode=ModeKeys.TRAIN, | |||||
| use_fast=True) | |||||
| eval_preprocessor = Preprocessor.from_pretrained( | eval_preprocessor = Preprocessor.from_pretrained( | ||||
| self.model_dir, | self.model_dir, | ||||
| cfg_dict=self.cfg, | cfg_dict=self.cfg, | ||||
| preprocessor_mode=ModeKeys.EVAL, | preprocessor_mode=ModeKeys.EVAL, | ||||
| **model_args, | **model_args, | ||||
| **self.eval_keys, | **self.eval_keys, | ||||
| mode=ModeKeys.EVAL) | |||||
| mode=ModeKeys.EVAL, | |||||
| use_fast=True) | |||||
| return train_preprocessor, eval_preprocessor | return train_preprocessor, eval_preprocessor | ||||
| @@ -876,7 +876,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| Subclass and override to inject custom behavior. | Subclass and override to inject custom behavior. | ||||
| """ | """ | ||||
| model = self.model | |||||
| model = self.model.module if self._dist else self.model | |||||
| model.eval() | model.eval() | ||||
| if is_parallel(model): | if is_parallel(model): | ||||
| @@ -238,6 +238,14 @@ class DownloadMode(enum.Enum): | |||||
| FORCE_REDOWNLOAD = 'force_redownload' | FORCE_REDOWNLOAD = 'force_redownload' | ||||
| class DownloadChannel(enum.Enum): | |||||
| """ Channels of datasets downloading for uv/pv counting. | |||||
| """ | |||||
| LOCAL = 'local' | |||||
| DSW = 'dsw' | |||||
| EAIS = 'eais' | |||||
| class UploadMode(enum.Enum): | class UploadMode(enum.Enum): | ||||
| """ How to upload object to remote. | """ How to upload object to remote. | ||||
| """ | """ | ||||
| @@ -1,6 +1,7 @@ | |||||
| addict | addict | ||||
| attrs | attrs | ||||
| datasets | |||||
| # version beyond 2.5.2 introduces compatbility issue and is being resolved | |||||
| datasets<=2.5.2 | |||||
| easydict | easydict | ||||
| einops | einops | ||||
| filelock>=3.3.0 | filelock>=3.3.0 | ||||
| @@ -11,3 +11,5 @@ timm | |||||
| tokenizers | tokenizers | ||||
| torchvision | torchvision | ||||
| transformers>=4.12.0 | transformers>=4.12.0 | ||||
| unicodedata2 | |||||
| zhconv | |||||
| @@ -1,4 +1,6 @@ | |||||
| biopython | |||||
| iopath | iopath | ||||
| ipdb | |||||
| lmdb | lmdb | ||||
| ml_collections | ml_collections | ||||
| scipy | scipy | ||||
| @@ -8,7 +8,8 @@ import zipfile | |||||
| from modelscope.msdatasets import MsDataset | from modelscope.msdatasets import MsDataset | ||||
| from modelscope.msdatasets.utils.dataset_utils import list_dataset_objects | from modelscope.msdatasets.utils.dataset_utils import list_dataset_objects | ||||
| from modelscope.utils import logger as logging | from modelscope.utils import logger as logging | ||||
| from modelscope.utils.constant import DEFAULT_DATASET_REVISION, ModelFile | |||||
| from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, DownloadMode, | |||||
| ModelFile) | |||||
| from modelscope.utils.test_utils import test_level | from modelscope.utils.test_utils import test_level | ||||
| logger = logging.get_logger(__name__) | logger = logging.get_logger(__name__) | ||||
| @@ -104,7 +105,10 @@ class DatasetUploadTest(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| def test_ds_download_dir(self): | def test_ds_download_dir(self): | ||||
| test_ds = MsDataset.load(self.dataset_name, self.namespace) | |||||
| test_ds = MsDataset.load( | |||||
| self.dataset_name, | |||||
| namespace=self.namespace, | |||||
| download_mode=DownloadMode.FORCE_REDOWNLOAD) | |||||
| assert test_ds.config_kwargs['split_config'].values() | assert test_ds.config_kwargs['split_config'].values() | ||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| @@ -21,9 +21,10 @@ class TestModelOutput(unittest.TestCase): | |||||
| self.assertEqual(outputs['logits'], torch.Tensor([1])) | self.assertEqual(outputs['logits'], torch.Tensor([1])) | ||||
| self.assertEqual(outputs[0], torch.Tensor([1])) | self.assertEqual(outputs[0], torch.Tensor([1])) | ||||
| self.assertEqual(outputs.logits, torch.Tensor([1])) | self.assertEqual(outputs.logits, torch.Tensor([1])) | ||||
| outputs.loss = torch.Tensor([2]) | |||||
| logits, loss = outputs | logits, loss = outputs | ||||
| self.assertEqual(logits, torch.Tensor([1])) | self.assertEqual(logits, torch.Tensor([1])) | ||||
| self.assertTrue(loss is None) | |||||
| self.assertTrue(loss is not None) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| @@ -19,9 +19,11 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| self.task = Tasks.named_entity_recognition | self.task = Tasks.named_entity_recognition | ||||
| self.model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' | self.model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' | ||||
| english_model_id = 'damo/nlp_raner_named-entity-recognition_english-large-ecom' | |||||
| tcrf_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' | tcrf_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' | ||||
| lcrf_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-news' | lcrf_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-news' | ||||
| sentence = '这与温岭市新河镇的一个神秘的传说有关。' | sentence = '这与温岭市新河镇的一个神秘的传说有关。' | ||||
| sentence_en = 'pizza shovel' | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| def test_run_tcrf_by_direct_model_download(self): | def test_run_tcrf_by_direct_model_download(self): | ||||
| @@ -89,6 +91,12 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| task=Tasks.named_entity_recognition, model=self.lcrf_model_id) | task=Tasks.named_entity_recognition, model=self.lcrf_model_id) | ||||
| print(pipeline_ins(input=self.sentence)) | print(pipeline_ins(input=self.sentence)) | ||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_english_with_model_name(self): | |||||
| pipeline_ins = pipeline( | |||||
| task=Tasks.named_entity_recognition, model=self.english_model_id) | |||||
| print(pipeline_ins(input='pizza shovel')) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| def test_run_with_default_model(self): | def test_run_with_default_model(self): | ||||
| pipeline_ins = pipeline(task=Tasks.named_entity_recognition) | pipeline_ins = pipeline(task=Tasks.named_entity_recognition) | ||||
| @@ -19,7 +19,7 @@ class UnifoldProteinStructureTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| self.protein_multimer = 'GAMGLPEEPSSPQESTLKALSLYEAHLSSYIMYLQTFLVKTKQKVNNKNYPEFTLFDTSKLKKDQTLKSIKT' + \ | self.protein_multimer = 'GAMGLPEEPSSPQESTLKALSLYEAHLSSYIMYLQTFLVKTKQKVNNKNYPEFTLFDTSKLKKDQTLKSIKT' + \ | ||||
| 'NIAALKNHIDKIKPIAMQIYKKYSKNIP' | 'NIAALKNHIDKIKPIAMQIYKKYSKNIP' | ||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_by_direct_model_download(self): | def test_run_by_direct_model_download(self): | ||||
| model_dir = snapshot_download(self.model_id) | model_dir = snapshot_download(self.model_id) | ||||
| mono_pipeline_ins = pipeline(task=self.task, model=model_dir) | mono_pipeline_ins = pipeline(task=self.task, model=model_dir) | ||||
| @@ -87,7 +87,7 @@ class TestFinetuneTokenClassification(unittest.TestCase): | |||||
| cfg['dataset'] = { | cfg['dataset'] = { | ||||
| 'train': { | 'train': { | ||||
| 'labels': label_enumerate_values, | 'labels': label_enumerate_values, | ||||
| 'first_sequence': 'first_sequence', | |||||
| 'first_sequence': 'tokens', | |||||
| 'label': 'labels', | 'label': 'labels', | ||||
| } | } | ||||
| } | } | ||||
| @@ -85,7 +85,7 @@ class TestOfaTrainer(unittest.TestCase): | |||||
| 'ocr_fudanvi_zh', | 'ocr_fudanvi_zh', | ||||
| subset_name='scene', | subset_name='scene', | ||||
| namespace='modelscope', | namespace='modelscope', | ||||
| split='train[:200]', | |||||
| split='train[800:900]', | |||||
| download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS), | download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS), | ||||
| eval_dataset=MsDataset.load( | eval_dataset=MsDataset.load( | ||||
| 'ocr_fudanvi_zh', | 'ocr_fudanvi_zh', | ||||