| @@ -1,3 +1,4 @@ | |||
| from .sequence_classification_model import * # noqa F403 | |||
| from .space.dialog_generation_model import * # noqa F403 | |||
| from .space.dialog_intent_model import * # noqa F403 | |||
| from .text_generation_model import * # noqa F403 | |||
| @@ -0,0 +1,103 @@ | |||
| from typing import Any, Dict, Optional | |||
| from modelscope.trainers.nlp.space.trainers.gen_trainer import MultiWOZTrainer | |||
| from modelscope.utils.constant import Tasks | |||
| from ...base import Model, Tensor | |||
| from ...builder import MODELS | |||
| from .model.generator import Generator | |||
| from .model.model_base import ModelBase | |||
| __all__ = ['DialogGenerationModel'] | |||
| @MODELS.register_module( | |||
| Tasks.dialog_generation, module_name=r'space-generation') | |||
| class DialogGenerationModel(Model): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the test generation model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| model_cls (Optional[Any], optional): model loader, if None, use the | |||
| default loader to load model weights, by default None. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| self.model_dir = model_dir | |||
| self.text_field = kwargs.pop('text_field') | |||
| self.config = kwargs.pop('config') | |||
| self.generator = Generator.create(self.config, reader=self.text_field) | |||
| self.model = ModelBase.create( | |||
| model_dir=model_dir, | |||
| config=self.config, | |||
| reader=self.text_field, | |||
| generator=self.generator) | |||
| def to_tensor(array): | |||
| """ | |||
| numpy array -> tensor | |||
| """ | |||
| import torch | |||
| array = torch.tensor(array) | |||
| return array.cuda() if self.config.use_gpu else array | |||
| self.trainer = MultiWOZTrainer( | |||
| model=self.model, | |||
| to_tensor=to_tensor, | |||
| config=self.config, | |||
| reader=self.text_field, | |||
| evaluator=None) | |||
| self.trainer.load() | |||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| """return the result by the model | |||
| Args: | |||
| input (Dict[str, Any]): the preprocessed data | |||
| Returns: | |||
| Dict[str, np.ndarray]: results | |||
| Example: | |||
| { | |||
| 'predictions': array([1]), # lable 0-negative 1-positive | |||
| 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||
| } | |||
| """ | |||
| from numpy import array, float32 | |||
| import torch | |||
| # turn_1 = { | |||
| # 'user': [ | |||
| # 13, 1045, 2052, 2066, 1037, 10095, 2013, 3002, 2198, 1005, | |||
| # 1055, 2267, 2000, 10733, 12570, 21713, 4487, 15474, 1012, 7 | |||
| # ] | |||
| # } | |||
| # old_pv_turn_1 = {} | |||
| turn_2 = { | |||
| 'user': | |||
| [13, 1045, 2215, 2000, 2681, 2044, 2459, 1024, 2321, 1012, 7] | |||
| } | |||
| old_pv_turn_2 = { | |||
| 'labels': [[ | |||
| 13, 1045, 2052, 2066, 1037, 10095, 2013, 3002, 2198, 1005, | |||
| 1055, 2267, 2000, 10733, 12570, 21713, 4487, 15474, 1012, 7 | |||
| ]], | |||
| 'resp': [ | |||
| 14, 1045, 2052, 2022, 3407, 2000, 2393, 2007, 2115, 5227, 1010, | |||
| 2079, 2017, 2031, 1037, 2051, 2017, 2052, 2066, 2000, 2681, | |||
| 2030, 7180, 2011, 1029, 8 | |||
| ], | |||
| 'bspn': [ | |||
| 15, 43, 7688, 10733, 12570, 21713, 4487, 15474, 6712, 3002, | |||
| 2198, 1005, 1055, 2267, 9 | |||
| ], | |||
| 'db': [19, 24, 21, 20], | |||
| 'aspn': [16, 43, 48, 2681, 7180, 10] | |||
| } | |||
| pv_turn = self.trainer.forward(turn=turn_2, old_pv_turn=old_pv_turn_2) | |||
| return pv_turn | |||
| @@ -1,3 +1,4 @@ | |||
| from .sequence_classification_pipeline import * # noqa F403 | |||
| from .space.dialog_generation_pipeline import * # noqa F403 | |||
| from .space.dialog_intent_pipeline import * # noqa F403 | |||
| from .text_generation_pipeline import * # noqa F403 | |||
| @@ -0,0 +1,50 @@ | |||
| from typing import Any, Dict, Optional | |||
| from modelscope.models.nlp import DialogGenerationModel | |||
| from modelscope.preprocessors import DialogGenerationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from ...base import Pipeline, Tensor | |||
| from ...builder import PIPELINES | |||
| __all__ = ['DialogGenerationPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.dialog_generation, module_name=r'space-generation') | |||
| class DialogGenerationPipeline(Pipeline): | |||
| def __init__(self, model: DialogGenerationModel, | |||
| preprocessor: DialogGenerationPreprocessor, **kwargs): | |||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||
| Args: | |||
| model (SequenceClassificationModel): a model instance | |||
| preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | |||
| """ | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| self.model = model | |||
| def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): _description_ | |||
| Returns: | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| vocab_size = len(self.tokenizer.vocab) | |||
| pred_list = inputs['predictions'] | |||
| pred_ids = pred_list[0][0].cpu().numpy().tolist() | |||
| for j in range(len(pred_ids)): | |||
| if pred_ids[j] >= vocab_size: | |||
| pred_ids[j] = 100 | |||
| pred = self.tokenizer.convert_ids_to_tokens(pred_ids) | |||
| pred_string = ''.join(pred).replace( | |||
| '##', | |||
| '').split('[SEP]')[0].replace('[CLS]', | |||
| '').replace('[SEP]', | |||
| '').replace('[UNK]', '') | |||
| return {'pred_string': pred_string} | |||
| @@ -6,4 +6,5 @@ from .common import Compose | |||
| from .image import LoadImage, load_image | |||
| from .nlp import * # noqa F403 | |||
| from .nlp import TextGenerationPreprocessor | |||
| from .space.dialog_generation_preprocessor import * # noqa F403 | |||
| from .space.dialog_intent_preprocessor import * # noqa F403 | |||
| @@ -0,0 +1,50 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import uuid | |||
| from typing import Any, Dict, Union | |||
| from modelscope.preprocessors.space.fields.gen_field import \ | |||
| MultiWOZBPETextField | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import Fields, InputFields | |||
| from modelscope.utils.type_assert import type_assert | |||
| from ..base import Preprocessor | |||
| from ..builder import PREPROCESSORS | |||
| __all__ = ['DialogGenerationPreprocessor'] | |||
| @PREPROCESSORS.register_module(Fields.nlp, module_name=r'space-generation') | |||
| class DialogGenerationPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||
| Args: | |||
| model_dir (str): model path | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| self.model_dir: str = model_dir | |||
| self.config = Config.from_file( | |||
| os.path.join(self.model_dir, 'configuration.json')) | |||
| self.text_field = MultiWOZBPETextField( | |||
| self.model_dir, config=self.config) | |||
| @type_assert(object, str) | |||
| def __call__(self, data: str) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| Args: | |||
| data (str): a sentence | |||
| Example: | |||
| 'you are so handsome.' | |||
| Returns: | |||
| Dict[str, Any]: the preprocessed data | |||
| """ | |||
| idx = self.text_field.get_ids(data) | |||
| return {'user_idx': idx} | |||
| @@ -668,6 +668,11 @@ class MultiWOZTrainer(Trainer): | |||
| return | |||
| def _get_turn_doamin(self, constraint_ids, bspn_gen_ids): | |||
| # constraint_token = self.tokenizer.convert_ids_to_tokens(constraint_ids) | |||
| # bspn_token = self.tokenizer.convert_ids_to_tokens(bspn_gen_ids) | |||
| return [] | |||
| def forward(self, turn, old_pv_turn): | |||
| with torch.no_grad(): | |||
| first_turn = True if len(old_pv_turn) == 0 else False | |||
| @@ -678,7 +683,6 @@ class MultiWOZTrainer(Trainer): | |||
| batch = type(batch)( | |||
| map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items())) | |||
| pv_turn = {} | |||
| print(batch) | |||
| outputs = self.func_model.infer( | |||
| inputs=batch, | |||
| @@ -687,29 +691,24 @@ class MultiWOZTrainer(Trainer): | |||
| max_gen_len=60) | |||
| generated_bs = outputs[0].cpu().numpy().tolist() | |||
| bspn_gen = self.decode_generated_bspn(generated_bs) | |||
| bspn_token = self.tokenizer.convert_ids_to_tokens(bspn_gen) | |||
| print(bspn_gen) | |||
| print(bspn_token) | |||
| turn_domain = [] | |||
| for item in bspn_token: | |||
| if item.startswith('[') and item.endswith(']'): | |||
| turn_domain.append(item) | |||
| turn_domain = self._get_turn_doamin(old_pv_turn['constraint_ids'], | |||
| bspn_gen) | |||
| print(turn_domain) | |||
| db_result = self.reader.bspan_to_DBpointer( | |||
| self.tokenizer.decode(bspn_gen), ['[taxi]']) | |||
| self.tokenizer.decode(bspn_gen), turn_domain) | |||
| print(db_result) | |||
| book_result = 21 | |||
| assert len(turn['db']) == 3 | |||
| assert isinstance(db_result, str) | |||
| db = \ | |||
| [self.reader.sos_db_id] + \ | |||
| self.tokenizer.convert_tokens_to_ids([db_result]) + \ | |||
| [book_result] + \ | |||
| [self.reader.eos_db_id] | |||
| prompt_id = self.reader.sos_a_id | |||
| prev_input = torch.tensor(bspn_gen + db) | |||
| if self.func_model.use_gpu: | |||
| prev_input = prev_input.cuda() | |||
| outputs_db = self.func_model.infer( | |||
| inputs=batch, | |||
| start_id=prompt_id, | |||
| @@ -727,5 +726,6 @@ class MultiWOZTrainer(Trainer): | |||
| pv_turn['bspn'] = decoded['bspn'] | |||
| pv_turn['db'] = None | |||
| pv_turn['aspn'] = None | |||
| pv_turn['constraint_ids'] = bspn_gen | |||
| return pv_turn | |||
| @@ -0,0 +1,121 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import os.path as osp | |||
| import tempfile | |||
| import unittest | |||
| from modelscope.models.nlp import DialogGenerationModel | |||
| from modelscope.pipelines import DialogGenerationPipeline, pipeline | |||
| from modelscope.preprocessors import DialogGenerationPreprocessor | |||
| def merge(info, result): | |||
| return info | |||
| class DialogGenerationTest(unittest.TestCase): | |||
| test_case = { | |||
| 'sng0073': { | |||
| 'goal': { | |||
| 'taxi': { | |||
| 'info': { | |||
| 'leaveat': '17:15', | |||
| 'destination': 'pizza hut fen ditton', | |||
| 'departure': "saint john's college" | |||
| }, | |||
| 'reqt': ['car', 'phone'], | |||
| 'fail_info': {} | |||
| } | |||
| }, | |||
| 'log': [{ | |||
| 'user': | |||
| "i would like a taxi from saint john 's college to pizza hut fen ditton .", | |||
| 'user_delex': | |||
| 'i would like a taxi from [value_departure] to [value_destination] .', | |||
| 'resp': | |||
| 'what time do you want to leave and what time do you want to arrive by ?', | |||
| 'sys': | |||
| 'what time do you want to leave and what time do you want to arrive by ?', | |||
| 'pointer': '0,0,0,0,0,0', | |||
| 'match': '', | |||
| 'constraint': | |||
| "[taxi] destination pizza hut fen ditton departure saint john 's college", | |||
| 'cons_delex': '[taxi] destination departure', | |||
| 'sys_act': '[taxi] [request] leave arrive', | |||
| 'turn_num': 0, | |||
| 'turn_domain': '[taxi]' | |||
| }, { | |||
| 'user': 'i want to leave after 17:15 .', | |||
| 'user_delex': 'i want to leave after [value_leave] .', | |||
| 'resp': | |||
| 'booking completed ! your taxi will be [value_car] contact number is [value_phone]', | |||
| 'sys': | |||
| 'booking completed ! your taxi will be blue honda contact number is 07218068540', | |||
| 'pointer': '0,0,0,0,0,0', | |||
| 'match': '', | |||
| 'constraint': | |||
| "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", | |||
| 'cons_delex': '[taxi] destination departure leave', | |||
| 'sys_act': '[taxi] [inform] car phone', | |||
| 'turn_num': 1, | |||
| 'turn_domain': '[taxi]' | |||
| }, { | |||
| 'user': 'thank you for all the help ! i appreciate it .', | |||
| 'user_delex': 'thank you for all the help ! i appreciate it .', | |||
| 'resp': | |||
| 'you are welcome . is there anything else i can help you with today ?', | |||
| 'sys': | |||
| 'you are welcome . is there anything else i can help you with today ?', | |||
| 'pointer': '0,0,0,0,0,0', | |||
| 'match': '', | |||
| 'constraint': | |||
| "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", | |||
| 'cons_delex': '[taxi] destination departure leave', | |||
| 'sys_act': '[general] [reqmore]', | |||
| 'turn_num': 2, | |||
| 'turn_domain': '[general]' | |||
| }, { | |||
| 'user': 'no , i am all set . have a nice day . bye .', | |||
| 'user_delex': 'no , i am all set . have a nice day . bye .', | |||
| 'resp': 'you too ! thank you', | |||
| 'sys': 'you too ! thank you', | |||
| 'pointer': '0,0,0,0,0,0', | |||
| 'match': '', | |||
| 'constraint': | |||
| "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", | |||
| 'cons_delex': '[taxi] destination departure leave', | |||
| 'sys_act': '[general] [bye]', | |||
| 'turn_num': 3, | |||
| 'turn_domain': '[general]' | |||
| }] | |||
| } | |||
| } | |||
| def test_run(self): | |||
| modeldir = '/Users/yangliu/Desktop/space-dialog-generation' | |||
| preprocessor = DialogGenerationPreprocessor(model_dir=modeldir) | |||
| model = DialogGenerationModel( | |||
| model_dir=modeldir, | |||
| text_field=preprocessor.text_field, | |||
| config=preprocessor.config) | |||
| print(model.forward(None)) | |||
| # pipeline = DialogGenerationPipeline( | |||
| # model=model, preprocessor=preprocessor) | |||
| # history_dialog_info = {} | |||
| # for step, item in enumerate(test_case['sng0073']['log']): | |||
| # user_question = item['user'] | |||
| # print('user: {}'.format(user_question)) | |||
| # | |||
| # # history_dialog_info = merge(history_dialog_info, | |||
| # # result) if step > 0 else {} | |||
| # result = pipeline(user_question, history=history_dialog_info) | |||
| # # | |||
| # # print('sys : {}'.format(result['pred_answer'])) | |||
| print('test') | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||