diff --git a/modelscope/models/nlp/space/dialog_state_tracking.py b/modelscope/models/nlp/space/dialog_state_tracking.py new file mode 100644 index 00000000..4b1c44d3 --- /dev/null +++ b/modelscope/models/nlp/space/dialog_state_tracking.py @@ -0,0 +1,77 @@ +import os +from typing import Any, Dict + +from modelscope.utils.config import Config +from modelscope.utils.constant import Tasks +from ...base import Model, Tensor +from ...builder import MODELS +from .model.generator import Generator +from .model.model_base import ModelBase + +__all__ = ['DialogStateTrackingModel'] + + +@MODELS.register_module(Tasks.dialog_state_tracking, module_name=r'space-dst') +class DialogStateTrackingModel(Model): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the test generation model from the `model_dir` path. + + Args: + model_dir (str): the model path. + model_cls (Optional[Any], optional): model loader, if None, use the + default loader to load model weights, by default None. + """ + + super().__init__(model_dir, *args, **kwargs) + self.model_dir = model_dir + self.config = kwargs.pop( + 'config', + Config.from_file( + os.path.join(self.model_dir, 'configuration.json'))) + self.text_field = kwargs.pop( + 'text_field', + IntentBPETextField(self.model_dir, config=self.config)) + + self.generator = Generator.create(self.config, reader=self.text_field) + self.model = ModelBase.create( + model_dir=model_dir, + config=self.config, + reader=self.text_field, + generator=self.generator) + + def to_tensor(array): + """ + numpy array -> tensor + """ + import torch + array = torch.tensor(array) + return array.cuda() if self.config.use_gpu else array + + self.trainer = IntentTrainer( + model=self.model, + to_tensor=to_tensor, + config=self.config, + reader=self.text_field) + self.trainer.load() + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """return the result by the model + + Args: + input (Dict[str, Any]): the preprocessed data + + Returns: + Dict[str, np.ndarray]: results + Example: + { + 'predictions': array([1]), # lable 0-negative 1-positive + 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), + 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value + } + """ + import numpy as np + pred = self.trainer.forward(input) + pred = np.squeeze(pred[0], 0) + + return {'pred': pred} diff --git a/modelscope/pipelines/nlp/space/dialog_state_tracking.py b/modelscope/pipelines/nlp/space/dialog_state_tracking.py new file mode 100644 index 00000000..4a943095 --- /dev/null +++ b/modelscope/pipelines/nlp/space/dialog_state_tracking.py @@ -0,0 +1,46 @@ +from typing import Any, Dict, Optional + +from modelscope.models.nlp import DialogModelingModel +from modelscope.preprocessors import DialogModelingPreprocessor +from modelscope.utils.constant import Tasks +from ...base import Pipeline, Tensor +from ...builder import PIPELINES + +__all__ = ['DialogStateTrackingPipeline'] + + +@PIPELINES.register_module( + Tasks.dialog_state_tracking, module_name=r'space-dst') +class DialogStateTrackingPipeline(Pipeline): + + def __init__(self, model: DialogModelingModel, + preprocessor: DialogModelingPreprocessor, **kwargs): + """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction + + Args: + model (SequenceClassificationModel): a model instance + preprocessor (SequenceClassificationPreprocessor): a preprocessor instance + """ + + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + self.model = model + self.preprocessor = preprocessor + + def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + sys_rsp = self.preprocessor.text_field.tokenizer.convert_ids_to_tokens( + inputs['resp']) + assert len(sys_rsp) > 2 + sys_rsp = sys_rsp[1:len(sys_rsp) - 1] + # sys_rsp = self.preprocessor.text_field.tokenizer. + + inputs['sys'] = sys_rsp + + return inputs diff --git a/modelscope/preprocessors/space/dialog_state_tracking_preprocessor.py b/modelscope/preprocessors/space/dialog_state_tracking_preprocessor.py new file mode 100644 index 00000000..6f67d580 --- /dev/null +++ b/modelscope/preprocessors/space/dialog_state_tracking_preprocessor.py @@ -0,0 +1,49 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Any, Dict + +from modelscope.preprocessors.space.fields.intent_field import \ + IntentBPETextField +from modelscope.utils.config import Config +from modelscope.utils.constant import Fields +from modelscope.utils.type_assert import type_assert +from ..base import Preprocessor +from ..builder import PREPROCESSORS + +__all__ = ['DialogStateTrackingPreprocessor'] + + +@PREPROCESSORS.register_module(Fields.nlp, module_name=r'space-dst') +class DialogStateTrackingPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + model_dir (str): model path + """ + super().__init__(*args, **kwargs) + + self.model_dir: str = model_dir + self.config = Config.from_file( + os.path.join(self.model_dir, 'configuration.json')) + self.text_field = IntentBPETextField( + self.model_dir, config=self.config) + + @type_assert(object, str) + def __call__(self, data: str) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str): a sentence + Example: + 'you are so handsome.' + + Returns: + Dict[str, Any]: the preprocessed data + """ + samples = self.text_field.preprocessor([data]) + samples, _ = self.text_field.collate_fn_multi_turn(samples) + + return samples diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 6fe21407..d89f0496 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -42,6 +42,7 @@ class Tasks(object): text_generation = 'text-generation' dialog_modeling = 'dialog-modeling' dialog_intent_prediction = 'dialog-intent-prediction' + dialog_state_tracking = 'dialog-state-tracking' table_question_answering = 'table-question-answering' feature_extraction = 'feature-extraction' sentence_similarity = 'sentence-similarity' diff --git a/tests/pipelines/nlp/test_dialog_state_tracking.py b/tests/pipelines/nlp/test_dialog_state_tracking.py new file mode 100644 index 00000000..e69de29b