Browse Source

merge with fill mask

master
智丞 3 years ago
parent
commit
37901ad696
3 changed files with 4 additions and 10 deletions
  1. +1
    -4
      modelscope/pipelines/builder.py
  2. +0
    -1
      modelscope/preprocessors/nlp.py
  3. +3
    -5
      tests/pipelines/test_fill_mask.py

+ 1
- 4
modelscope/pipelines/builder.py View File

@@ -1,10 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.


import os.path as osp
from typing import List, Union from typing import List, Union


from attr import has

from modelscope.metainfo import Pipelines from modelscope.metainfo import Pipelines
from modelscope.models.base import Model from modelscope.models.base import Model
from modelscope.utils.config import Config, ConfigDict from modelscope.utils.config import Config, ConfigDict
@@ -44,7 +41,7 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_unet_person-image-cartoon_compound-models'), 'damo/cv_unet_person-image-cartoon_compound-models'),
Tasks.ocr_detection: (Pipelines.ocr_detection, Tasks.ocr_detection: (Pipelines.ocr_detection,
'damo/cv_resnet18_ocr-detection-line-level_damo'), 'damo/cv_resnet18_ocr-detection-line-level_damo'),
Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask_large'),
Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'),
Tasks.action_recognition: (Pipelines.action_recognition, Tasks.action_recognition: (Pipelines.action_recognition,
'damo/cv_TAdaConv_action-recognition'), 'damo/cv_TAdaConv_action-recognition'),
} }


+ 0
- 1
modelscope/preprocessors/nlp.py View File

@@ -313,7 +313,6 @@ class TextGenerationPreprocessor(Preprocessor):


rst['input_ids'].append(feature['input_ids']) rst['input_ids'].append(feature['input_ids'])
rst['attention_mask'].append(feature['attention_mask']) rst['attention_mask'].append(feature['attention_mask'])
# rst['token_type_ids'].append(feature['token_type_ids'])
return {k: torch.tensor(v) for k, v in rst.items()} return {k: torch.tensor(v) for k, v in rst.items()}






+ 3
- 5
tests/pipelines/test_fill_mask.py View File

@@ -1,6 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import unittest import unittest


from modelscope.hub.snapshot_download import snapshot_download from modelscope.hub.snapshot_download import snapshot_download
@@ -14,10 +12,10 @@ from modelscope.utils.test_utils import test_level


class FillMaskTest(unittest.TestCase): class FillMaskTest(unittest.TestCase):
model_id_sbert = { model_id_sbert = {
'zh': 'damo/nlp_structbert_fill-mask-chinese_large',
'en': 'damo/nlp_structbert_fill-mask-english_large'
'zh': 'damo/nlp_structbert_fill-mask_chinese-large',
'en': 'damo/nlp_structbert_fill-mask_english-large'
} }
model_id_veco = 'damo/nlp_veco_fill-mask_large'
model_id_veco = 'damo/nlp_veco_fill-mask-large'


ori_texts = { ori_texts = {
'zh': 'zh':


Loading…
Cancel
Save