| @@ -1,10 +1,7 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os.path as osp | |||||
| from typing import List, Union | from typing import List, Union | ||||
| from attr import has | |||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.models.base import Model | from modelscope.models.base import Model | ||||
| from modelscope.utils.config import Config, ConfigDict | from modelscope.utils.config import Config, ConfigDict | ||||
| @@ -44,7 +41,7 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| 'damo/cv_unet_person-image-cartoon_compound-models'), | 'damo/cv_unet_person-image-cartoon_compound-models'), | ||||
| Tasks.ocr_detection: (Pipelines.ocr_detection, | Tasks.ocr_detection: (Pipelines.ocr_detection, | ||||
| 'damo/cv_resnet18_ocr-detection-line-level_damo'), | 'damo/cv_resnet18_ocr-detection-line-level_damo'), | ||||
| Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask_large'), | |||||
| Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'), | |||||
| Tasks.action_recognition: (Pipelines.action_recognition, | Tasks.action_recognition: (Pipelines.action_recognition, | ||||
| 'damo/cv_TAdaConv_action-recognition'), | 'damo/cv_TAdaConv_action-recognition'), | ||||
| } | } | ||||
| @@ -313,7 +313,6 @@ class TextGenerationPreprocessor(Preprocessor): | |||||
| rst['input_ids'].append(feature['input_ids']) | rst['input_ids'].append(feature['input_ids']) | ||||
| rst['attention_mask'].append(feature['attention_mask']) | rst['attention_mask'].append(feature['attention_mask']) | ||||
| # rst['token_type_ids'].append(feature['token_type_ids']) | |||||
| return {k: torch.tensor(v) for k, v in rst.items()} | return {k: torch.tensor(v) for k, v in rst.items()} | ||||
| @@ -1,6 +1,4 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | |||||
| import shutil | |||||
| import unittest | import unittest | ||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| @@ -14,10 +12,10 @@ from modelscope.utils.test_utils import test_level | |||||
| class FillMaskTest(unittest.TestCase): | class FillMaskTest(unittest.TestCase): | ||||
| model_id_sbert = { | model_id_sbert = { | ||||
| 'zh': 'damo/nlp_structbert_fill-mask-chinese_large', | |||||
| 'en': 'damo/nlp_structbert_fill-mask-english_large' | |||||
| 'zh': 'damo/nlp_structbert_fill-mask_chinese-large', | |||||
| 'en': 'damo/nlp_structbert_fill-mask_english-large' | |||||
| } | } | ||||
| model_id_veco = 'damo/nlp_veco_fill-mask_large' | |||||
| model_id_veco = 'damo/nlp_veco_fill-mask-large' | |||||
| ori_texts = { | ori_texts = { | ||||
| 'zh': | 'zh': | ||||