Browse Source

solve comment: 1. change MaskedLMModelBase to MaskedLanguageModelBase 2. remove a useless import

master
雨泓 3 years ago
parent
commit
c376d59143
2 changed files with 8 additions and 8 deletions
  1. +4
    -4
      modelscope/models/nlp/masked_language_model.py
  2. +4
    -4
      modelscope/pipelines/nlp/fill_mask_pipeline.py

+ 4
- 4
modelscope/models/nlp/masked_language_model.py View File

@@ -7,10 +7,10 @@ from ...utils.constant import Tasks
from ..base import Model, Tensor
from ..builder import MODELS

__all__ = ['StructBertForMaskedLM', 'VecoForMaskedLM', 'MaskedLMModelBase']
__all__ = ['StructBertForMaskedLM', 'VecoForMaskedLM', 'MaskedLanguageModelBase']


class MaskedLMModelBase(Model):
class MaskedLanguageModelBase(Model):

def __init__(self, model_dir: str, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
@@ -48,7 +48,7 @@ class MaskedLMModelBase(Model):


@MODELS.register_module(Tasks.fill_mask, module_name=Models.structbert)
class StructBertForMaskedLM(MaskedLMModelBase):
class StructBertForMaskedLM(MaskedLanguageModelBase):

def build_model(self):
from sofa import SbertForMaskedLM
@@ -56,7 +56,7 @@ class StructBertForMaskedLM(MaskedLMModelBase):


@MODELS.register_module(Tasks.fill_mask, module_name=Models.veco)
class VecoForMaskedLM(MaskedLMModelBase):
class VecoForMaskedLM(MaskedLanguageModelBase):

def build_model(self):
from sofa import VecoForMaskedLM


+ 4
- 4
modelscope/pipelines/nlp/fill_mask_pipeline.py View File

@@ -4,7 +4,7 @@ import torch

from ...metainfo import Pipelines
from ...models import Model
from ...models.nlp.masked_language_model import MaskedLMModelBase
from ...models.nlp.masked_language_model import MaskedLanguageModelBase
from ...preprocessors import FillMaskPreprocessor
from ...utils.constant import Tasks
from ..base import Pipeline, Tensor
@@ -17,18 +17,18 @@ __all__ = ['FillMaskPipeline']
class FillMaskPipeline(Pipeline):

def __init__(self,
model: Union[MaskedLMModelBase, str],
model: Union[MaskedLanguageModelBase, str],
preprocessor: Optional[FillMaskPreprocessor] = None,
first_sequence='sentense',
**kwargs):
"""use `model` and `preprocessor` to create a nlp fill mask pipeline for prediction

Args:
model (MaskedLMModelBase): a model instance
model (MaskedLanguageModelBase): a model instance
preprocessor (FillMaskPreprocessor): a preprocessor instance
"""
fill_mask_model = model if isinstance(
model, MaskedLMModelBase) else Model.from_pretrained(model)
model, MaskedLanguageModelBase) else Model.from_pretrained(model)
assert fill_mask_model.config is not None

if preprocessor is None:


Loading…
Cancel
Save