diff --git a/maas_lib/models/nlp/__init__.py b/maas_lib/models/nlp/__init__.py index d85c0ba7..a8489c12 100644 --- a/maas_lib/models/nlp/__init__.py +++ b/maas_lib/models/nlp/__init__.py @@ -1 +1,2 @@ from .sequence_classification_model import * # noqa F403 +from .space.dialog_generation_model import * # noqa F403 diff --git a/maas_lib/models/nlp/space/__init__.py b/maas_lib/models/nlp/space/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/maas_lib/models/nlp/space/dialog_generation_model.py b/maas_lib/models/nlp/space/dialog_generation_model.py new file mode 100644 index 00000000..72a99705 --- /dev/null +++ b/maas_lib/models/nlp/space/dialog_generation_model.py @@ -0,0 +1,48 @@ +from typing import Any, Dict, Optional + +from maas_lib.utils.constant import Tasks +from ...base import Model, Tensor +from ...builder import MODELS + +__all__ = ['DialogGenerationModel'] + + +@MODELS.register_module(Tasks.dialog_generation, module_name=r'space') +class DialogGenerationModel(Model): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the test generation model from the `model_dir` path. + + Args: + model_dir (str): the model path. + model_cls (Optional[Any], optional): model loader, if None, use the + default loader to load model weights, by default None. + """ + + super().__init__(model_dir, *args, **kwargs) + self.model_dir = model_dir + pass + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """return the result by the model + + Args: + input (Dict[str, Any]): the preprocessed data + + Returns: + Dict[str, np.ndarray]: results + Example: + { + 'predictions': array([1]), # lable 0-negative 1-positive + 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), + 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value + } + """ + from numpy import array, float32 + + return { + 'predictions': array([1]), # lable 0-negative 1-positive + 'probabilities': array([[0.11491239, 0.8850876]], dtype=float32), + 'logits': array([[-0.53860897, 1.5029076]], + dtype=float32) # true value + } diff --git a/maas_lib/pipelines/nlp/__init__.py b/maas_lib/pipelines/nlp/__init__.py index f9d874e7..01bd0e2a 100644 --- a/maas_lib/pipelines/nlp/__init__.py +++ b/maas_lib/pipelines/nlp/__init__.py @@ -1 +1,2 @@ from .sequence_classification_pipeline import * # noqa F403 +from .space.dialog_generation_pipeline import * # noqa F403 diff --git a/maas_lib/pipelines/nlp/space/__init__.py b/maas_lib/pipelines/nlp/space/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/maas_lib/pipelines/nlp/space/dialog_generation_pipeline.py b/maas_lib/pipelines/nlp/space/dialog_generation_pipeline.py new file mode 100644 index 00000000..df193d66 --- /dev/null +++ b/maas_lib/pipelines/nlp/space/dialog_generation_pipeline.py @@ -0,0 +1,49 @@ +from typing import Any, Dict, Optional + +from maas_lib.models.nlp import DialogGenerationModel +from maas_lib.preprocessors import DialogGenerationPreprocessor +from maas_lib.utils.constant import Tasks +from ...base import Model, Tensor +from ...builder import PIPELINES + +__all__ = ['DialogGenerationPipeline'] + + +@PIPELINES.register_module(Tasks.dialog_generation, module_name=r'space') +class DialogGenerationPipeline(Model): + + def __init__(self, model: DialogGenerationModel, + preprocessor: DialogGenerationPreprocessor, **kwargs): + """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction + + Args: + model (SequenceClassificationModel): a model instance + preprocessor (SequenceClassificationPreprocessor): a preprocessor instance + """ + + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + pass + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """return the result by the model + + Args: + input (Dict[str, Any]): the preprocessed data + + Returns: + Dict[str, np.ndarray]: results + Example: + { + 'predictions': array([1]), # lable 0-negative 1-positive + 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), + 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value + } + """ + from numpy import array, float32 + + return { + 'predictions': array([1]), # lable 0-negative 1-positive + 'probabilities': array([[0.11491239, 0.8850876]], dtype=float32), + 'logits': array([[-0.53860897, 1.5029076]], + dtype=float32) # true value + } diff --git a/maas_lib/preprocessors/nlp.py b/maas_lib/preprocessors/nlp.py index bde401c2..f4877510 100644 --- a/maas_lib/preprocessors/nlp.py +++ b/maas_lib/preprocessors/nlp.py @@ -10,7 +10,10 @@ from maas_lib.utils.type_assert import type_assert from .base import Preprocessor from .builder import PREPROCESSORS -__all__ = ['Tokenize', 'SequenceClassificationPreprocessor'] +__all__ = [ + 'Tokenize', 'SequenceClassificationPreprocessor', + 'DialogGenerationPreprocessor' +] @PREPROCESSORS.register_module(Fields.nlp) @@ -89,3 +92,31 @@ class SequenceClassificationPreprocessor(Preprocessor): rst['token_type_ids'].append(feature['token_type_ids']) return rst + + +@PREPROCESSORS.register_module(Fields.nlp, module_name=r'space') +class DialogGenerationPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + model_dir (str): model path + """ + super().__init__(*args, **kwargs) + + pass + + @type_assert(object, str) + def __call__(self, data: str) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str): a sentence + Example: + 'you are so handsome.' + + Returns: + Dict[str, Any]: the preprocessed data + """ + return None diff --git a/maas_lib/utils/constant.py b/maas_lib/utils/constant.py index 8f808a6f..bd4f8e17 100644 --- a/maas_lib/utils/constant.py +++ b/maas_lib/utils/constant.py @@ -38,6 +38,7 @@ class Tasks(object): token_classification = 'token-classification' conversational = 'conversational' text_generation = 'text-generation' + dialog_generation = 'dialog-generation' table_question_answering = 'table-question-answering' feature_extraction = 'feature-extraction' sentence_similarity = 'sentence-similarity' diff --git a/tests/pipelines/nlp/__init__.py b/tests/pipelines/nlp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/pipelines/nlp/test_dialog_generation.py b/tests/pipelines/nlp/test_dialog_generation.py new file mode 100644 index 00000000..68b82132 --- /dev/null +++ b/tests/pipelines/nlp/test_dialog_generation.py @@ -0,0 +1,67 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import os.path as osp +import tempfile +import unittest + +from maas_lib.fileio import File +from maas_lib.models.nlp import DialogGenerationModel +from maas_lib.pipelines import DialogGenerationPipeline, pipeline +from maas_lib.preprocessors import DialogGenerationPreprocessor +from maas_lib.utils.constant import Tasks + +dialog_case = [{ + 'user': + 'am looking for a place to to stay that has cheap price range it should be in a type of hotel', + 'sys': + 'okay , do you have a specific area you want to stay in ?' +}, { + 'user': + 'no , i just need to make sure it is cheap . oh , and i need parking', + 'sys': + 'i found 1 cheap hotel for you that include -s parking . do you like me to book it ?' +}, { + 'user': + 'yes , please . 6 people 3 nights starting on tuesday .', + 'sys': + "i am sorry but i was n't able to book that for you for tuesday . is there another day you would like " + 'to stay or perhaps a shorter stay ? ' +}, { + 'user': + 'how about only 2 nights .', + 'sys': + 'booking was successful . reference number is : 7gawk763 . anything else i can do for you ?', +}, { + 'user': 'no , that will be all . goodbye .', + 'sys': 'thank you for using our services .' +}] + + +class DialogGenerationTest(unittest.TestCase): + + def test_run(self): + for item in dialog_case: + q = item['user'] + a = item['sys'] + print('user:{}'.format(q)) + print('sys:{}'.format(a)) + + # preprocessor = DialogGenerationPreprocessor() + # # data = DialogGenerationData() + # model = DialogGenerationModel(path, preprocessor.tokenizer) + # pipeline = DialogGenerationPipeline(model, preprocessor) + # + # history_dialog = [] + # for item in dialog_case: + # user_question = item['user'] + # print('user: {}'.format(user_question)) + # + # pipeline(user_question) + # + # sys_answer, history_dialog = pipeline() + # + # print('sys : {}'.format(sys_answer)) + + +if __name__ == '__main__': + unittest.main()