Browse Source

generation ready

master
ly119399 3 years ago
parent
commit
3c687c9f37
7 changed files with 130 additions and 88 deletions
  1. +14
    -34
      modelscope/models/nlp/space/dialog_generation_model.py
  2. +1
    -1
      modelscope/pipelines/base.py
  3. +9
    -13
      modelscope/pipelines/nlp/space/dialog_generation_pipeline.py
  4. +5
    -4
      modelscope/preprocessors/space/dialog_generation_preprocessor.py
  5. +45
    -15
      modelscope/trainers/nlp/space/trainers/gen_trainer.py
  6. +6
    -0
      modelscope/utils/nlp/space/ontology.py
  7. +50
    -21
      tests/pipelines/nlp/test_dialog_generation.py

+ 14
- 34
modelscope/models/nlp/space/dialog_generation_model.py View File

@@ -1,6 +1,10 @@
import os
from typing import Any, Dict, Optional

from modelscope.preprocessors.space.fields.gen_field import \
MultiWOZBPETextField
from modelscope.trainers.nlp.space.trainers.gen_trainer import MultiWOZTrainer
from modelscope.utils.config import Config
from modelscope.utils.constant import Tasks
from ...base import Model, Tensor
from ...builder import MODELS
@@ -25,8 +29,13 @@ class DialogGenerationModel(Model):

super().__init__(model_dir, *args, **kwargs)
self.model_dir = model_dir
self.text_field = kwargs.pop('text_field')
self.config = kwargs.pop('config')
self.config = kwargs.pop(
'config',
Config.from_file(
os.path.join(self.model_dir, 'configuration.json')))
self.text_field = kwargs.pop(
'text_field',
MultiWOZBPETextField(self.model_dir, config=self.config))
self.generator = Generator.create(self.config, reader=self.text_field)
self.model = ModelBase.create(
model_dir=model_dir,
@@ -65,39 +74,10 @@ class DialogGenerationModel(Model):
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
}
"""
from numpy import array, float32
import torch

# turn_1 = {
# 'user': [
# 13, 1045, 2052, 2066, 1037, 10095, 2013, 3002, 2198, 1005,
# 1055, 2267, 2000, 10733, 12570, 21713, 4487, 15474, 1012, 7
# ]
# }
# old_pv_turn_1 = {}
turn = {'user': input['user']}
old_pv_turn = input['history']

turn_2 = {
'user':
[13, 1045, 2215, 2000, 2681, 2044, 2459, 1024, 2321, 1012, 7]
}
old_pv_turn_2 = {
'labels': [[
13, 1045, 2052, 2066, 1037, 10095, 2013, 3002, 2198, 1005,
1055, 2267, 2000, 10733, 12570, 21713, 4487, 15474, 1012, 7
]],
'resp': [
14, 1045, 2052, 2022, 3407, 2000, 2393, 2007, 2115, 5227, 1010,
2079, 2017, 2031, 1037, 2051, 2017, 2052, 2066, 2000, 2681,
2030, 7180, 2011, 1029, 8
],
'bspn': [
15, 43, 7688, 10733, 12570, 21713, 4487, 15474, 6712, 3002,
2198, 1005, 1055, 2267, 9
],
'db': [19, 24, 21, 20],
'aspn': [16, 43, 48, 2681, 7180, 10]
}

pv_turn = self.trainer.forward(turn=turn_2, old_pv_turn=old_pv_turn_2)
pv_turn = self.trainer.forward(turn=turn, old_pv_turn=old_pv_turn)

return pv_turn

+ 1
- 1
modelscope/pipelines/base.py View File

@@ -15,7 +15,7 @@ from modelscope.utils.logger import get_logger
from .util import is_model_name

Tensor = Union['torch.Tensor', 'tf.Tensor']
Input = Union[str, PyDataset, 'PIL.Image.Image', 'numpy.ndarray']
Input = Union[str, PyDataset, Dict, 'PIL.Image.Image', 'numpy.ndarray']
InputModel = Union[str, Model]

output_keys = [


+ 9
- 13
modelscope/pipelines/nlp/space/dialog_generation_pipeline.py View File

@@ -24,6 +24,7 @@ class DialogGenerationPipeline(Pipeline):

super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.model = model
self.preprocessor = preprocessor

def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]:
"""process the prediction results
@@ -34,17 +35,12 @@ class DialogGenerationPipeline(Pipeline):
Returns:
Dict[str, str]: the prediction results
"""
sys_rsp = self.preprocessor.text_field.tokenizer.convert_ids_to_tokens(
inputs['resp'])
assert len(sys_rsp) > 2
sys_rsp = sys_rsp[1:len(sys_rsp) - 1]
# sys_rsp = self.preprocessor.text_field.tokenizer.

vocab_size = len(self.tokenizer.vocab)
pred_list = inputs['predictions']
pred_ids = pred_list[0][0].cpu().numpy().tolist()
for j in range(len(pred_ids)):
if pred_ids[j] >= vocab_size:
pred_ids[j] = 100
pred = self.tokenizer.convert_ids_to_tokens(pred_ids)
pred_string = ''.join(pred).replace(
'##',
'').split('[SEP]')[0].replace('[CLS]',
'').replace('[SEP]',
'').replace('[UNK]', '')
return {'pred_string': pred_string}
inputs['sys'] = sys_rsp

return inputs

+ 5
- 4
modelscope/preprocessors/space/dialog_generation_preprocessor.py View File

@@ -32,8 +32,8 @@ class DialogGenerationPreprocessor(Preprocessor):
self.text_field = MultiWOZBPETextField(
self.model_dir, config=self.config)

@type_assert(object, str)
def __call__(self, data: str) -> Dict[str, Any]:
@type_assert(object, Dict)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""process the raw input data

Args:
@@ -45,6 +45,7 @@ class DialogGenerationPreprocessor(Preprocessor):
Dict[str, Any]: the preprocessed data
"""

idx = self.text_field.get_ids(data)
user_ids = self.text_field.get_ids(data['user_input'])
data['user'] = user_ids

return {'user_idx': idx}
return data

+ 45
- 15
modelscope/trainers/nlp/space/trainers/gen_trainer.py View File

@@ -13,6 +13,7 @@ import torch
from tqdm import tqdm
from transformers.optimization import AdamW, get_linear_schedule_with_warmup

import modelscope.utils.nlp.space.ontology as ontology
from ..metrics.metrics_tracker import MetricsTracker


@@ -668,10 +669,45 @@ class MultiWOZTrainer(Trainer):

return

def _get_turn_doamin(self, constraint_ids, bspn_gen_ids):
# constraint_token = self.tokenizer.convert_ids_to_tokens(constraint_ids)
# bspn_token = self.tokenizer.convert_ids_to_tokens(bspn_gen_ids)
return []
def _get_turn_domain(self, old_pv_turn, bspn_gen_ids, first_turn):

def _get_slots(constraint):
domain_name = ''
slots = {}
for item in constraint:
if item in ontology.placeholder_tokens:
continue
if item in ontology.all_domains_with_bracket:
domain_name = item
slots[domain_name] = set()
else:
assert domain_name in ontology.all_domains_with_bracket
slots[domain_name].add(item)
return slots

turn_domain = []
if first_turn and len(bspn_gen_ids) == 0:
turn_domain = ['[general]']
return turn_domain

bspn_token = self.tokenizer.convert_ids_to_tokens(bspn_gen_ids)
turn_slots = _get_slots(bspn_token)
if first_turn:
return list(turn_slots.keys())

assert 'bspn' in old_pv_turn
pv_bspn_token = self.tokenizer.convert_ids_to_tokens(
old_pv_turn['bspn'])
pv_turn_slots = _get_slots(pv_bspn_token)
for domain, value in turn_slots.items():
pv_value = pv_turn_slots[
domain] if domain in pv_turn_slots else set()
if len(value - pv_value) > 0 or len(pv_value - value):
turn_domain.append(domain)
if len(turn_domain) == 0:
turn_domain = list(turn_slots.keys())

return turn_domain

def forward(self, turn, old_pv_turn):
with torch.no_grad():
@@ -692,14 +728,11 @@ class MultiWOZTrainer(Trainer):
generated_bs = outputs[0].cpu().numpy().tolist()
bspn_gen = self.decode_generated_bspn(generated_bs)

turn_domain = self._get_turn_doamin(old_pv_turn['constraint_ids'],
bspn_gen)
print(turn_domain)
turn_domain = self._get_turn_domain(old_pv_turn, bspn_gen,
first_turn)

db_result = self.reader.bspan_to_DBpointer(
self.tokenizer.decode(bspn_gen), turn_domain)
print(db_result)
assert len(turn['db']) == 3
assert isinstance(db_result, str)
db = \
[self.reader.sos_db_id] + \
@@ -718,14 +751,11 @@ class MultiWOZTrainer(Trainer):
generated_ar = outputs_db[0].cpu().numpy().tolist()
decoded = self.decode_generated_act_resp(generated_ar)
decoded['bspn'] = bspn_gen
print(decoded)
print(self.tokenizer.convert_ids_to_tokens(decoded['resp']))

pv_turn['labels'] = None
pv_turn['labels'] = inputs['labels']
pv_turn['resp'] = decoded['resp']
pv_turn['bspn'] = decoded['bspn']
pv_turn['db'] = None
pv_turn['aspn'] = None
pv_turn['constraint_ids'] = bspn_gen
pv_turn['db'] = db
pv_turn['aspn'] = decoded['aspn']

return pv_turn

+ 6
- 0
modelscope/utils/nlp/space/ontology.py View File

@@ -1,7 +1,13 @@
all_domains = [
'restaurant', 'hotel', 'attraction', 'train', 'taxi', 'police', 'hospital'
]
all_domains_with_bracket = ['[{}]'.format(item) for item in all_domains]
db_domains = ['restaurant', 'hotel', 'attraction', 'train']
placeholder_tokens = [
'<go_r>', '<go_b>', '<go_a>', '<go_d>', '<eos_u>', '<eos_r>', '<eos_b>',
'<eos_a>', '<eos_d>', '<eos_q>', '<sos_u>', '<sos_r>', '<sos_b>',
'<sos_a>', '<sos_d>', '<sos_q>'
]

normlize_slot_names = {
'car type': 'car',


+ 50
- 21
tests/pipelines/nlp/test_dialog_generation.py View File

@@ -4,16 +4,17 @@ import os.path as osp
import tempfile
import unittest

from maas_hub.snapshot_download import snapshot_download

from modelscope.models import Model
from modelscope.models.nlp import DialogGenerationModel
from modelscope.pipelines import DialogGenerationPipeline, pipeline
from modelscope.preprocessors import DialogGenerationPreprocessor


def merge(info, result):
return info
from modelscope.utils.constant import Tasks


class DialogGenerationTest(unittest.TestCase):
model_id = 'damo/nlp_space_dialog-generation'
test_case = {
'sng0073': {
'goal': {
@@ -91,30 +92,58 @@ class DialogGenerationTest(unittest.TestCase):
}
}

@unittest.skip('test with snapshot_download')
def test_run(self):

modeldir = '/Users/yangliu/Desktop/space-dialog-generation'
cache_path = '/Users/yangliu/Space/maas_model/nlp_space_dialog-generation'
# cache_path = snapshot_download(self.model_id)

preprocessor = DialogGenerationPreprocessor(model_dir=modeldir)
preprocessor = DialogGenerationPreprocessor(model_dir=cache_path)
model = DialogGenerationModel(
model_dir=modeldir,
model_dir=cache_path,
text_field=preprocessor.text_field,
config=preprocessor.config)
print(model.forward(None))
# pipeline = DialogGenerationPipeline(
# model=model, preprocessor=preprocessor)
pipelines = [
DialogGenerationPipeline(model=model, preprocessor=preprocessor),
pipeline(
task=Tasks.dialog_generation,
model=model,
preprocessor=preprocessor)
]

result = {}
for step, item in enumerate(self.test_case['sng0073']['log']):
user = item['user']
print('user: {}'.format(user))

result = pipelines[step % 2]({
'user_input': user,
'history': result
})
print('sys : {}'.format(result['sys']))

def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
preprocessor = DialogGenerationPreprocessor(model_dir=model.model_dir)

pipelines = [
DialogGenerationPipeline(model=model, preprocessor=preprocessor),
pipeline(
task=Tasks.dialog_generation,
model=model,
preprocessor=preprocessor)
]

result = {}
for step, item in enumerate(self.test_case['sng0073']['log']):
user = item['user']
print('user: {}'.format(user))

# history_dialog_info = {}
# for step, item in enumerate(test_case['sng0073']['log']):
# user_question = item['user']
# print('user: {}'.format(user_question))
#
# # history_dialog_info = merge(history_dialog_info,
# # result) if step > 0 else {}
# result = pipeline(user_question, history=history_dialog_info)
# #
# # print('sys : {}'.format(result['pred_answer']))
print('test')
result = pipelines[step % 2]({
'user_input': user,
'history': result
})
print('sys : {}'.format(result['sys']))


if __name__ == '__main__':


Loading…
Cancel
Save