| @@ -1,6 +1,10 @@ | |||||
| import os | |||||
| from typing import Any, Dict, Optional | from typing import Any, Dict, Optional | ||||
| from modelscope.preprocessors.space.fields.gen_field import \ | |||||
| MultiWOZBPETextField | |||||
| from modelscope.trainers.nlp.space.trainers.gen_trainer import MultiWOZTrainer | from modelscope.trainers.nlp.space.trainers.gen_trainer import MultiWOZTrainer | ||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from ...base import Model, Tensor | from ...base import Model, Tensor | ||||
| from ...builder import MODELS | from ...builder import MODELS | ||||
| @@ -25,8 +29,13 @@ class DialogGenerationModel(Model): | |||||
| super().__init__(model_dir, *args, **kwargs) | super().__init__(model_dir, *args, **kwargs) | ||||
| self.model_dir = model_dir | self.model_dir = model_dir | ||||
| self.text_field = kwargs.pop('text_field') | |||||
| self.config = kwargs.pop('config') | |||||
| self.config = kwargs.pop( | |||||
| 'config', | |||||
| Config.from_file( | |||||
| os.path.join(self.model_dir, 'configuration.json'))) | |||||
| self.text_field = kwargs.pop( | |||||
| 'text_field', | |||||
| MultiWOZBPETextField(self.model_dir, config=self.config)) | |||||
| self.generator = Generator.create(self.config, reader=self.text_field) | self.generator = Generator.create(self.config, reader=self.text_field) | ||||
| self.model = ModelBase.create( | self.model = ModelBase.create( | ||||
| model_dir=model_dir, | model_dir=model_dir, | ||||
| @@ -65,39 +74,10 @@ class DialogGenerationModel(Model): | |||||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | ||||
| } | } | ||||
| """ | """ | ||||
| from numpy import array, float32 | |||||
| import torch | |||||
| # turn_1 = { | |||||
| # 'user': [ | |||||
| # 13, 1045, 2052, 2066, 1037, 10095, 2013, 3002, 2198, 1005, | |||||
| # 1055, 2267, 2000, 10733, 12570, 21713, 4487, 15474, 1012, 7 | |||||
| # ] | |||||
| # } | |||||
| # old_pv_turn_1 = {} | |||||
| turn = {'user': input['user']} | |||||
| old_pv_turn = input['history'] | |||||
| turn_2 = { | |||||
| 'user': | |||||
| [13, 1045, 2215, 2000, 2681, 2044, 2459, 1024, 2321, 1012, 7] | |||||
| } | |||||
| old_pv_turn_2 = { | |||||
| 'labels': [[ | |||||
| 13, 1045, 2052, 2066, 1037, 10095, 2013, 3002, 2198, 1005, | |||||
| 1055, 2267, 2000, 10733, 12570, 21713, 4487, 15474, 1012, 7 | |||||
| ]], | |||||
| 'resp': [ | |||||
| 14, 1045, 2052, 2022, 3407, 2000, 2393, 2007, 2115, 5227, 1010, | |||||
| 2079, 2017, 2031, 1037, 2051, 2017, 2052, 2066, 2000, 2681, | |||||
| 2030, 7180, 2011, 1029, 8 | |||||
| ], | |||||
| 'bspn': [ | |||||
| 15, 43, 7688, 10733, 12570, 21713, 4487, 15474, 6712, 3002, | |||||
| 2198, 1005, 1055, 2267, 9 | |||||
| ], | |||||
| 'db': [19, 24, 21, 20], | |||||
| 'aspn': [16, 43, 48, 2681, 7180, 10] | |||||
| } | |||||
| pv_turn = self.trainer.forward(turn=turn_2, old_pv_turn=old_pv_turn_2) | |||||
| pv_turn = self.trainer.forward(turn=turn, old_pv_turn=old_pv_turn) | |||||
| return pv_turn | return pv_turn | ||||
| @@ -15,7 +15,7 @@ from modelscope.utils.logger import get_logger | |||||
| from .util import is_model_name | from .util import is_model_name | ||||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | Tensor = Union['torch.Tensor', 'tf.Tensor'] | ||||
| Input = Union[str, PyDataset, 'PIL.Image.Image', 'numpy.ndarray'] | |||||
| Input = Union[str, PyDataset, Dict, 'PIL.Image.Image', 'numpy.ndarray'] | |||||
| InputModel = Union[str, Model] | InputModel = Union[str, Model] | ||||
| output_keys = [ | output_keys = [ | ||||
| @@ -24,6 +24,7 @@ class DialogGenerationPipeline(Pipeline): | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | super().__init__(model=model, preprocessor=preprocessor, **kwargs) | ||||
| self.model = model | self.model = model | ||||
| self.preprocessor = preprocessor | |||||
| def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: | def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: | ||||
| """process the prediction results | """process the prediction results | ||||
| @@ -34,17 +35,12 @@ class DialogGenerationPipeline(Pipeline): | |||||
| Returns: | Returns: | ||||
| Dict[str, str]: the prediction results | Dict[str, str]: the prediction results | ||||
| """ | """ | ||||
| sys_rsp = self.preprocessor.text_field.tokenizer.convert_ids_to_tokens( | |||||
| inputs['resp']) | |||||
| assert len(sys_rsp) > 2 | |||||
| sys_rsp = sys_rsp[1:len(sys_rsp) - 1] | |||||
| # sys_rsp = self.preprocessor.text_field.tokenizer. | |||||
| vocab_size = len(self.tokenizer.vocab) | |||||
| pred_list = inputs['predictions'] | |||||
| pred_ids = pred_list[0][0].cpu().numpy().tolist() | |||||
| for j in range(len(pred_ids)): | |||||
| if pred_ids[j] >= vocab_size: | |||||
| pred_ids[j] = 100 | |||||
| pred = self.tokenizer.convert_ids_to_tokens(pred_ids) | |||||
| pred_string = ''.join(pred).replace( | |||||
| '##', | |||||
| '').split('[SEP]')[0].replace('[CLS]', | |||||
| '').replace('[SEP]', | |||||
| '').replace('[UNK]', '') | |||||
| return {'pred_string': pred_string} | |||||
| inputs['sys'] = sys_rsp | |||||
| return inputs | |||||
| @@ -32,8 +32,8 @@ class DialogGenerationPreprocessor(Preprocessor): | |||||
| self.text_field = MultiWOZBPETextField( | self.text_field = MultiWOZBPETextField( | ||||
| self.model_dir, config=self.config) | self.model_dir, config=self.config) | ||||
| @type_assert(object, str) | |||||
| def __call__(self, data: str) -> Dict[str, Any]: | |||||
| @type_assert(object, Dict) | |||||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| """process the raw input data | """process the raw input data | ||||
| Args: | Args: | ||||
| @@ -45,6 +45,7 @@ class DialogGenerationPreprocessor(Preprocessor): | |||||
| Dict[str, Any]: the preprocessed data | Dict[str, Any]: the preprocessed data | ||||
| """ | """ | ||||
| idx = self.text_field.get_ids(data) | |||||
| user_ids = self.text_field.get_ids(data['user_input']) | |||||
| data['user'] = user_ids | |||||
| return {'user_idx': idx} | |||||
| return data | |||||
| @@ -13,6 +13,7 @@ import torch | |||||
| from tqdm import tqdm | from tqdm import tqdm | ||||
| from transformers.optimization import AdamW, get_linear_schedule_with_warmup | from transformers.optimization import AdamW, get_linear_schedule_with_warmup | ||||
| import modelscope.utils.nlp.space.ontology as ontology | |||||
| from ..metrics.metrics_tracker import MetricsTracker | from ..metrics.metrics_tracker import MetricsTracker | ||||
| @@ -668,10 +669,45 @@ class MultiWOZTrainer(Trainer): | |||||
| return | return | ||||
| def _get_turn_doamin(self, constraint_ids, bspn_gen_ids): | |||||
| # constraint_token = self.tokenizer.convert_ids_to_tokens(constraint_ids) | |||||
| # bspn_token = self.tokenizer.convert_ids_to_tokens(bspn_gen_ids) | |||||
| return [] | |||||
| def _get_turn_domain(self, old_pv_turn, bspn_gen_ids, first_turn): | |||||
| def _get_slots(constraint): | |||||
| domain_name = '' | |||||
| slots = {} | |||||
| for item in constraint: | |||||
| if item in ontology.placeholder_tokens: | |||||
| continue | |||||
| if item in ontology.all_domains_with_bracket: | |||||
| domain_name = item | |||||
| slots[domain_name] = set() | |||||
| else: | |||||
| assert domain_name in ontology.all_domains_with_bracket | |||||
| slots[domain_name].add(item) | |||||
| return slots | |||||
| turn_domain = [] | |||||
| if first_turn and len(bspn_gen_ids) == 0: | |||||
| turn_domain = ['[general]'] | |||||
| return turn_domain | |||||
| bspn_token = self.tokenizer.convert_ids_to_tokens(bspn_gen_ids) | |||||
| turn_slots = _get_slots(bspn_token) | |||||
| if first_turn: | |||||
| return list(turn_slots.keys()) | |||||
| assert 'bspn' in old_pv_turn | |||||
| pv_bspn_token = self.tokenizer.convert_ids_to_tokens( | |||||
| old_pv_turn['bspn']) | |||||
| pv_turn_slots = _get_slots(pv_bspn_token) | |||||
| for domain, value in turn_slots.items(): | |||||
| pv_value = pv_turn_slots[ | |||||
| domain] if domain in pv_turn_slots else set() | |||||
| if len(value - pv_value) > 0 or len(pv_value - value): | |||||
| turn_domain.append(domain) | |||||
| if len(turn_domain) == 0: | |||||
| turn_domain = list(turn_slots.keys()) | |||||
| return turn_domain | |||||
| def forward(self, turn, old_pv_turn): | def forward(self, turn, old_pv_turn): | ||||
| with torch.no_grad(): | with torch.no_grad(): | ||||
| @@ -692,14 +728,11 @@ class MultiWOZTrainer(Trainer): | |||||
| generated_bs = outputs[0].cpu().numpy().tolist() | generated_bs = outputs[0].cpu().numpy().tolist() | ||||
| bspn_gen = self.decode_generated_bspn(generated_bs) | bspn_gen = self.decode_generated_bspn(generated_bs) | ||||
| turn_domain = self._get_turn_doamin(old_pv_turn['constraint_ids'], | |||||
| bspn_gen) | |||||
| print(turn_domain) | |||||
| turn_domain = self._get_turn_domain(old_pv_turn, bspn_gen, | |||||
| first_turn) | |||||
| db_result = self.reader.bspan_to_DBpointer( | db_result = self.reader.bspan_to_DBpointer( | ||||
| self.tokenizer.decode(bspn_gen), turn_domain) | self.tokenizer.decode(bspn_gen), turn_domain) | ||||
| print(db_result) | |||||
| assert len(turn['db']) == 3 | |||||
| assert isinstance(db_result, str) | assert isinstance(db_result, str) | ||||
| db = \ | db = \ | ||||
| [self.reader.sos_db_id] + \ | [self.reader.sos_db_id] + \ | ||||
| @@ -718,14 +751,11 @@ class MultiWOZTrainer(Trainer): | |||||
| generated_ar = outputs_db[0].cpu().numpy().tolist() | generated_ar = outputs_db[0].cpu().numpy().tolist() | ||||
| decoded = self.decode_generated_act_resp(generated_ar) | decoded = self.decode_generated_act_resp(generated_ar) | ||||
| decoded['bspn'] = bspn_gen | decoded['bspn'] = bspn_gen | ||||
| print(decoded) | |||||
| print(self.tokenizer.convert_ids_to_tokens(decoded['resp'])) | |||||
| pv_turn['labels'] = None | |||||
| pv_turn['labels'] = inputs['labels'] | |||||
| pv_turn['resp'] = decoded['resp'] | pv_turn['resp'] = decoded['resp'] | ||||
| pv_turn['bspn'] = decoded['bspn'] | pv_turn['bspn'] = decoded['bspn'] | ||||
| pv_turn['db'] = None | |||||
| pv_turn['aspn'] = None | |||||
| pv_turn['constraint_ids'] = bspn_gen | |||||
| pv_turn['db'] = db | |||||
| pv_turn['aspn'] = decoded['aspn'] | |||||
| return pv_turn | return pv_turn | ||||
| @@ -1,7 +1,13 @@ | |||||
| all_domains = [ | all_domains = [ | ||||
| 'restaurant', 'hotel', 'attraction', 'train', 'taxi', 'police', 'hospital' | 'restaurant', 'hotel', 'attraction', 'train', 'taxi', 'police', 'hospital' | ||||
| ] | ] | ||||
| all_domains_with_bracket = ['[{}]'.format(item) for item in all_domains] | |||||
| db_domains = ['restaurant', 'hotel', 'attraction', 'train'] | db_domains = ['restaurant', 'hotel', 'attraction', 'train'] | ||||
| placeholder_tokens = [ | |||||
| '<go_r>', '<go_b>', '<go_a>', '<go_d>', '<eos_u>', '<eos_r>', '<eos_b>', | |||||
| '<eos_a>', '<eos_d>', '<eos_q>', '<sos_u>', '<sos_r>', '<sos_b>', | |||||
| '<sos_a>', '<sos_d>', '<sos_q>' | |||||
| ] | |||||
| normlize_slot_names = { | normlize_slot_names = { | ||||
| 'car type': 'car', | 'car type': 'car', | ||||
| @@ -4,16 +4,17 @@ import os.path as osp | |||||
| import tempfile | import tempfile | ||||
| import unittest | import unittest | ||||
| from maas_hub.snapshot_download import snapshot_download | |||||
| from modelscope.models import Model | |||||
| from modelscope.models.nlp import DialogGenerationModel | from modelscope.models.nlp import DialogGenerationModel | ||||
| from modelscope.pipelines import DialogGenerationPipeline, pipeline | from modelscope.pipelines import DialogGenerationPipeline, pipeline | ||||
| from modelscope.preprocessors import DialogGenerationPreprocessor | from modelscope.preprocessors import DialogGenerationPreprocessor | ||||
| def merge(info, result): | |||||
| return info | |||||
| from modelscope.utils.constant import Tasks | |||||
| class DialogGenerationTest(unittest.TestCase): | class DialogGenerationTest(unittest.TestCase): | ||||
| model_id = 'damo/nlp_space_dialog-generation' | |||||
| test_case = { | test_case = { | ||||
| 'sng0073': { | 'sng0073': { | ||||
| 'goal': { | 'goal': { | ||||
| @@ -91,30 +92,58 @@ class DialogGenerationTest(unittest.TestCase): | |||||
| } | } | ||||
| } | } | ||||
| @unittest.skip('test with snapshot_download') | |||||
| def test_run(self): | def test_run(self): | ||||
| modeldir = '/Users/yangliu/Desktop/space-dialog-generation' | |||||
| cache_path = '/Users/yangliu/Space/maas_model/nlp_space_dialog-generation' | |||||
| # cache_path = snapshot_download(self.model_id) | |||||
| preprocessor = DialogGenerationPreprocessor(model_dir=modeldir) | |||||
| preprocessor = DialogGenerationPreprocessor(model_dir=cache_path) | |||||
| model = DialogGenerationModel( | model = DialogGenerationModel( | ||||
| model_dir=modeldir, | |||||
| model_dir=cache_path, | |||||
| text_field=preprocessor.text_field, | text_field=preprocessor.text_field, | ||||
| config=preprocessor.config) | config=preprocessor.config) | ||||
| print(model.forward(None)) | |||||
| # pipeline = DialogGenerationPipeline( | |||||
| # model=model, preprocessor=preprocessor) | |||||
| pipelines = [ | |||||
| DialogGenerationPipeline(model=model, preprocessor=preprocessor), | |||||
| pipeline( | |||||
| task=Tasks.dialog_generation, | |||||
| model=model, | |||||
| preprocessor=preprocessor) | |||||
| ] | |||||
| result = {} | |||||
| for step, item in enumerate(self.test_case['sng0073']['log']): | |||||
| user = item['user'] | |||||
| print('user: {}'.format(user)) | |||||
| result = pipelines[step % 2]({ | |||||
| 'user_input': user, | |||||
| 'history': result | |||||
| }) | |||||
| print('sys : {}'.format(result['sys'])) | |||||
| def test_run_with_model_from_modelhub(self): | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| preprocessor = DialogGenerationPreprocessor(model_dir=model.model_dir) | |||||
| pipelines = [ | |||||
| DialogGenerationPipeline(model=model, preprocessor=preprocessor), | |||||
| pipeline( | |||||
| task=Tasks.dialog_generation, | |||||
| model=model, | |||||
| preprocessor=preprocessor) | |||||
| ] | |||||
| result = {} | |||||
| for step, item in enumerate(self.test_case['sng0073']['log']): | |||||
| user = item['user'] | |||||
| print('user: {}'.format(user)) | |||||
| # history_dialog_info = {} | |||||
| # for step, item in enumerate(test_case['sng0073']['log']): | |||||
| # user_question = item['user'] | |||||
| # print('user: {}'.format(user_question)) | |||||
| # | |||||
| # # history_dialog_info = merge(history_dialog_info, | |||||
| # # result) if step > 0 else {} | |||||
| # result = pipeline(user_question, history=history_dialog_info) | |||||
| # # | |||||
| # # print('sys : {}'.format(result['pred_answer'])) | |||||
| print('test') | |||||
| result = pipelines[step % 2]({ | |||||
| 'user_input': user, | |||||
| 'history': result | |||||
| }) | |||||
| print('sys : {}'.format(result['sys'])) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||