|
|
|
@@ -4,7 +4,7 @@ import torch |
|
|
|
|
|
|
|
from ...metainfo import Pipelines |
|
|
|
from ...models import Model |
|
|
|
from ...models.nlp.masked_language_model import MaskedLMModelBase |
|
|
|
from ...models.nlp.masked_language_model import MaskedLanguageModelBase |
|
|
|
from ...preprocessors import FillMaskPreprocessor |
|
|
|
from ...utils.constant import Tasks |
|
|
|
from ..base import Pipeline, Tensor |
|
|
|
@@ -17,18 +17,18 @@ __all__ = ['FillMaskPipeline'] |
|
|
|
class FillMaskPipeline(Pipeline): |
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
model: Union[MaskedLMModelBase, str], |
|
|
|
model: Union[MaskedLanguageModelBase, str], |
|
|
|
preprocessor: Optional[FillMaskPreprocessor] = None, |
|
|
|
first_sequence='sentense', |
|
|
|
**kwargs): |
|
|
|
"""use `model` and `preprocessor` to create a nlp fill mask pipeline for prediction |
|
|
|
|
|
|
|
Args: |
|
|
|
model (MaskedLMModelBase): a model instance |
|
|
|
model (MaskedLanguageModelBase): a model instance |
|
|
|
preprocessor (FillMaskPreprocessor): a preprocessor instance |
|
|
|
""" |
|
|
|
fill_mask_model = model if isinstance( |
|
|
|
model, MaskedLMModelBase) else Model.from_pretrained(model) |
|
|
|
model, MaskedLanguageModelBase) else Model.from_pretrained(model) |
|
|
|
assert fill_mask_model.config is not None |
|
|
|
|
|
|
|
if preprocessor is None: |
|
|
|
|