| @@ -26,7 +26,9 @@ class Models(object): | |||
| structbert = 'structbert' | |||
| veco = 'veco' | |||
| translation = 'csanmt-translation' | |||
| space = 'space' | |||
| space_dst = 'space-dst' | |||
| space_intent = 'space-intent' | |||
| space_modeling = 'space-modeling' | |||
| tcrf = 'transformer-crf' | |||
| bart = 'bart' | |||
| gpt3 = 'gpt3' | |||
| @@ -116,7 +118,7 @@ class Pipelines(object): | |||
| csanmt_translation = 'csanmt-translation' | |||
| nli = 'nli' | |||
| dialog_intent_prediction = 'dialog-intent-prediction' | |||
| task_oriented_conversation = 'task-oriented-conversation' | |||
| dialog_modeling = 'dialog-modeling' | |||
| dialog_state_tracking = 'dialog-state-tracking' | |||
| zero_shot_classification = 'zero-shot-classification' | |||
| text_error_correction = 'text-error-correction' | |||
| @@ -16,7 +16,7 @@ __all__ = ['SpaceForDialogIntent'] | |||
| @MODELS.register_module( | |||
| Tasks.dialog_intent_prediction, module_name=Models.space) | |||
| Tasks.task_oriented_conversation, module_name=Models.space_intent) | |||
| class SpaceForDialogIntent(TorchModel): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| @@ -16,7 +16,7 @@ __all__ = ['SpaceForDialogModeling'] | |||
| @MODELS.register_module( | |||
| Tasks.task_oriented_conversation, module_name=Models.space) | |||
| Tasks.task_oriented_conversation, module_name=Models.space_modeling) | |||
| class SpaceForDialogModeling(TorchModel): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| @@ -34,8 +34,10 @@ class SpaceForDialogModeling(TorchModel): | |||
| Config.from_file( | |||
| os.path.join(self.model_dir, ModelFile.CONFIGURATION))) | |||
| self.config.use_gpu = True if 'device' not in kwargs or kwargs[ | |||
| 'device'] == 'gpu' else False | |||
| import torch | |||
| self.config.use_gpu = True if ( | |||
| 'device' not in kwargs or kwargs['device'] | |||
| == 'gpu') and torch.cuda.is_available() else False | |||
| self.text_field = kwargs.pop( | |||
| 'text_field', | |||
| @@ -9,7 +9,8 @@ from modelscope.utils.constant import Tasks | |||
| __all__ = ['SpaceForDialogStateTracking'] | |||
| @MODELS.register_module(Tasks.dialog_state_tracking, module_name=Models.space) | |||
| @MODELS.register_module( | |||
| Tasks.task_oriented_conversation, module_name=Models.space_dst) | |||
| class SpaceForDialogStateTracking(TorchModel): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| @@ -320,7 +320,7 @@ TASK_OUTPUTS = { | |||
| Tasks.fill_mask: [OutputKeys.TEXT], | |||
| # (Deprecated) dialog intent prediction result for single sample | |||
| # {'pred': array([2.62349960e-03, 4.12110658e-03, 4.12748595e-05, 3.77560973e-05, | |||
| # {'output': {'prediction': array([2.62349960e-03, 4.12110658e-03, 4.12748595e-05, 3.77560973e-05, | |||
| # 1.08599677e-04, 1.72710388e-05, 2.95618793e-05, 1.93638436e-04, | |||
| # 6.45841064e-05, 1.15997791e-04, 5.11605394e-05, 9.87020373e-01, | |||
| # 2.66957268e-05, 4.72324500e-05, 9.74208378e-05, 4.18022355e-05, | |||
| @@ -339,50 +339,49 @@ TASK_OUTPUTS = { | |||
| # 3.61441926e-05, 3.38475402e-05, 3.44323053e-05, 5.70138109e-05, | |||
| # 4.31488479e-05, 4.94503947e-05, 4.30105974e-05, 1.00963116e-04, | |||
| # 2.82062047e-05, 1.15582036e-04, 4.48261271e-05, 3.99339879e-05, | |||
| # 7.27692823e-05], dtype=float32), 'label_pos': array([11]), 'label': 'lost_or_stolen_card'} | |||
| Tasks.dialog_intent_prediction: | |||
| [OutputKeys.PREDICTION, OutputKeys.LABEL_POS, OutputKeys.LABEL], | |||
| # 7.27692823e-05], dtype=float32), 'label_pos': array([11]), 'label': 'lost_or_stolen_card'}} | |||
| # (Deprecated) dialog modeling prediction result for single sample | |||
| # sys : ['you', 'are', 'welcome', '.', 'have', 'a', 'great', 'day', '!'] | |||
| Tasks.task_oriented_conversation: [OutputKeys.RESPONSE], | |||
| # {'output' : ['you', 'are', 'welcome', '.', 'have', 'a', 'great', 'day', '!']} | |||
| # (Deprecated) dialog state tracking result for single sample | |||
| # { | |||
| # "dialog_states": { | |||
| # "taxi-leaveAt": "none", | |||
| # "taxi-destination": "none", | |||
| # "taxi-departure": "none", | |||
| # "taxi-arriveBy": "none", | |||
| # "restaurant-book_people": "none", | |||
| # "restaurant-book_day": "none", | |||
| # "restaurant-book_time": "none", | |||
| # "restaurant-food": "none", | |||
| # "restaurant-pricerange": "none", | |||
| # "restaurant-name": "none", | |||
| # "restaurant-area": "none", | |||
| # "hotel-book_people": "none", | |||
| # "hotel-book_day": "none", | |||
| # "hotel-book_stay": "none", | |||
| # "hotel-name": "none", | |||
| # "hotel-area": "none", | |||
| # "hotel-parking": "none", | |||
| # "hotel-pricerange": "cheap", | |||
| # "hotel-stars": "none", | |||
| # "hotel-internet": "none", | |||
| # "hotel-type": "true", | |||
| # "attraction-type": "none", | |||
| # "attraction-name": "none", | |||
| # "attraction-area": "none", | |||
| # "train-book_people": "none", | |||
| # "train-leaveAt": "none", | |||
| # "train-destination": "none", | |||
| # "train-day": "none", | |||
| # "train-arriveBy": "none", | |||
| # "train-departure": "none" | |||
| # "output":{ | |||
| # "dialog_states": { | |||
| # "taxi-leaveAt": "none", | |||
| # "taxi-destination": "none", | |||
| # "taxi-departure": "none", | |||
| # "taxi-arriveBy": "none", | |||
| # "restaurant-book_people": "none", | |||
| # "restaurant-book_day": "none", | |||
| # "restaurant-book_time": "none", | |||
| # "restaurant-food": "none", | |||
| # "restaurant-pricerange": "none", | |||
| # "restaurant-name": "none", | |||
| # "restaurant-area": "none", | |||
| # "hotel-book_people": "none", | |||
| # "hotel-book_day": "none", | |||
| # "hotel-book_stay": "none", | |||
| # "hotel-name": "none", | |||
| # "hotel-area": "none", | |||
| # "hotel-parking": "none", | |||
| # "hotel-pricerange": "cheap", | |||
| # "hotel-stars": "none", | |||
| # "hotel-internet": "none", | |||
| # "hotel-type": "true", | |||
| # "attraction-type": "none", | |||
| # "attraction-name": "none", | |||
| # "attraction-area": "none", | |||
| # "train-book_people": "none", | |||
| # "train-leaveAt": "none", | |||
| # "train-destination": "none", | |||
| # "train-day": "none", | |||
| # "train-arriveBy": "none", | |||
| # "train-departure": "none" | |||
| # } | |||
| # } | |||
| # } | |||
| Tasks.dialog_state_tracking: [OutputKeys.DIALOG_STATES], | |||
| Tasks.task_oriented_conversation: [OutputKeys.OUTPUT], | |||
| # ============ audio tasks =================== | |||
| # asr result for single sample | |||
| @@ -48,13 +48,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| Tasks.zero_shot_classification: | |||
| (Pipelines.zero_shot_classification, | |||
| 'damo/nlp_structbert_zero-shot-classification_chinese-base'), | |||
| Tasks.dialog_intent_prediction: | |||
| (Pipelines.dialog_intent_prediction, | |||
| 'damo/nlp_space_dialog-intent-prediction'), | |||
| Tasks.task_oriented_conversation: (Pipelines.task_oriented_conversation, | |||
| Tasks.task_oriented_conversation: (Pipelines.dialog_modeling, | |||
| 'damo/nlp_space_dialog-modeling'), | |||
| Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking, | |||
| 'damo/nlp_space_dialog-state-tracking'), | |||
| Tasks.text_error_correction: | |||
| (Pipelines.text_error_correction, | |||
| 'damo/nlp_bart_text-error-correction_chinese'), | |||
| @@ -5,7 +5,7 @@ from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .dialog_intent_prediction_pipeline import DialogIntentPredictionPipeline | |||
| from .task_oriented_conversation_pipeline import TaskOrientedConversationPipeline | |||
| from .dialog_modeling_pipeline import DialogModelingPipeline | |||
| from .dialog_state_tracking_pipeline import DialogStateTrackingPipeline | |||
| from .fill_mask_pipeline import FillMaskPipeline | |||
| from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline | |||
| @@ -24,8 +24,7 @@ else: | |||
| _import_structure = { | |||
| 'dialog_intent_prediction_pipeline': | |||
| ['DialogIntentPredictionPipeline'], | |||
| 'task_oriented_conversation_pipeline': | |||
| ['TaskOrientedConversationPipeline'], | |||
| 'dialog_modeling_pipeline': ['DialogModelingPipeline'], | |||
| 'dialog_state_tracking_pipeline': ['DialogStateTrackingPipeline'], | |||
| 'fill_mask_pipeline': ['FillMaskPipeline'], | |||
| 'single_sentence_classification_pipeline': | |||
| @@ -15,7 +15,7 @@ __all__ = ['DialogIntentPredictionPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.dialog_intent_prediction, | |||
| Tasks.task_oriented_conversation, | |||
| module_name=Pipelines.dialog_intent_prediction) | |||
| class DialogIntentPredictionPipeline(Pipeline): | |||
| @@ -51,10 +51,10 @@ class DialogIntentPredictionPipeline(Pipeline): | |||
| pred = inputs['pred'] | |||
| pos = np.where(pred == np.max(pred)) | |||
| result = { | |||
| OutputKeys.PREDICTION: pred, | |||
| OutputKeys.LABEL_POS: pos[0], | |||
| OutputKeys.LABEL: self.categories[pos[0][0]] | |||
| return { | |||
| OutputKeys.OUTPUT: { | |||
| OutputKeys.PREDICTION: pred, | |||
| OutputKeys.LABEL_POS: pos[0], | |||
| OutputKeys.LABEL: self.categories[pos[0][0]] | |||
| } | |||
| } | |||
| return result | |||
| @@ -11,13 +11,12 @@ from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import DialogModelingPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| __all__ = ['TaskOrientedConversationPipeline'] | |||
| __all__ = ['DialogModelingPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.task_oriented_conversation, | |||
| module_name=Pipelines.task_oriented_conversation) | |||
| class TaskOrientedConversationPipeline(Pipeline): | |||
| Tasks.task_oriented_conversation, module_name=Pipelines.dialog_modeling) | |||
| class DialogModelingPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[SpaceForDialogModeling, str], | |||
| @@ -51,6 +50,6 @@ class TaskOrientedConversationPipeline(Pipeline): | |||
| inputs['resp']) | |||
| assert len(sys_rsp) > 2 | |||
| sys_rsp = sys_rsp[1:len(sys_rsp) - 1] | |||
| inputs[OutputKeys.RESPONSE] = sys_rsp | |||
| inputs[OutputKeys.OUTPUT] = sys_rsp | |||
| return inputs | |||
| @@ -13,7 +13,8 @@ __all__ = ['DialogStateTrackingPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.dialog_state_tracking, module_name=Pipelines.dialog_state_tracking) | |||
| Tasks.task_oriented_conversation, | |||
| module_name=Pipelines.dialog_state_tracking) | |||
| class DialogStateTrackingPipeline(Pipeline): | |||
| def __init__(self, | |||
| @@ -63,7 +64,7 @@ class DialogStateTrackingPipeline(Pipeline): | |||
| _outputs[5], unique_ids, input_ids_unmasked, | |||
| values, inform, prefix, ds) | |||
| return {OutputKeys.DIALOG_STATES: ds} | |||
| return {OutputKeys.OUTPUT: ds} | |||
| def predict_and_format(config, tokenizer, features, per_slot_class_logits, | |||
| @@ -20,7 +20,7 @@ class DialogIntentPredictionTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_by_direct_model_download(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| cache_path = snapshot_download(self.model_id, revision='update') | |||
| preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path) | |||
| model = SpaceForDialogIntent( | |||
| model_dir=cache_path, | |||
| @@ -31,7 +31,7 @@ class DialogIntentPredictionTest(unittest.TestCase): | |||
| DialogIntentPredictionPipeline( | |||
| model=model, preprocessor=preprocessor), | |||
| pipeline( | |||
| task=Tasks.dialog_intent_prediction, | |||
| task=Tasks.task_oriented_conversation, | |||
| model=model, | |||
| preprocessor=preprocessor) | |||
| ] | |||
| @@ -41,7 +41,7 @@ class DialogIntentPredictionTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| model = Model.from_pretrained(self.model_id, revision='update') | |||
| preprocessor = DialogIntentPredictionPreprocessor( | |||
| model_dir=model.model_dir) | |||
| @@ -49,7 +49,7 @@ class DialogIntentPredictionTest(unittest.TestCase): | |||
| DialogIntentPredictionPipeline( | |||
| model=model, preprocessor=preprocessor), | |||
| pipeline( | |||
| task=Tasks.dialog_intent_prediction, | |||
| task=Tasks.task_oriented_conversation, | |||
| model=model, | |||
| preprocessor=preprocessor) | |||
| ] | |||
| @@ -60,17 +60,14 @@ class DialogIntentPredictionTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipelines = [ | |||
| pipeline(task=Tasks.dialog_intent_prediction, model=self.model_id) | |||
| pipeline( | |||
| task=Tasks.task_oriented_conversation, | |||
| model=self.model_id, | |||
| model_revision='update') | |||
| ] | |||
| for my_pipeline, item in list(zip(pipelines, self.test_case)): | |||
| print(my_pipeline(item)) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipelines = [pipeline(task=Tasks.dialog_intent_prediction)] | |||
| for my_pipeline, item in list(zip(pipelines, self.test_case)): | |||
| print(my_pipeline(item)) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -5,14 +5,15 @@ from typing import List | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import SpaceForDialogModeling | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.nlp import TaskOrientedConversationPipeline | |||
| from modelscope.pipelines.nlp import DialogModelingPipeline | |||
| from modelscope.preprocessors import DialogModelingPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class TaskOrientedConversationTest(unittest.TestCase): | |||
| class DialogModelingTest(unittest.TestCase): | |||
| model_id = 'damo/nlp_space_dialog-modeling' | |||
| test_case = { | |||
| 'sng0073': { | |||
| @@ -92,23 +93,25 @@ class TaskOrientedConversationTest(unittest.TestCase): | |||
| } | |||
| def generate_and_print_dialog_response( | |||
| self, pipelines: List[TaskOrientedConversationPipeline]): | |||
| self, pipelines: List[DialogModelingPipeline]): | |||
| result = {} | |||
| pipeline_len = len(pipelines) | |||
| for step, item in enumerate(self.test_case['sng0073']['log']): | |||
| user = item['user'] | |||
| print('user: {}'.format(user)) | |||
| result = pipelines[step % 2]({ | |||
| result = pipelines[step % pipeline_len]({ | |||
| 'user_input': user, | |||
| 'history': result | |||
| }) | |||
| print('response : {}'.format(result['response'])) | |||
| print('response : {}'.format(result[OutputKeys.OUTPUT])) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_by_direct_model_download(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| cache_path = snapshot_download( | |||
| self.model_id, revision='task_oriented_conversation') | |||
| preprocessor = DialogModelingPreprocessor(model_dir=cache_path) | |||
| model = SpaceForDialogModeling( | |||
| @@ -116,27 +119,18 @@ class TaskOrientedConversationTest(unittest.TestCase): | |||
| text_field=preprocessor.text_field, | |||
| config=preprocessor.config) | |||
| pipelines = [ | |||
| TaskOrientedConversationPipeline( | |||
| model=model, preprocessor=preprocessor), | |||
| pipeline( | |||
| task=Tasks.task_oriented_conversation, | |||
| model=model, | |||
| preprocessor=preprocessor) | |||
| DialogModelingPipeline(model=model, preprocessor=preprocessor) | |||
| ] | |||
| self.generate_and_print_dialog_response(pipelines) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| model = Model.from_pretrained( | |||
| self.model_id, revision='task_oriented_conversation') | |||
| preprocessor = DialogModelingPreprocessor(model_dir=model.model_dir) | |||
| pipelines = [ | |||
| TaskOrientedConversationPipeline( | |||
| model=model, preprocessor=preprocessor), | |||
| pipeline( | |||
| task=Tasks.task_oriented_conversation, | |||
| model=model, | |||
| preprocessor=preprocessor) | |||
| DialogModelingPipeline(model=model, preprocessor=preprocessor) | |||
| ] | |||
| self.generate_and_print_dialog_response(pipelines) | |||
| @@ -145,17 +139,18 @@ class TaskOrientedConversationTest(unittest.TestCase): | |||
| def test_run_with_model_name(self): | |||
| pipelines = [ | |||
| pipeline( | |||
| task=Tasks.task_oriented_conversation, model=self.model_id), | |||
| pipeline( | |||
| task=Tasks.task_oriented_conversation, model=self.model_id) | |||
| task=Tasks.task_oriented_conversation, | |||
| model=self.model_id, | |||
| model_revision='task_oriented_conversation') | |||
| ] | |||
| self.generate_and_print_dialog_response(pipelines) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipelines = [ | |||
| pipeline(task=Tasks.task_oriented_conversation), | |||
| pipeline(task=Tasks.task_oriented_conversation) | |||
| pipeline( | |||
| task=Tasks.task_oriented_conversation, | |||
| model_revision='task_oriented_conversation') | |||
| ] | |||
| self.generate_and_print_dialog_response(pipelines) | |||
| @@ -5,6 +5,7 @@ from typing import List | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import SpaceForDialogStateTracking | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.nlp import DialogStateTrackingPipeline | |||
| from modelscope.preprocessors import DialogStateTrackingPreprocessor | |||
| @@ -94,11 +95,11 @@ class DialogStateTrackingTest(unittest.TestCase): | |||
| }) | |||
| print(json.dumps(result)) | |||
| history_states.extend([result['dialog_states'], {}]) | |||
| history_states.extend([result[OutputKeys.OUTPUT], {}]) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_by_direct_model_download(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| cache_path = snapshot_download(self.model_id, revision='update') | |||
| model = SpaceForDialogStateTracking(cache_path) | |||
| preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path) | |||
| @@ -106,7 +107,7 @@ class DialogStateTrackingTest(unittest.TestCase): | |||
| DialogStateTrackingPipeline( | |||
| model=model, preprocessor=preprocessor), | |||
| pipeline( | |||
| task=Tasks.dialog_state_tracking, | |||
| task=Tasks.task_oriented_conversation, | |||
| model=model, | |||
| preprocessor=preprocessor) | |||
| ] | |||
| @@ -114,14 +115,15 @@ class DialogStateTrackingTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| model = Model.from_pretrained(self.model_id, revision='update') | |||
| preprocessor = DialogStateTrackingPreprocessor( | |||
| model_dir=model.model_dir) | |||
| pipelines = [ | |||
| DialogStateTrackingPipeline( | |||
| model=model, preprocessor=preprocessor), | |||
| pipeline( | |||
| task=Tasks.dialog_state_tracking, | |||
| task=Tasks.task_oriented_conversation, | |||
| model=model, | |||
| preprocessor=preprocessor) | |||
| ] | |||
| @@ -131,15 +133,13 @@ class DialogStateTrackingTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipelines = [ | |||
| pipeline(task=Tasks.dialog_state_tracking, model=self.model_id) | |||
| pipeline( | |||
| task=Tasks.task_oriented_conversation, | |||
| model=self.model_id, | |||
| model_revision='update') | |||
| ] | |||
| self.tracking_and_print_dialog_states(pipelines) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipelines = [pipeline(task=Tasks.dialog_state_tracking)] | |||
| self.tracking_and_print_dialog_states(pipelines) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||