From 318ac98a7ebcdf83b82e4efe2afe68035bedb2d6 Mon Sep 17 00:00:00 2001 From: ly119399 Date: Sat, 25 Jun 2022 12:24:18 +0800 Subject: [PATCH] add init --- modelscope/models/nlp/__init__.py | 1 + modelscope/pipelines/nlp/__init__.py | 1 + modelscope/preprocessors/__init__.py | 1 + .../nlp/test_dialog_state_tracking.py | 61 +++++++++++++++++++ 4 files changed, 64 insertions(+) diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index 1406b965..e62ab404 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -4,3 +4,4 @@ from .sbert_for_sentence_similarity import * # noqa F403 from .sbert_for_token_classification import * # noqa F403 from .space.dialog_intent_prediction_model import * # noqa F403 from .space.dialog_modeling_model import * # noqa F403 +from .space.dialog_state_tracking import * # noqa F403 diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index a67b4436..adfa1d4c 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -2,5 +2,6 @@ from .sentence_similarity_pipeline import * # noqa F403 from .sequence_classification_pipeline import * # noqa F403 from .space.dialog_intent_prediction_pipeline import * # noqa F403 from .space.dialog_modeling_pipeline import * # noqa F403 +from .space.dialog_state_tracking import * # noqa F403 from .text_generation_pipeline import * # noqa F403 from .word_segmentation_pipeline import * # noqa F403 diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index 133f7004..7b67507a 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -8,4 +8,5 @@ from .image import LoadImage, load_image from .nlp import * # noqa F403 from .space.dialog_intent_prediction_preprocessor import * # noqa F403 from .space.dialog_modeling_preprocessor import * # noqa F403 +from .space.dialog_state_tracking_preprocessor import * # noqa F403 from .text_to_speech import * # noqa F403 diff --git a/tests/pipelines/nlp/test_dialog_state_tracking.py b/tests/pipelines/nlp/test_dialog_state_tracking.py index e69de29b..a6c989bd 100644 --- a/tests/pipelines/nlp/test_dialog_state_tracking.py +++ b/tests/pipelines/nlp/test_dialog_state_tracking.py @@ -0,0 +1,61 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from maas_hub.snapshot_download import snapshot_download + +from modelscope.models import Model +from modelscope.models.nlp import DialogStateTrackingModel +from modelscope.pipelines import DialogStateTrackingPipeline, pipeline +from modelscope.preprocessors import DialogStateTrackingPreprocessor +from modelscope.utils.constant import Tasks + + +class DialogIntentPredictionTest(unittest.TestCase): + model_id = 'damo/nlp_space_dialog-intent-prediction' + test_case = [ + 'How do I locate my card?', + 'I still have not received my new card, I ordered over a week ago.' + ] + + @unittest.skip('test with snapshot_download') + def test_run(self): + cache_path = snapshot_download(self.model_id) + preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path) + model = DialogIntentModel( + model_dir=cache_path, + text_field=preprocessor.text_field, + config=preprocessor.config) + + pipelines = [ + DialogIntentPredictionPipeline( + model=model, preprocessor=preprocessor), + pipeline( + task=Tasks.dialog_intent_prediction, + model=model, + preprocessor=preprocessor) + ] + + for my_pipeline, item in list(zip(pipelines, self.test_case)): + print(my_pipeline(item)) + + def test_run_with_model_from_modelhub(self): + # model = Model.from_pretrained(self.model_id) + # preprocessor = DialogIntentPredictionPreprocessor( + # model_dir=model.model_dir) + # + # pipelines = [ + # DialogIntentPredictionPipeline( + # model=model, preprocessor=preprocessor), + # pipeline( + # task=Tasks.dialog_intent_prediction, + # model=model, + # preprocessor=preprocessor) + # ] + # + # for my_pipeline, item in list(zip(pipelines, self.test_case)): + # print(my_pipeline(item)) + pass + + +if __name__ == '__main__': + unittest.main()