| @@ -1,4 +1,6 @@ | |||
| from typing import Dict, Optional, Union | |||
| from typing import Dict, Optional, Union, Any | |||
| import torch | |||
| from ...models import Model | |||
| from ...models.nlp.masked_language_model import \ | |||
| @@ -35,6 +37,7 @@ class FillMaskPipeline(Pipeline): | |||
| fill_mask_model.model_dir, | |||
| first_sequence=first_sequence, | |||
| second_sequence=None) | |||
| fill_mask_model.eval() | |||
| super().__init__(model=fill_mask_model, preprocessor=preprocessor, **kwargs) | |||
| self.preprocessor = preprocessor | |||
| self.tokenizer = preprocessor.tokenizer | |||
| @@ -61,6 +64,11 @@ class FillMaskPipeline(Pipeline): | |||
| } | |||
| } | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return super().forward(inputs, **forward_params) | |||
| def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| """process the prediction results | |||
| @@ -1,6 +1,6 @@ | |||
| import uuid | |||
| from typing import Any, Dict, Union | |||
| import torch | |||
| import uuid | |||
| from typing import Any, Dict, Union | |||
| @@ -42,9 +42,15 @@ class NLIPipeline(Pipeline): | |||
| sc_model.model_dir, | |||
| first_sequence=first_sequence, | |||
| second_sequence=second_sequence) | |||
| sc_model.eval() | |||
| super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) | |||
| assert len(sc_model.id2label) > 0 | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return super().forward(inputs, **forward_params) | |||
| def postprocess(self, inputs: Dict[str, Any], **postprocess_params) -> Dict[str, str]: | |||
| """process the prediction results | |||
| @@ -1,7 +1,7 @@ | |||
| from typing import Any, Dict, Union | |||
| import numpy as np | |||
| import torch | |||
| from ...metainfo import Pipelines | |||
| from ...models.nlp import SbertForSentenceSimilarity | |||
| from ...preprocessors import SequenceClassificationPreprocessor | |||
| @@ -39,11 +39,17 @@ class SentenceSimilarityPipeline(Pipeline): | |||
| sc_model.model_dir, | |||
| first_sequence=first_sequence, | |||
| second_sequence=second_sequence) | |||
| sc_model.eval() | |||
| super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) | |||
| assert hasattr(self.model, 'id2label'), \ | |||
| 'id2label map should be initalizaed in init function.' | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return super().forward(inputs, **forward_params) | |||
| def postprocess(self, inputs: Dict[str, Any], **postprocess_params) -> Dict[str, str]: | |||
| """process the prediction results | |||
| @@ -1,7 +1,7 @@ | |||
| import os | |||
| import uuid | |||
| from typing import Any, Dict, Union | |||
| import torch | |||
| import json | |||
| import numpy as np | |||
| @@ -43,9 +43,15 @@ class SentimentClassificationPipeline(Pipeline): | |||
| sc_model.model_dir, | |||
| first_sequence=first_sequence, | |||
| second_sequence=second_sequence) | |||
| sc_model.eval() | |||
| super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) | |||
| assert len(sc_model.id2label) > 0 | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return super().forward(inputs, **forward_params) | |||
| def postprocess(self, inputs: Dict[str, Any], **postprocess_params) -> Dict[str, str]: | |||
| """process the prediction results | |||
| @@ -1,5 +1,5 @@ | |||
| from typing import Dict, Optional, Union | |||
| from typing import Dict, Optional, Union, Any | |||
| import torch | |||
| from ...metainfo import Pipelines | |||
| from ...models import Model | |||
| from ...models.nlp import PalmForTextGeneration | |||
| @@ -33,9 +33,15 @@ class TextGenerationPipeline(Pipeline): | |||
| model.tokenizer, | |||
| first_sequence='sentence', | |||
| second_sequence=None) | |||
| model.eval() | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| self.tokenizer = model.tokenizer | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return super().forward(inputs, **forward_params) | |||
| def postprocess(self, inputs: Dict[str, Tensor], **postprocess_params) -> Dict[str, str]: | |||
| """process the prediction results | |||
| @@ -1,5 +1,5 @@ | |||
| from typing import Any, Dict, Optional, Union | |||
| import torch | |||
| from ...metainfo import Pipelines | |||
| from ...models import Model | |||
| from ...models.nlp import SbertForTokenClassification | |||
| @@ -30,12 +30,18 @@ class WordSegmentationPipeline(Pipeline): | |||
| SbertForTokenClassification) else Model.from_pretrained(model) | |||
| if preprocessor is None: | |||
| preprocessor = TokenClassifcationPreprocessor(model.model_dir) | |||
| model.eval() | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| self.tokenizer = preprocessor.tokenizer | |||
| self.config = model.config | |||
| assert len(self.config.id2label) > 0 | |||
| self.id2label = self.config.id2label | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return super().forward(inputs, **forward_params) | |||
| def postprocess(self, inputs: Dict[str, Any], **postprocess_params) -> Dict[str, str]: | |||
| """process the prediction results | |||
| @@ -1,7 +1,7 @@ | |||
| import os | |||
| import uuid | |||
| from typing import Any, Dict, Union | |||
| import torch | |||
| import json | |||
| import numpy as np | |||
| from scipy.special import softmax | |||
| @@ -44,6 +44,7 @@ class ZeroShotClassificationPipeline(Pipeline): | |||
| if preprocessor is None: | |||
| preprocessor = ZeroShotClassificationPreprocessor( | |||
| sc_model.model_dir) | |||
| model.eval() | |||
| super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) | |||
| def _sanitize_parameters(self, **kwargs): | |||
| @@ -62,6 +63,11 @@ class ZeroShotClassificationPipeline(Pipeline): | |||
| postprocess_params['multi_label'] = kwargs.pop('multi_label', False) | |||
| return preprocess_params, {}, postprocess_params | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return super().forward(inputs, **forward_params) | |||
| def postprocess(self, | |||
| inputs: Dict[str, Any], | |||
| candidate_labels, | |||