diff --git a/modelscope/models/nlp/space/dialog_generation_model.py b/modelscope/models/nlp/space/dialog_generation_model.py index db8c40e0..95a9ecfd 100644 --- a/modelscope/models/nlp/space/dialog_generation_model.py +++ b/modelscope/models/nlp/space/dialog_generation_model.py @@ -1,6 +1,10 @@ +import os from typing import Any, Dict, Optional +from modelscope.preprocessors.space.fields.gen_field import \ + MultiWOZBPETextField from modelscope.trainers.nlp.space.trainers.gen_trainer import MultiWOZTrainer +from modelscope.utils.config import Config from modelscope.utils.constant import Tasks from ...base import Model, Tensor from ...builder import MODELS @@ -25,8 +29,13 @@ class DialogGenerationModel(Model): super().__init__(model_dir, *args, **kwargs) self.model_dir = model_dir - self.text_field = kwargs.pop('text_field') - self.config = kwargs.pop('config') + self.config = kwargs.pop( + 'config', + Config.from_file( + os.path.join(self.model_dir, 'configuration.json'))) + self.text_field = kwargs.pop( + 'text_field', + MultiWOZBPETextField(self.model_dir, config=self.config)) self.generator = Generator.create(self.config, reader=self.text_field) self.model = ModelBase.create( model_dir=model_dir, @@ -65,39 +74,10 @@ class DialogGenerationModel(Model): 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value } """ - from numpy import array, float32 - import torch - # turn_1 = { - # 'user': [ - # 13, 1045, 2052, 2066, 1037, 10095, 2013, 3002, 2198, 1005, - # 1055, 2267, 2000, 10733, 12570, 21713, 4487, 15474, 1012, 7 - # ] - # } - # old_pv_turn_1 = {} + turn = {'user': input['user']} + old_pv_turn = input['history'] - turn_2 = { - 'user': - [13, 1045, 2215, 2000, 2681, 2044, 2459, 1024, 2321, 1012, 7] - } - old_pv_turn_2 = { - 'labels': [[ - 13, 1045, 2052, 2066, 1037, 10095, 2013, 3002, 2198, 1005, - 1055, 2267, 2000, 10733, 12570, 21713, 4487, 15474, 1012, 7 - ]], - 'resp': [ - 14, 1045, 2052, 2022, 3407, 2000, 2393, 2007, 2115, 5227, 1010, - 2079, 2017, 2031, 1037, 2051, 2017, 2052, 2066, 2000, 2681, - 2030, 7180, 2011, 1029, 8 - ], - 'bspn': [ - 15, 43, 7688, 10733, 12570, 21713, 4487, 15474, 6712, 3002, - 2198, 1005, 1055, 2267, 9 - ], - 'db': [19, 24, 21, 20], - 'aspn': [16, 43, 48, 2681, 7180, 10] - } - - pv_turn = self.trainer.forward(turn=turn_2, old_pv_turn=old_pv_turn_2) + pv_turn = self.trainer.forward(turn=turn, old_pv_turn=old_pv_turn) return pv_turn diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index 41a80896..32297877 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -15,7 +15,7 @@ from modelscope.utils.logger import get_logger from .util import is_model_name Tensor = Union['torch.Tensor', 'tf.Tensor'] -Input = Union[str, PyDataset, 'PIL.Image.Image', 'numpy.ndarray'] +Input = Union[str, PyDataset, Dict, 'PIL.Image.Image', 'numpy.ndarray'] InputModel = Union[str, Model] output_keys = [ diff --git a/modelscope/pipelines/nlp/space/dialog_generation_pipeline.py b/modelscope/pipelines/nlp/space/dialog_generation_pipeline.py index 949b20d0..1d93fdef 100644 --- a/modelscope/pipelines/nlp/space/dialog_generation_pipeline.py +++ b/modelscope/pipelines/nlp/space/dialog_generation_pipeline.py @@ -24,6 +24,7 @@ class DialogGenerationPipeline(Pipeline): super().__init__(model=model, preprocessor=preprocessor, **kwargs) self.model = model + self.preprocessor = preprocessor def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: """process the prediction results @@ -34,17 +35,12 @@ class DialogGenerationPipeline(Pipeline): Returns: Dict[str, str]: the prediction results """ + sys_rsp = self.preprocessor.text_field.tokenizer.convert_ids_to_tokens( + inputs['resp']) + assert len(sys_rsp) > 2 + sys_rsp = sys_rsp[1:len(sys_rsp) - 1] + # sys_rsp = self.preprocessor.text_field.tokenizer. - vocab_size = len(self.tokenizer.vocab) - pred_list = inputs['predictions'] - pred_ids = pred_list[0][0].cpu().numpy().tolist() - for j in range(len(pred_ids)): - if pred_ids[j] >= vocab_size: - pred_ids[j] = 100 - pred = self.tokenizer.convert_ids_to_tokens(pred_ids) - pred_string = ''.join(pred).replace( - '##', - '').split('[SEP]')[0].replace('[CLS]', - '').replace('[SEP]', - '').replace('[UNK]', '') - return {'pred_string': pred_string} + inputs['sys'] = sys_rsp + + return inputs diff --git a/modelscope/preprocessors/space/dialog_generation_preprocessor.py b/modelscope/preprocessors/space/dialog_generation_preprocessor.py index c6e2584d..9ce9e03b 100644 --- a/modelscope/preprocessors/space/dialog_generation_preprocessor.py +++ b/modelscope/preprocessors/space/dialog_generation_preprocessor.py @@ -32,8 +32,8 @@ class DialogGenerationPreprocessor(Preprocessor): self.text_field = MultiWOZBPETextField( self.model_dir, config=self.config) - @type_assert(object, str) - def __call__(self, data: str) -> Dict[str, Any]: + @type_assert(object, Dict) + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """process the raw input data Args: @@ -45,6 +45,7 @@ class DialogGenerationPreprocessor(Preprocessor): Dict[str, Any]: the preprocessed data """ - idx = self.text_field.get_ids(data) + user_ids = self.text_field.get_ids(data['user_input']) + data['user'] = user_ids - return {'user_idx': idx} + return data diff --git a/modelscope/trainers/nlp/space/trainers/gen_trainer.py b/modelscope/trainers/nlp/space/trainers/gen_trainer.py index 28494c83..a0cda25c 100644 --- a/modelscope/trainers/nlp/space/trainers/gen_trainer.py +++ b/modelscope/trainers/nlp/space/trainers/gen_trainer.py @@ -13,6 +13,7 @@ import torch from tqdm import tqdm from transformers.optimization import AdamW, get_linear_schedule_with_warmup +import modelscope.utils.nlp.space.ontology as ontology from ..metrics.metrics_tracker import MetricsTracker @@ -668,10 +669,45 @@ class MultiWOZTrainer(Trainer): return - def _get_turn_doamin(self, constraint_ids, bspn_gen_ids): - # constraint_token = self.tokenizer.convert_ids_to_tokens(constraint_ids) - # bspn_token = self.tokenizer.convert_ids_to_tokens(bspn_gen_ids) - return [] + def _get_turn_domain(self, old_pv_turn, bspn_gen_ids, first_turn): + + def _get_slots(constraint): + domain_name = '' + slots = {} + for item in constraint: + if item in ontology.placeholder_tokens: + continue + if item in ontology.all_domains_with_bracket: + domain_name = item + slots[domain_name] = set() + else: + assert domain_name in ontology.all_domains_with_bracket + slots[domain_name].add(item) + return slots + + turn_domain = [] + if first_turn and len(bspn_gen_ids) == 0: + turn_domain = ['[general]'] + return turn_domain + + bspn_token = self.tokenizer.convert_ids_to_tokens(bspn_gen_ids) + turn_slots = _get_slots(bspn_token) + if first_turn: + return list(turn_slots.keys()) + + assert 'bspn' in old_pv_turn + pv_bspn_token = self.tokenizer.convert_ids_to_tokens( + old_pv_turn['bspn']) + pv_turn_slots = _get_slots(pv_bspn_token) + for domain, value in turn_slots.items(): + pv_value = pv_turn_slots[ + domain] if domain in pv_turn_slots else set() + if len(value - pv_value) > 0 or len(pv_value - value): + turn_domain.append(domain) + if len(turn_domain) == 0: + turn_domain = list(turn_slots.keys()) + + return turn_domain def forward(self, turn, old_pv_turn): with torch.no_grad(): @@ -692,14 +728,11 @@ class MultiWOZTrainer(Trainer): generated_bs = outputs[0].cpu().numpy().tolist() bspn_gen = self.decode_generated_bspn(generated_bs) - turn_domain = self._get_turn_doamin(old_pv_turn['constraint_ids'], - bspn_gen) - print(turn_domain) + turn_domain = self._get_turn_domain(old_pv_turn, bspn_gen, + first_turn) db_result = self.reader.bspan_to_DBpointer( self.tokenizer.decode(bspn_gen), turn_domain) - print(db_result) - assert len(turn['db']) == 3 assert isinstance(db_result, str) db = \ [self.reader.sos_db_id] + \ @@ -718,14 +751,11 @@ class MultiWOZTrainer(Trainer): generated_ar = outputs_db[0].cpu().numpy().tolist() decoded = self.decode_generated_act_resp(generated_ar) decoded['bspn'] = bspn_gen - print(decoded) - print(self.tokenizer.convert_ids_to_tokens(decoded['resp'])) - pv_turn['labels'] = None + pv_turn['labels'] = inputs['labels'] pv_turn['resp'] = decoded['resp'] pv_turn['bspn'] = decoded['bspn'] - pv_turn['db'] = None - pv_turn['aspn'] = None - pv_turn['constraint_ids'] = bspn_gen + pv_turn['db'] = db + pv_turn['aspn'] = decoded['aspn'] return pv_turn diff --git a/modelscope/utils/nlp/space/ontology.py b/modelscope/utils/nlp/space/ontology.py index b22d3b3e..4f27168a 100644 --- a/modelscope/utils/nlp/space/ontology.py +++ b/modelscope/utils/nlp/space/ontology.py @@ -1,7 +1,13 @@ all_domains = [ 'restaurant', 'hotel', 'attraction', 'train', 'taxi', 'police', 'hospital' ] +all_domains_with_bracket = ['[{}]'.format(item) for item in all_domains] db_domains = ['restaurant', 'hotel', 'attraction', 'train'] +placeholder_tokens = [ + '', '', '', '', '', '', '', + '', '', '', '', '', '', + '', '', '' +] normlize_slot_names = { 'car type': 'car', diff --git a/tests/pipelines/nlp/test_dialog_generation.py b/tests/pipelines/nlp/test_dialog_generation.py index 8af102a0..23a6e5e9 100644 --- a/tests/pipelines/nlp/test_dialog_generation.py +++ b/tests/pipelines/nlp/test_dialog_generation.py @@ -4,16 +4,17 @@ import os.path as osp import tempfile import unittest +from maas_hub.snapshot_download import snapshot_download + +from modelscope.models import Model from modelscope.models.nlp import DialogGenerationModel from modelscope.pipelines import DialogGenerationPipeline, pipeline from modelscope.preprocessors import DialogGenerationPreprocessor - - -def merge(info, result): - return info +from modelscope.utils.constant import Tasks class DialogGenerationTest(unittest.TestCase): + model_id = 'damo/nlp_space_dialog-generation' test_case = { 'sng0073': { 'goal': { @@ -91,30 +92,58 @@ class DialogGenerationTest(unittest.TestCase): } } + @unittest.skip('test with snapshot_download') def test_run(self): - modeldir = '/Users/yangliu/Desktop/space-dialog-generation' + cache_path = '/Users/yangliu/Space/maas_model/nlp_space_dialog-generation' + # cache_path = snapshot_download(self.model_id) - preprocessor = DialogGenerationPreprocessor(model_dir=modeldir) + preprocessor = DialogGenerationPreprocessor(model_dir=cache_path) model = DialogGenerationModel( - model_dir=modeldir, + model_dir=cache_path, text_field=preprocessor.text_field, config=preprocessor.config) - print(model.forward(None)) - # pipeline = DialogGenerationPipeline( - # model=model, preprocessor=preprocessor) + pipelines = [ + DialogGenerationPipeline(model=model, preprocessor=preprocessor), + pipeline( + task=Tasks.dialog_generation, + model=model, + preprocessor=preprocessor) + ] + + result = {} + for step, item in enumerate(self.test_case['sng0073']['log']): + user = item['user'] + print('user: {}'.format(user)) + + result = pipelines[step % 2]({ + 'user_input': user, + 'history': result + }) + print('sys : {}'.format(result['sys'])) + + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + preprocessor = DialogGenerationPreprocessor(model_dir=model.model_dir) + + pipelines = [ + DialogGenerationPipeline(model=model, preprocessor=preprocessor), + pipeline( + task=Tasks.dialog_generation, + model=model, + preprocessor=preprocessor) + ] + + result = {} + for step, item in enumerate(self.test_case['sng0073']['log']): + user = item['user'] + print('user: {}'.format(user)) - # history_dialog_info = {} - # for step, item in enumerate(test_case['sng0073']['log']): - # user_question = item['user'] - # print('user: {}'.format(user_question)) - # - # # history_dialog_info = merge(history_dialog_info, - # # result) if step > 0 else {} - # result = pipeline(user_question, history=history_dialog_info) - # # - # # print('sys : {}'.format(result['pred_answer'])) - print('test') + result = pipelines[step % 2]({ + 'user_input': user, + 'history': result + }) + print('sys : {}'.format(result['sys'])) if __name__ == '__main__':