| @@ -63,15 +63,16 @@ class SpaceForDialogIntent(Model): | |||||
| """return the result by the model | """return the result by the model | ||||
| Args: | Args: | ||||
| input (Dict[str, Any]): the preprocessed data | |||||
| input (Dict[str, Tensor]): the preprocessed data | |||||
| Returns: | Returns: | ||||
| Dict[str, np.ndarray]: results | |||||
| Dict[str, Tensor]: results | |||||
| Example: | Example: | ||||
| { | { | ||||
| 'predictions': array([1]), # lable 0-negative 1-positive | |||||
| 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||||
| 'pred': array([2.62349960e-03 4.12110658e-03 4.12748595e-05 3.77560973e-05 | |||||
| 1.08599677e-04 1.72710388e-05 2.95618793e-05 1.93638436e-04 | |||||
| 6.45841064e-05 1.15997791e-04 5.11605394e-05 9.87020373e-01 | |||||
| 2.66957268e-05 4.72324500e-05 9.74208378e-05], dtype=float32) | |||||
| } | } | ||||
| """ | """ | ||||
| import numpy as np | import numpy as np | ||||
| @@ -62,15 +62,17 @@ class SpaceForDialogModeling(Model): | |||||
| """return the result by the model | """return the result by the model | ||||
| Args: | Args: | ||||
| input (Dict[str, Any]): the preprocessed data | |||||
| input (Dict[str, Tensor]): the preprocessed data | |||||
| Returns: | Returns: | ||||
| Dict[str, np.ndarray]: results | |||||
| Dict[str, Tensor]: results | |||||
| Example: | Example: | ||||
| { | { | ||||
| 'predictions': array([1]), # lable 0-negative 1-positive | |||||
| 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||||
| 'labels': array([1,192,321,12]), # lable | |||||
| 'resp': array([293,1023,123,1123]), #vocab label for response | |||||
| 'bspn': array([123,321,2,24,1 ]), | |||||
| 'aspn': array([47,8345,32,29,1983]), | |||||
| 'db': array([19, 24, 20]), | |||||
| } | } | ||||
| """ | """ | ||||
| @@ -2,6 +2,7 @@ import os | |||||
| from typing import Any, Dict | from typing import Any, Dict | ||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from ....metainfo import Models | |||||
| from ....utils.nlp.space.utils_dst import batch_to_device | from ....utils.nlp.space.utils_dst import batch_to_device | ||||
| from ...base import Model, Tensor | from ...base import Model, Tensor | ||||
| from ...builder import MODELS | from ...builder import MODELS | ||||
| @@ -9,7 +10,7 @@ from ...builder import MODELS | |||||
| __all__ = ['SpaceForDialogStateTracking'] | __all__ = ['SpaceForDialogStateTracking'] | ||||
| @MODELS.register_module(Tasks.dialog_state_tracking, module_name=r'space') | |||||
| @MODELS.register_module(Tasks.dialog_state_tracking, module_name=Models.space) | |||||
| class SpaceForDialogStateTracking(Model): | class SpaceForDialogStateTracking(Model): | ||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| @@ -17,8 +18,6 @@ class SpaceForDialogStateTracking(Model): | |||||
| Args: | Args: | ||||
| model_dir (str): the model path. | model_dir (str): the model path. | ||||
| model_cls (Optional[Any], optional): model loader, if None, use the | |||||
| default loader to load model weights, by default None. | |||||
| """ | """ | ||||
| super().__init__(model_dir, *args, **kwargs) | super().__init__(model_dir, *args, **kwargs) | ||||
| @@ -27,7 +26,6 @@ class SpaceForDialogStateTracking(Model): | |||||
| self.model_dir = model_dir | self.model_dir = model_dir | ||||
| self.config = SpaceConfig.from_pretrained(self.model_dir) | self.config = SpaceConfig.from_pretrained(self.model_dir) | ||||
| # self.model = SpaceForDST(self.config) | |||||
| self.model = SpaceForDST.from_pretrained(self.model_dir) | self.model = SpaceForDST.from_pretrained(self.model_dir) | ||||
| self.model.to(self.config.device) | self.model.to(self.config.device) | ||||
| @@ -35,15 +33,20 @@ class SpaceForDialogStateTracking(Model): | |||||
| """return the result by the model | """return the result by the model | ||||
| Args: | Args: | ||||
| input (Dict[str, Any]): the preprocessed data | |||||
| input (Dict[str, Tensor]): the preprocessed data | |||||
| Returns: | Returns: | ||||
| Dict[str, np.ndarray]: results | |||||
| Dict[str, Tensor]: results | |||||
| Example: | Example: | ||||
| { | { | ||||
| 'predictions': array([1]), # lable 0-negative 1-positive | |||||
| 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||||
| 'inputs': dict(input_ids, input_masks,start_pos), # tracking states | |||||
| 'outputs': dict(slots_logits), | |||||
| 'unique_ids': str(test-example.json-0), # default value | |||||
| 'input_ids_unmasked': array([101, 7632, 1010,0,0,0]) | |||||
| 'values': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]), | |||||
| 'inform': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]), | |||||
| 'prefix': str('final'), #default value | |||||
| 'ds': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]) | |||||
| } | } | ||||
| """ | """ | ||||
| import numpy as np | import numpy as np | ||||
| @@ -88,8 +91,6 @@ class SpaceForDialogStateTracking(Model): | |||||
| if u != 0: | if u != 0: | ||||
| diag_state[slot][i] = u | diag_state[slot][i] = u | ||||
| # print(outputs) | |||||
| return { | return { | ||||
| 'inputs': inputs, | 'inputs': inputs, | ||||
| 'outputs': outputs, | 'outputs': outputs, | ||||
| @@ -41,6 +41,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| 'damo/nlp_space_dialog-intent-prediction'), | 'damo/nlp_space_dialog-intent-prediction'), | ||||
| Tasks.dialog_modeling: (Pipelines.dialog_modeling, | Tasks.dialog_modeling: (Pipelines.dialog_modeling, | ||||
| 'damo/nlp_space_dialog-modeling'), | 'damo/nlp_space_dialog-modeling'), | ||||
| Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking, | |||||
| 'damo/nlp_space_dialog-state-tracking'), | |||||
| Tasks.image_captioning: (Pipelines.image_caption, | Tasks.image_captioning: (Pipelines.image_caption, | ||||
| 'damo/ofa_image-caption_coco_large_en'), | 'damo/ofa_image-caption_coco_large_en'), | ||||
| Tasks.image_generation: | Tasks.image_generation: | ||||
| @@ -1,8 +1,9 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from typing import Any, Dict | |||||
| from typing import Any, Dict, Union | |||||
| from ...metainfo import Pipelines | from ...metainfo import Pipelines | ||||
| from ...models import Model | |||||
| from ...models.nlp import SpaceForDialogIntent | from ...models.nlp import SpaceForDialogIntent | ||||
| from ...preprocessors import DialogIntentPredictionPreprocessor | from ...preprocessors import DialogIntentPredictionPreprocessor | ||||
| from ...utils.constant import Tasks | from ...utils.constant import Tasks | ||||
| @@ -18,17 +19,22 @@ __all__ = ['DialogIntentPredictionPipeline'] | |||||
| module_name=Pipelines.dialog_intent_prediction) | module_name=Pipelines.dialog_intent_prediction) | ||||
| class DialogIntentPredictionPipeline(Pipeline): | class DialogIntentPredictionPipeline(Pipeline): | ||||
| def __init__(self, model: SpaceForDialogIntent, | |||||
| preprocessor: DialogIntentPredictionPreprocessor, **kwargs): | |||||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||||
| def __init__(self, | |||||
| model: Union[SpaceForDialogIntent, str], | |||||
| preprocessor: DialogIntentPredictionPreprocessor = None, | |||||
| **kwargs): | |||||
| """use `model` and `preprocessor` to create a dialog intent prediction pipeline | |||||
| Args: | Args: | ||||
| model (SequenceClassificationModel): a model instance | |||||
| preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | |||||
| model (SpaceForDialogIntent): a model instance | |||||
| preprocessor (DialogIntentPredictionPreprocessor): a preprocessor instance | |||||
| """ | """ | ||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| model = model if isinstance( | |||||
| model, SpaceForDialogIntent) else Model.from_pretrained(model) | |||||
| if preprocessor is None: | |||||
| preprocessor = DialogIntentPredictionPreprocessor(model.model_dir) | |||||
| self.model = model | self.model = model | ||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.categories = preprocessor.categories | self.categories = preprocessor.categories | ||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | ||||
| @@ -1,8 +1,9 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from typing import Any, Dict, Optional | |||||
| from typing import Any, Dict, Union | |||||
| from ...metainfo import Pipelines | from ...metainfo import Pipelines | ||||
| from ...models import Model | |||||
| from ...models.nlp import SpaceForDialogModeling | from ...models.nlp import SpaceForDialogModeling | ||||
| from ...preprocessors import DialogModelingPreprocessor | from ...preprocessors import DialogModelingPreprocessor | ||||
| from ...utils.constant import Tasks | from ...utils.constant import Tasks | ||||
| @@ -17,17 +18,22 @@ __all__ = ['DialogModelingPipeline'] | |||||
| Tasks.dialog_modeling, module_name=Pipelines.dialog_modeling) | Tasks.dialog_modeling, module_name=Pipelines.dialog_modeling) | ||||
| class DialogModelingPipeline(Pipeline): | class DialogModelingPipeline(Pipeline): | ||||
| def __init__(self, model: SpaceForDialogModeling, | |||||
| preprocessor: DialogModelingPreprocessor, **kwargs): | |||||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||||
| def __init__(self, | |||||
| model: Union[SpaceForDialogModeling, str], | |||||
| preprocessor: DialogModelingPreprocessor = None, | |||||
| **kwargs): | |||||
| """use `model` and `preprocessor` to create a dialog modleing pipeline for dialog response generation | |||||
| Args: | Args: | ||||
| model (SequenceClassificationModel): a model instance | |||||
| preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | |||||
| model (SpaceForDialogModeling): a model instance | |||||
| preprocessor (DialogModelingPreprocessor): a preprocessor instance | |||||
| """ | """ | ||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| model = model if isinstance( | |||||
| model, SpaceForDialogModeling) else Model.from_pretrained(model) | |||||
| self.model = model | self.model = model | ||||
| if preprocessor is None: | |||||
| preprocessor = DialogModelingPreprocessor(model.model_dir) | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.preprocessor = preprocessor | self.preprocessor = preprocessor | ||||
| def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: | def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: | ||||
| @@ -1,7 +1,7 @@ | |||||
| from typing import Any, Dict | |||||
| from typing import Any, Dict, Union | |||||
| from ...metainfo import Pipelines | from ...metainfo import Pipelines | ||||
| from ...models import SpaceForDialogStateTracking | |||||
| from ...models import Model, SpaceForDialogStateTracking | |||||
| from ...preprocessors import DialogStateTrackingPreprocessor | from ...preprocessors import DialogStateTrackingPreprocessor | ||||
| from ...utils.constant import Tasks | from ...utils.constant import Tasks | ||||
| from ..base import Pipeline | from ..base import Pipeline | ||||
| @@ -15,17 +15,26 @@ __all__ = ['DialogStateTrackingPipeline'] | |||||
| Tasks.dialog_state_tracking, module_name=Pipelines.dialog_state_tracking) | Tasks.dialog_state_tracking, module_name=Pipelines.dialog_state_tracking) | ||||
| class DialogStateTrackingPipeline(Pipeline): | class DialogStateTrackingPipeline(Pipeline): | ||||
| def __init__(self, model: SpaceForDialogStateTracking, | |||||
| preprocessor: DialogStateTrackingPreprocessor, **kwargs): | |||||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||||
| def __init__(self, | |||||
| model: Union[SpaceForDialogStateTracking, str], | |||||
| preprocessor: DialogStateTrackingPreprocessor = None, | |||||
| **kwargs): | |||||
| """use `model` and `preprocessor` to create a dialog state tracking pipeline for | |||||
| observation of dialog states tracking after many turns of open domain dialogue | |||||
| Args: | Args: | ||||
| model (SequenceClassificationModel): a model instance | |||||
| preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | |||||
| model (SpaceForDialogStateTracking): a model instance | |||||
| preprocessor (DialogStateTrackingPreprocessor): a preprocessor instance | |||||
| """ | """ | ||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| model = model if isinstance( | |||||
| model, | |||||
| SpaceForDialogStateTracking) else Model.from_pretrained(model) | |||||
| self.model = model | self.model = model | ||||
| if preprocessor is None: | |||||
| preprocessor = DialogStateTrackingPreprocessor(model.model_dir) | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.tokenizer = preprocessor.tokenizer | self.tokenizer = preprocessor.tokenizer | ||||
| self.config = preprocessor.config | self.config = preprocessor.config | ||||
| @@ -46,9 +55,7 @@ class DialogStateTrackingPipeline(Pipeline): | |||||
| values = inputs['values'] | values = inputs['values'] | ||||
| inform = inputs['inform'] | inform = inputs['inform'] | ||||
| prefix = inputs['prefix'] | prefix = inputs['prefix'] | ||||
| # ds = {slot: 'none' for slot in self.config.dst_slot_list} | |||||
| ds = inputs['ds'] | ds = inputs['ds'] | ||||
| ds = predict_and_format(self.config, self.tokenizer, _inputs, | ds = predict_and_format(self.config, self.tokenizer, _inputs, | ||||
| _outputs[2], _outputs[3], _outputs[4], | _outputs[2], _outputs[3], _outputs[4], | ||||
| _outputs[5], unique_ids, input_ids_unmasked, | _outputs[5], unique_ids, input_ids_unmasked, | ||||
| @@ -138,13 +138,6 @@ TASK_OUTPUTS = { | |||||
| # } | # } | ||||
| Tasks.sentiment_classification: [OutputKeys.SCORES, OutputKeys.LABELS], | Tasks.sentiment_classification: [OutputKeys.SCORES, OutputKeys.LABELS], | ||||
| # sentiment classification result for single sample | |||||
| # { | |||||
| # "labels": ["happy", "sad", "calm", "angry"], | |||||
| # "scores": [0.9, 0.1, 0.05, 0.05] | |||||
| # } | |||||
| Tasks.sentiment_classification: ['scores', 'labels'], | |||||
| # zero-shot classification result for single sample | # zero-shot classification result for single sample | ||||
| # { | # { | ||||
| # "scores": [0.9, 0.1, 0.05, 0.05] | # "scores": [0.9, 0.1, 0.05, 0.05] | ||||
| @@ -18,7 +18,7 @@ class DialogIntentPredictionTest(unittest.TestCase): | |||||
| ] | ] | ||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| def test_run(self): | |||||
| def test_run_by_direct_model_download(self): | |||||
| cache_path = snapshot_download(self.model_id) | cache_path = snapshot_download(self.model_id) | ||||
| preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path) | preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path) | ||||
| model = SpaceForDialogIntent( | model = SpaceForDialogIntent( | ||||
| @@ -56,6 +56,20 @@ class DialogIntentPredictionTest(unittest.TestCase): | |||||
| for my_pipeline, item in list(zip(pipelines, self.test_case)): | for my_pipeline, item in list(zip(pipelines, self.test_case)): | ||||
| print(my_pipeline(item)) | print(my_pipeline(item)) | ||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_name(self): | |||||
| pipelines = [ | |||||
| pipeline(task=Tasks.dialog_intent_prediction, model=self.model_id) | |||||
| ] | |||||
| for my_pipeline, item in list(zip(pipelines, self.test_case)): | |||||
| print(my_pipeline(item)) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_default_model(self): | |||||
| pipelines = [pipeline(task=Tasks.dialog_intent_prediction)] | |||||
| for my_pipeline, item in list(zip(pipelines, self.test_case)): | |||||
| print(my_pipeline(item)) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| unittest.main() | unittest.main() | ||||
| @@ -1,5 +1,6 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import unittest | import unittest | ||||
| from typing import List | |||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| from modelscope.models import Model | from modelscope.models import Model | ||||
| @@ -89,8 +90,22 @@ class DialogModelingTest(unittest.TestCase): | |||||
| } | } | ||||
| } | } | ||||
| def generate_and_print_dialog_response( | |||||
| self, pipelines: List[DialogModelingPipeline]): | |||||
| result = {} | |||||
| for step, item in enumerate(self.test_case['sng0073']['log']): | |||||
| user = item['user'] | |||||
| print('user: {}'.format(user)) | |||||
| result = pipelines[step % 2]({ | |||||
| 'user_input': user, | |||||
| 'history': result | |||||
| }) | |||||
| print('response : {}'.format(result['response'])) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| def test_run(self): | |||||
| def test_run_by_direct_model_download(self): | |||||
| cache_path = snapshot_download(self.model_id) | cache_path = snapshot_download(self.model_id) | ||||
| @@ -106,17 +121,7 @@ class DialogModelingTest(unittest.TestCase): | |||||
| model=model, | model=model, | ||||
| preprocessor=preprocessor) | preprocessor=preprocessor) | ||||
| ] | ] | ||||
| result = {} | |||||
| for step, item in enumerate(self.test_case['sng0073']['log']): | |||||
| user = item['user'] | |||||
| print('user: {}'.format(user)) | |||||
| result = pipelines[step % 2]({ | |||||
| 'user_input': user, | |||||
| 'history': result | |||||
| }) | |||||
| print('response : {}'.format(result['response'])) | |||||
| self.generate_and_print_dialog_response(pipelines) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_run_with_model_from_modelhub(self): | def test_run_with_model_from_modelhub(self): | ||||
| @@ -131,16 +136,23 @@ class DialogModelingTest(unittest.TestCase): | |||||
| preprocessor=preprocessor) | preprocessor=preprocessor) | ||||
| ] | ] | ||||
| result = {} | |||||
| for step, item in enumerate(self.test_case['sng0073']['log']): | |||||
| user = item['user'] | |||||
| print('user: {}'.format(user)) | |||||
| self.generate_and_print_dialog_response(pipelines) | |||||
| result = pipelines[step % 2]({ | |||||
| 'user_input': user, | |||||
| 'history': result | |||||
| }) | |||||
| print('response : {}'.format(result['response'])) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_name(self): | |||||
| pipelines = [ | |||||
| pipeline(task=Tasks.dialog_modeling, model=self.model_id), | |||||
| pipeline(task=Tasks.dialog_modeling, model=self.model_id) | |||||
| ] | |||||
| self.generate_and_print_dialog_response(pipelines) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_default_model(self): | |||||
| pipelines = [ | |||||
| pipeline(task=Tasks.dialog_modeling), | |||||
| pipeline(task=Tasks.dialog_modeling) | |||||
| ] | |||||
| self.generate_and_print_dialog_response(pipelines) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| @@ -1,5 +1,6 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import unittest | import unittest | ||||
| from typing import List | |||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| from modelscope.models import Model, SpaceForDialogStateTracking | from modelscope.models import Model, SpaceForDialogStateTracking | ||||
| @@ -75,23 +76,10 @@ class DialogStateTrackingTest(unittest.TestCase): | |||||
| 'User-8': 'Thank you, goodbye', | 'User-8': 'Thank you, goodbye', | ||||
| }] | }] | ||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run(self): | |||||
| cache_path = snapshot_download(self.model_id) | |||||
| model = SpaceForDialogStateTracking(cache_path) | |||||
| preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path) | |||||
| pipelines = [ | |||||
| DialogStateTrackingPipeline( | |||||
| model=model, preprocessor=preprocessor), | |||||
| pipeline( | |||||
| task=Tasks.dialog_state_tracking, | |||||
| model=model, | |||||
| preprocessor=preprocessor) | |||||
| ] | |||||
| pipelines_len = len(pipelines) | |||||
| def tracking_and_print_dialog_states( | |||||
| self, pipelines: List[DialogStateTrackingPipeline]): | |||||
| import json | import json | ||||
| pipelines_len = len(pipelines) | |||||
| history_states = [{}] | history_states = [{}] | ||||
| utter = {} | utter = {} | ||||
| for step, item in enumerate(self.test_case): | for step, item in enumerate(self.test_case): | ||||
| @@ -106,6 +94,22 @@ class DialogStateTrackingTest(unittest.TestCase): | |||||
| history_states.extend([result['dialog_states'], {}]) | history_states.extend([result['dialog_states'], {}]) | ||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_by_direct_model_download(self): | |||||
| cache_path = snapshot_download(self.model_id) | |||||
| model = SpaceForDialogStateTracking(cache_path) | |||||
| preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path) | |||||
| pipelines = [ | |||||
| DialogStateTrackingPipeline( | |||||
| model=model, preprocessor=preprocessor), | |||||
| pipeline( | |||||
| task=Tasks.dialog_state_tracking, | |||||
| model=model, | |||||
| preprocessor=preprocessor) | |||||
| ] | |||||
| self.tracking_and_print_dialog_states(pipelines) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_run_with_model_from_modelhub(self): | def test_run_with_model_from_modelhub(self): | ||||
| model = Model.from_pretrained(self.model_id) | model = Model.from_pretrained(self.model_id) | ||||
| @@ -120,21 +124,19 @@ class DialogStateTrackingTest(unittest.TestCase): | |||||
| preprocessor=preprocessor) | preprocessor=preprocessor) | ||||
| ] | ] | ||||
| pipelines_len = len(pipelines) | |||||
| import json | |||||
| history_states = [{}] | |||||
| utter = {} | |||||
| for step, item in enumerate(self.test_case): | |||||
| utter.update(item) | |||||
| result = pipelines[step % pipelines_len]({ | |||||
| 'utter': | |||||
| utter, | |||||
| 'history_states': | |||||
| history_states | |||||
| }) | |||||
| print(json.dumps(result)) | |||||
| self.tracking_and_print_dialog_states(pipelines) | |||||
| history_states.extend([result['dialog_states'], {}]) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_name(self): | |||||
| pipelines = [ | |||||
| pipeline(task=Tasks.dialog_state_tracking, model=self.model_id) | |||||
| ] | |||||
| self.tracking_and_print_dialog_states(pipelines) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_default_model(self): | |||||
| pipelines = [pipeline(task=Tasks.dialog_state_tracking)] | |||||
| self.tracking_and_print_dialog_states(pipelines) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||