| @@ -1,6 +1,10 @@ | |||||
| from typing import Any, Dict, Optional | |||||
| import os | |||||
| from typing import Any, Dict | |||||
| from modelscope.preprocessors.space.fields.intent_field import \ | |||||
| IntentBPETextField | |||||
| from modelscope.trainers.nlp.space.trainers.intent_trainer import IntentTrainer | from modelscope.trainers.nlp.space.trainers.intent_trainer import IntentTrainer | ||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from ...base import Model, Tensor | from ...base import Model, Tensor | ||||
| from ...builder import MODELS | from ...builder import MODELS | ||||
| @@ -10,7 +14,7 @@ from .model.model_base import ModelBase | |||||
| __all__ = ['DialogIntentModel'] | __all__ = ['DialogIntentModel'] | ||||
| @MODELS.register_module(Tasks.dialog_intent, module_name=r'space-intent') | |||||
| @MODELS.register_module(Tasks.dialog_intent, module_name=r'space') | |||||
| class DialogIntentModel(Model): | class DialogIntentModel(Model): | ||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| @@ -24,8 +28,14 @@ class DialogIntentModel(Model): | |||||
| super().__init__(model_dir, *args, **kwargs) | super().__init__(model_dir, *args, **kwargs) | ||||
| self.model_dir = model_dir | self.model_dir = model_dir | ||||
| self.text_field = kwargs.pop('text_field') | |||||
| self.config = kwargs.pop('config') | |||||
| self.config = kwargs.pop( | |||||
| 'config', | |||||
| Config.from_file( | |||||
| os.path.join(self.model_dir, 'configuration.json'))) | |||||
| self.text_field = kwargs.pop( | |||||
| 'text_field', | |||||
| IntentBPETextField(self.model_dir, config=self.config)) | |||||
| self.generator = Generator.create(self.config, reader=self.text_field) | self.generator = Generator.create(self.config, reader=self.text_field) | ||||
| self.model = ModelBase.create( | self.model = ModelBase.create( | ||||
| model_dir=model_dir, | model_dir=model_dir, | ||||
| @@ -63,9 +73,8 @@ class DialogIntentModel(Model): | |||||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | ||||
| } | } | ||||
| """ | """ | ||||
| from numpy import array, float32 | |||||
| import torch | |||||
| print('--forward--') | |||||
| result = self.trainer.forward(input) | |||||
| import numpy as np | |||||
| pred = self.trainer.forward(input) | |||||
| pred = np.squeeze(pred[0], 0) | |||||
| return result | |||||
| return {'pred': pred} | |||||
| @@ -9,7 +9,7 @@ from ...builder import PIPELINES | |||||
| __all__ = ['DialogIntentPipeline'] | __all__ = ['DialogIntentPipeline'] | ||||
| @PIPELINES.register_module(Tasks.dialog_intent, module_name=r'space-intent') | |||||
| @PIPELINES.register_module(Tasks.dialog_intent, module_name=r'space') | |||||
| class DialogIntentPipeline(Pipeline): | class DialogIntentPipeline(Pipeline): | ||||
| def __init__(self, model: DialogIntentModel, | def __init__(self, model: DialogIntentModel, | ||||
| @@ -34,5 +34,10 @@ class DialogIntentPipeline(Pipeline): | |||||
| Returns: | Returns: | ||||
| Dict[str, str]: the prediction results | Dict[str, str]: the prediction results | ||||
| """ | """ | ||||
| import numpy as np | |||||
| pred = inputs['pred'] | |||||
| pos = np.where(pred == np.max(pred)) | |||||
| return inputs | |||||
| result = {'pred': pred, 'label': pos[0]} | |||||
| return result | |||||
| @@ -1,13 +1,12 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | import os | ||||
| import uuid | |||||
| from typing import Any, Dict, Union | |||||
| from typing import Any, Dict | |||||
| from modelscope.preprocessors.space.fields.intent_field import \ | from modelscope.preprocessors.space.fields.intent_field import \ | ||||
| IntentBPETextField | IntentBPETextField | ||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import Fields, InputFields | |||||
| from modelscope.utils.constant import Fields | |||||
| from modelscope.utils.type_assert import type_assert | from modelscope.utils.type_assert import type_assert | ||||
| from ..base import Preprocessor | from ..base import Preprocessor | ||||
| from ..builder import PREPROCESSORS | from ..builder import PREPROCESSORS | ||||
| @@ -15,7 +14,7 @@ from ..builder import PREPROCESSORS | |||||
| __all__ = ['DialogIntentPreprocessor'] | __all__ = ['DialogIntentPreprocessor'] | ||||
| @PREPROCESSORS.register_module(Fields.nlp, module_name=r'space-intent') | |||||
| @PREPROCESSORS.register_module(Fields.nlp, module_name=r'space') | |||||
| class DialogIntentPreprocessor(Preprocessor): | class DialogIntentPreprocessor(Preprocessor): | ||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| @@ -508,7 +508,6 @@ class IntentTrainer(Trainer): | |||||
| report_for_unlabeled_data, cur_valid_metric=-accuracy) | report_for_unlabeled_data, cur_valid_metric=-accuracy) | ||||
| def forward(self, batch): | def forward(self, batch): | ||||
| outputs, labels = [], [] | |||||
| pred, true = [], [] | pred, true = [], [] | ||||
| with torch.no_grad(): | with torch.no_grad(): | ||||
| @@ -522,12 +521,10 @@ class IntentTrainer(Trainer): | |||||
| intent_probs = result['intent_probs'] | intent_probs = result['intent_probs'] | ||||
| if self.can_norm: | if self.can_norm: | ||||
| pred += [intent_probs] | pred += [intent_probs] | ||||
| true += batch['intent_label'].cpu().detach().tolist() | |||||
| else: | else: | ||||
| pred += np.argmax(intent_probs, axis=1).tolist() | pred += np.argmax(intent_probs, axis=1).tolist() | ||||
| true += batch['intent_label'].cpu().detach().tolist() | |||||
| return {'pred': pred} | |||||
| return pred | |||||
| def infer(self, data_iter, num_batches=None, ex_data_iter=None): | def infer(self, data_iter, num_batches=None, ex_data_iter=None): | ||||
| """ | """ | ||||
| @@ -1,76 +0,0 @@ | |||||
| test_case = { | |||||
| 'sng0073': { | |||||
| 'goal': { | |||||
| 'taxi': { | |||||
| 'info': { | |||||
| 'leaveat': '17:15', | |||||
| 'destination': 'pizza hut fen ditton', | |||||
| 'departure': "saint john's college" | |||||
| }, | |||||
| 'reqt': ['car', 'phone'], | |||||
| 'fail_info': {} | |||||
| } | |||||
| }, | |||||
| 'log': [{ | |||||
| 'user': | |||||
| "i would like a taxi from saint john 's college to pizza hut fen ditton .", | |||||
| 'user_delex': | |||||
| 'i would like a taxi from [value_departure] to [value_destination] .', | |||||
| 'resp': | |||||
| 'what time do you want to leave and what time do you want to arrive by ?', | |||||
| 'sys': | |||||
| 'what time do you want to leave and what time do you want to arrive by ?', | |||||
| 'pointer': '0,0,0,0,0,0', | |||||
| 'match': '', | |||||
| 'constraint': | |||||
| "[taxi] destination pizza hut fen ditton departure saint john 's college", | |||||
| 'cons_delex': '[taxi] destination departure', | |||||
| 'sys_act': '[taxi] [request] leave arrive', | |||||
| 'turn_num': 0, | |||||
| 'turn_domain': '[taxi]' | |||||
| }, { | |||||
| 'user': 'i want to leave after 17:15 .', | |||||
| 'user_delex': 'i want to leave after [value_leave] .', | |||||
| 'resp': | |||||
| 'booking completed ! your taxi will be [value_car] contact number is [value_phone]', | |||||
| 'sys': | |||||
| 'booking completed ! your taxi will be blue honda contact number is 07218068540', | |||||
| 'pointer': '0,0,0,0,0,0', | |||||
| 'match': '', | |||||
| 'constraint': | |||||
| "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", | |||||
| 'cons_delex': '[taxi] destination departure leave', | |||||
| 'sys_act': '[taxi] [inform] car phone', | |||||
| 'turn_num': 1, | |||||
| 'turn_domain': '[taxi]' | |||||
| }, { | |||||
| 'user': 'thank you for all the help ! i appreciate it .', | |||||
| 'user_delex': 'thank you for all the help ! i appreciate it .', | |||||
| 'resp': | |||||
| 'you are welcome . is there anything else i can help you with today ?', | |||||
| 'sys': | |||||
| 'you are welcome . is there anything else i can help you with today ?', | |||||
| 'pointer': '0,0,0,0,0,0', | |||||
| 'match': '', | |||||
| 'constraint': | |||||
| "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", | |||||
| 'cons_delex': '[taxi] destination departure leave', | |||||
| 'sys_act': '[general] [reqmore]', | |||||
| 'turn_num': 2, | |||||
| 'turn_domain': '[general]' | |||||
| }, { | |||||
| 'user': 'no , i am all set . have a nice day . bye .', | |||||
| 'user_delex': 'no , i am all set . have a nice day . bye .', | |||||
| 'resp': 'you too ! thank you', | |||||
| 'sys': 'you too ! thank you', | |||||
| 'pointer': '0,0,0,0,0,0', | |||||
| 'match': '', | |||||
| 'constraint': | |||||
| "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", | |||||
| 'cons_delex': '[taxi] destination departure leave', | |||||
| 'sys_act': '[general] [bye]', | |||||
| 'turn_num': 3, | |||||
| 'turn_domain': '[general]' | |||||
| }] | |||||
| } | |||||
| } | |||||
| @@ -1,4 +0,0 @@ | |||||
| test_case = [ | |||||
| 'How do I locate my card?', | |||||
| 'I still have not received my new card, I ordered over a week ago.' | |||||
| ] | |||||
| @@ -4,8 +4,6 @@ import os.path as osp | |||||
| import tempfile | import tempfile | ||||
| import unittest | import unittest | ||||
| from tests.case.nlp.dialog_generation_case import test_case | |||||
| from modelscope.models.nlp import DialogGenerationModel | from modelscope.models.nlp import DialogGenerationModel | ||||
| from modelscope.pipelines import DialogGenerationPipeline, pipeline | from modelscope.pipelines import DialogGenerationPipeline, pipeline | ||||
| from modelscope.preprocessors import DialogGenerationPreprocessor | from modelscope.preprocessors import DialogGenerationPreprocessor | ||||
| @@ -16,6 +14,82 @@ def merge(info, result): | |||||
| class DialogGenerationTest(unittest.TestCase): | class DialogGenerationTest(unittest.TestCase): | ||||
| test_case = { | |||||
| 'sng0073': { | |||||
| 'goal': { | |||||
| 'taxi': { | |||||
| 'info': { | |||||
| 'leaveat': '17:15', | |||||
| 'destination': 'pizza hut fen ditton', | |||||
| 'departure': "saint john's college" | |||||
| }, | |||||
| 'reqt': ['car', 'phone'], | |||||
| 'fail_info': {} | |||||
| } | |||||
| }, | |||||
| 'log': [{ | |||||
| 'user': | |||||
| "i would like a taxi from saint john 's college to pizza hut fen ditton .", | |||||
| 'user_delex': | |||||
| 'i would like a taxi from [value_departure] to [value_destination] .', | |||||
| 'resp': | |||||
| 'what time do you want to leave and what time do you want to arrive by ?', | |||||
| 'sys': | |||||
| 'what time do you want to leave and what time do you want to arrive by ?', | |||||
| 'pointer': '0,0,0,0,0,0', | |||||
| 'match': '', | |||||
| 'constraint': | |||||
| "[taxi] destination pizza hut fen ditton departure saint john 's college", | |||||
| 'cons_delex': '[taxi] destination departure', | |||||
| 'sys_act': '[taxi] [request] leave arrive', | |||||
| 'turn_num': 0, | |||||
| 'turn_domain': '[taxi]' | |||||
| }, { | |||||
| 'user': 'i want to leave after 17:15 .', | |||||
| 'user_delex': 'i want to leave after [value_leave] .', | |||||
| 'resp': | |||||
| 'booking completed ! your taxi will be [value_car] contact number is [value_phone]', | |||||
| 'sys': | |||||
| 'booking completed ! your taxi will be blue honda contact number is 07218068540', | |||||
| 'pointer': '0,0,0,0,0,0', | |||||
| 'match': '', | |||||
| 'constraint': | |||||
| "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", | |||||
| 'cons_delex': '[taxi] destination departure leave', | |||||
| 'sys_act': '[taxi] [inform] car phone', | |||||
| 'turn_num': 1, | |||||
| 'turn_domain': '[taxi]' | |||||
| }, { | |||||
| 'user': 'thank you for all the help ! i appreciate it .', | |||||
| 'user_delex': 'thank you for all the help ! i appreciate it .', | |||||
| 'resp': | |||||
| 'you are welcome . is there anything else i can help you with today ?', | |||||
| 'sys': | |||||
| 'you are welcome . is there anything else i can help you with today ?', | |||||
| 'pointer': '0,0,0,0,0,0', | |||||
| 'match': '', | |||||
| 'constraint': | |||||
| "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", | |||||
| 'cons_delex': '[taxi] destination departure leave', | |||||
| 'sys_act': '[general] [reqmore]', | |||||
| 'turn_num': 2, | |||||
| 'turn_domain': '[general]' | |||||
| }, { | |||||
| 'user': 'no , i am all set . have a nice day . bye .', | |||||
| 'user_delex': 'no , i am all set . have a nice day . bye .', | |||||
| 'resp': 'you too ! thank you', | |||||
| 'sys': 'you too ! thank you', | |||||
| 'pointer': '0,0,0,0,0,0', | |||||
| 'match': '', | |||||
| 'constraint': | |||||
| "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", | |||||
| 'cons_delex': '[taxi] destination departure leave', | |||||
| 'sys_act': '[general] [bye]', | |||||
| 'turn_num': 3, | |||||
| 'turn_domain': '[general]' | |||||
| }] | |||||
| } | |||||
| } | |||||
| def test_run(self): | def test_run(self): | ||||
| @@ -1,11 +1,9 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | |||||
| import os.path as osp | |||||
| import tempfile | |||||
| import unittest | import unittest | ||||
| from tests.case.nlp.dialog_intent_case import test_case | |||||
| from maas_hub.snapshot_download import snapshot_download | |||||
| from modelscope.models import Model | |||||
| from modelscope.models.nlp import DialogIntentModel | from modelscope.models.nlp import DialogIntentModel | ||||
| from modelscope.pipelines import DialogIntentPipeline, pipeline | from modelscope.pipelines import DialogIntentPipeline, pipeline | ||||
| from modelscope.preprocessors import DialogIntentPreprocessor | from modelscope.preprocessors import DialogIntentPreprocessor | ||||
| @@ -13,22 +11,46 @@ from modelscope.utils.constant import Tasks | |||||
| class DialogGenerationTest(unittest.TestCase): | class DialogGenerationTest(unittest.TestCase): | ||||
| model_id = 'damo/nlp_space_dialog-intent' | |||||
| test_case = [ | |||||
| 'How do I locate my card?', | |||||
| 'I still have not received my new card, I ordered over a week ago.' | |||||
| ] | |||||
| @unittest.skip('test with snapshot_download') | |||||
| def test_run(self): | def test_run(self): | ||||
| modeldir = '/Users/yangliu/Desktop/space-dialog-intent' | |||||
| preprocessor = DialogIntentPreprocessor(model_dir=modeldir) | |||||
| cache_path = snapshot_download(self.model_id) | |||||
| preprocessor = DialogIntentPreprocessor(model_dir=cache_path) | |||||
| model = DialogIntentModel( | model = DialogIntentModel( | ||||
| model_dir=modeldir, | |||||
| model_dir=cache_path, | |||||
| text_field=preprocessor.text_field, | text_field=preprocessor.text_field, | ||||
| config=preprocessor.config) | config=preprocessor.config) | ||||
| pipeline1 = DialogIntentPipeline( | |||||
| model=model, preprocessor=preprocessor) | |||||
| # pipeline1 = pipeline(task=Tasks.dialog_intent, model=model, preprocessor=preprocessor) | |||||
| for item in test_case: | |||||
| print(pipeline1(item)) | |||||
| pipelines = [ | |||||
| DialogIntentPipeline(model=model, preprocessor=preprocessor), | |||||
| pipeline( | |||||
| task=Tasks.dialog_intent, | |||||
| model=model, | |||||
| preprocessor=preprocessor) | |||||
| ] | |||||
| for my_pipeline, item in list(zip(pipelines, self.test_case)): | |||||
| print(my_pipeline(item)) | |||||
| def test_run_with_model_from_modelhub(self): | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| preprocessor = DialogIntentPreprocessor(model_dir=model.model_dir) | |||||
| pipelines = [ | |||||
| DialogIntentPipeline(model=model, preprocessor=preprocessor), | |||||
| pipeline( | |||||
| task=Tasks.dialog_intent, | |||||
| model=model, | |||||
| preprocessor=preprocessor) | |||||
| ] | |||||
| for my_pipeline, item in list(zip(pipelines, self.test_case)): | |||||
| print(my_pipeline(item)) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||