|
|
|
@@ -0,0 +1,52 @@ |
|
|
|
from typing import Any, Dict |
|
|
|
|
|
|
|
from maas_lib.utils.constant import Tasks |
|
|
|
from ..base import Model, Tensor |
|
|
|
from ..builder import MODELS |
|
|
|
|
|
|
|
__all__ = ['PalmForTextGenerationModel'] |
|
|
|
|
|
|
|
|
|
|
|
@MODELS.register_module(Tasks.text_generation, module_name=r'palm') |
|
|
|
class PalmForTextGenerationModel(Model): |
|
|
|
|
|
|
|
def __init__(self, model_dir: str, *args, **kwargs): |
|
|
|
"""initialize the text generation model from the `model_dir` path. |
|
|
|
|
|
|
|
Args: |
|
|
|
model_dir (str): the model path. |
|
|
|
model_cls (Optional[Any], optional): model loader, if None, use the |
|
|
|
default loader to load model weights, by default None. |
|
|
|
""" |
|
|
|
from sofa import PalmTokenizer |
|
|
|
|
|
|
|
super().__init__(model_dir, *args, **kwargs) |
|
|
|
self.model_dir = model_dir |
|
|
|
|
|
|
|
from sofa.models.palm import PalmForConditionalGeneration, TextGenerator |
|
|
|
tokenizer = kwargs.pop('tokenizer', |
|
|
|
PalmTokenizer.from_pretrained(model_dir)) |
|
|
|
model = PalmForConditionalGeneration.from_pretrained(model_dir) |
|
|
|
self.generator = TextGenerator(model, tokenizer) |
|
|
|
|
|
|
|
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: |
|
|
|
"""return the result by the model |
|
|
|
|
|
|
|
Args: |
|
|
|
input (Dict[str, Any]): the preprocessed data |
|
|
|
|
|
|
|
Returns: |
|
|
|
Dict[str, np.ndarray]: results |
|
|
|
Example: |
|
|
|
{ |
|
|
|
'predictions': array([1]), # lable 0-negative 1-positive |
|
|
|
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), |
|
|
|
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value |
|
|
|
} |
|
|
|
""" |
|
|
|
|
|
|
|
encoder_inputs = [ |
|
|
|
input['input_ids'], input['token_type_ids'], |
|
|
|
input['attention_mask'] |
|
|
|
] |
|
|
|
return self.generator(encoder_inputs) |