| @@ -4,3 +4,4 @@ from .sbert_for_sentence_similarity import * # noqa F403 | |||||
| from .sbert_for_token_classification import * # noqa F403 | from .sbert_for_token_classification import * # noqa F403 | ||||
| from .space.dialog_intent_prediction_model import * # noqa F403 | from .space.dialog_intent_prediction_model import * # noqa F403 | ||||
| from .space.dialog_modeling_model import * # noqa F403 | from .space.dialog_modeling_model import * # noqa F403 | ||||
| from .space.dialog_state_tracking import * # noqa F403 | |||||
| @@ -2,5 +2,6 @@ from .sentence_similarity_pipeline import * # noqa F403 | |||||
| from .sequence_classification_pipeline import * # noqa F403 | from .sequence_classification_pipeline import * # noqa F403 | ||||
| from .space.dialog_intent_prediction_pipeline import * # noqa F403 | from .space.dialog_intent_prediction_pipeline import * # noqa F403 | ||||
| from .space.dialog_modeling_pipeline import * # noqa F403 | from .space.dialog_modeling_pipeline import * # noqa F403 | ||||
| from .space.dialog_state_tracking import * # noqa F403 | |||||
| from .text_generation_pipeline import * # noqa F403 | from .text_generation_pipeline import * # noqa F403 | ||||
| from .word_segmentation_pipeline import * # noqa F403 | from .word_segmentation_pipeline import * # noqa F403 | ||||
| @@ -8,4 +8,5 @@ from .image import LoadImage, load_image | |||||
| from .nlp import * # noqa F403 | from .nlp import * # noqa F403 | ||||
| from .space.dialog_intent_prediction_preprocessor import * # noqa F403 | from .space.dialog_intent_prediction_preprocessor import * # noqa F403 | ||||
| from .space.dialog_modeling_preprocessor import * # noqa F403 | from .space.dialog_modeling_preprocessor import * # noqa F403 | ||||
| from .space.dialog_state_tracking_preprocessor import * # noqa F403 | |||||
| from .text_to_speech import * # noqa F403 | from .text_to_speech import * # noqa F403 | ||||
| @@ -0,0 +1,61 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from maas_hub.snapshot_download import snapshot_download | |||||
| from modelscope.models import Model | |||||
| from modelscope.models.nlp import DialogStateTrackingModel | |||||
| from modelscope.pipelines import DialogStateTrackingPipeline, pipeline | |||||
| from modelscope.preprocessors import DialogStateTrackingPreprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| class DialogIntentPredictionTest(unittest.TestCase): | |||||
| model_id = 'damo/nlp_space_dialog-intent-prediction' | |||||
| test_case = [ | |||||
| 'How do I locate my card?', | |||||
| 'I still have not received my new card, I ordered over a week ago.' | |||||
| ] | |||||
| @unittest.skip('test with snapshot_download') | |||||
| def test_run(self): | |||||
| cache_path = snapshot_download(self.model_id) | |||||
| preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path) | |||||
| model = DialogIntentModel( | |||||
| model_dir=cache_path, | |||||
| text_field=preprocessor.text_field, | |||||
| config=preprocessor.config) | |||||
| pipelines = [ | |||||
| DialogIntentPredictionPipeline( | |||||
| model=model, preprocessor=preprocessor), | |||||
| pipeline( | |||||
| task=Tasks.dialog_intent_prediction, | |||||
| model=model, | |||||
| preprocessor=preprocessor) | |||||
| ] | |||||
| for my_pipeline, item in list(zip(pipelines, self.test_case)): | |||||
| print(my_pipeline(item)) | |||||
| def test_run_with_model_from_modelhub(self): | |||||
| # model = Model.from_pretrained(self.model_id) | |||||
| # preprocessor = DialogIntentPredictionPreprocessor( | |||||
| # model_dir=model.model_dir) | |||||
| # | |||||
| # pipelines = [ | |||||
| # DialogIntentPredictionPipeline( | |||||
| # model=model, preprocessor=preprocessor), | |||||
| # pipeline( | |||||
| # task=Tasks.dialog_intent_prediction, | |||||
| # model=model, | |||||
| # preprocessor=preprocessor) | |||||
| # ] | |||||
| # | |||||
| # for my_pipeline, item in list(zip(pipelines, self.test_case)): | |||||
| # print(my_pipeline(item)) | |||||
| pass | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||