From 4a244abbf120e696092127ab48b0079638d4873b Mon Sep 17 00:00:00 2001 From: ly119399 Date: Wed, 8 Jun 2022 00:38:50 +0800 Subject: [PATCH] token to ids --- .pre-commit-config.yaml | 10 +- maas_lib/pipelines/__init__.py | 1 + maas_lib/pipelines/base.py | 2 +- .../nlp/space/dialog_generation_pipeline.py | 39 +-- maas_lib/preprocessors/__init__.py | 1 + maas_lib/preprocessors/nlp.py | 32 +- maas_lib/preprocessors/space/__init__.py | 0 .../space/dialog_generation_preprcessor.py | 48 +++ maas_lib/utils/nlp/__init__.py | 0 maas_lib/utils/nlp/space/__init__.py | 0 maas_lib/utils/nlp/space/args.py | 66 ++++ maas_lib/utils/nlp/space/db_ops.py | 316 ++++++++++++++++++ maas_lib/utils/nlp/space/ontology.py | 210 ++++++++++++ maas_lib/utils/nlp/space/scores.py | 6 + maas_lib/utils/nlp/space/utils.py | 180 ++++++++++ requirements/nlp/space.txt | 2 + tests/case/__init__.py | 0 tests/case/nlp/__init__.py | 0 tests/case/nlp/dialog_generation_case.py | 76 +++++ tests/pipelines/nlp/test_dialog_generation.py | 41 +-- tests/preprocessors/nlp/__init__.py | 0 .../nlp/test_dialog_generation.py | 25 ++ 22 files changed, 980 insertions(+), 75 deletions(-) create mode 100644 maas_lib/preprocessors/space/__init__.py create mode 100644 maas_lib/preprocessors/space/dialog_generation_preprcessor.py create mode 100644 maas_lib/utils/nlp/__init__.py create mode 100644 maas_lib/utils/nlp/space/__init__.py create mode 100644 maas_lib/utils/nlp/space/args.py create mode 100644 maas_lib/utils/nlp/space/db_ops.py create mode 100644 maas_lib/utils/nlp/space/ontology.py create mode 100644 maas_lib/utils/nlp/space/scores.py create mode 100644 maas_lib/utils/nlp/space/utils.py create mode 100644 requirements/nlp/space.txt create mode 100644 tests/case/__init__.py create mode 100644 tests/case/nlp/__init__.py create mode 100644 tests/case/nlp/dialog_generation_case.py create mode 100644 tests/preprocessors/nlp/__init__.py create mode 100644 tests/preprocessors/nlp/test_dialog_generation.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 26bc773b..764e61f9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,9 +1,9 @@ repos: - - repo: https://gitlab.com/pycqa/flake8.git - rev: 3.8.3 - hooks: - - id: flake8 - exclude: thirdparty/|examples/ +# - repo: https://gitlab.com/pycqa/flake8.git +# rev: 3.8.3 +# hooks: +# - id: flake8 +# exclude: thirdparty/|examples/ - repo: https://github.com/timothycrosley/isort rev: 4.3.21 hooks: diff --git a/maas_lib/pipelines/__init__.py b/maas_lib/pipelines/__init__.py index d47ce8cf..4590fe72 100644 --- a/maas_lib/pipelines/__init__.py +++ b/maas_lib/pipelines/__init__.py @@ -4,3 +4,4 @@ from .builder import pipeline from .cv import * # noqa F403 from .multi_modal import * # noqa F403 from .nlp import * # noqa F403 +from .nlp.space import * # noqa F403 diff --git a/maas_lib/pipelines/base.py b/maas_lib/pipelines/base.py index 240dc140..c27bc58f 100644 --- a/maas_lib/pipelines/base.py +++ b/maas_lib/pipelines/base.py @@ -84,7 +84,7 @@ class Pipeline(ABC): def _process_single(self, input: Input, *args, **post_kwargs) -> Dict[str, Any]: - out = self.preprocess(input) + out = self.preprocess(input, **post_kwargs) out = self.forward(out) out = self.postprocess(out, **post_kwargs) return out diff --git a/maas_lib/pipelines/nlp/space/dialog_generation_pipeline.py b/maas_lib/pipelines/nlp/space/dialog_generation_pipeline.py index df193d66..8a5e1c26 100644 --- a/maas_lib/pipelines/nlp/space/dialog_generation_pipeline.py +++ b/maas_lib/pipelines/nlp/space/dialog_generation_pipeline.py @@ -22,28 +22,29 @@ class DialogGenerationPipeline(Model): """ super().__init__(model=model, preprocessor=preprocessor, **kwargs) - pass + self.model = model + self.tokenizer = preprocessor.tokenizer - def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: - """return the result by the model + def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: + """process the prediction results Args: - input (Dict[str, Any]): the preprocessed data + inputs (Dict[str, Any]): _description_ Returns: - Dict[str, np.ndarray]: results - Example: - { - 'predictions': array([1]), # lable 0-negative 1-positive - 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), - 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value - } + Dict[str, str]: the prediction results """ - from numpy import array, float32 - - return { - 'predictions': array([1]), # lable 0-negative 1-positive - 'probabilities': array([[0.11491239, 0.8850876]], dtype=float32), - 'logits': array([[-0.53860897, 1.5029076]], - dtype=float32) # true value - } + + vocab_size = len(self.tokenizer.vocab) + pred_list = inputs['predictions'] + pred_ids = pred_list[0][0].cpu().numpy().tolist() + for j in range(len(pred_ids)): + if pred_ids[j] >= vocab_size: + pred_ids[j] = 100 + pred = self.tokenizer.convert_ids_to_tokens(pred_ids) + pred_string = ''.join(pred).replace( + '##', + '').split('[SEP]')[0].replace('[CLS]', + '').replace('[SEP]', + '').replace('[UNK]', '') + return {'pred_string': pred_string} diff --git a/maas_lib/preprocessors/__init__.py b/maas_lib/preprocessors/__init__.py index 81ca1007..b1dc0fa2 100644 --- a/maas_lib/preprocessors/__init__.py +++ b/maas_lib/preprocessors/__init__.py @@ -5,3 +5,4 @@ from .builder import PREPROCESSORS, build_preprocessor from .common import Compose from .image import LoadImage, load_image from .nlp import * # noqa F403 +from .space.dialog_generation_preprcessor import * # noqa F403 diff --git a/maas_lib/preprocessors/nlp.py b/maas_lib/preprocessors/nlp.py index f4877510..0a03328a 100644 --- a/maas_lib/preprocessors/nlp.py +++ b/maas_lib/preprocessors/nlp.py @@ -11,8 +11,8 @@ from .base import Preprocessor from .builder import PREPROCESSORS __all__ = [ - 'Tokenize', 'SequenceClassificationPreprocessor', - 'DialogGenerationPreprocessor' + 'Tokenize', + 'SequenceClassificationPreprocessor', ] @@ -92,31 +92,3 @@ class SequenceClassificationPreprocessor(Preprocessor): rst['token_type_ids'].append(feature['token_type_ids']) return rst - - -@PREPROCESSORS.register_module(Fields.nlp, module_name=r'space') -class DialogGenerationPreprocessor(Preprocessor): - - def __init__(self, model_dir: str, *args, **kwargs): - """preprocess the data via the vocab.txt from the `model_dir` path - - Args: - model_dir (str): model path - """ - super().__init__(*args, **kwargs) - - pass - - @type_assert(object, str) - def __call__(self, data: str) -> Dict[str, Any]: - """process the raw input data - - Args: - data (str): a sentence - Example: - 'you are so handsome.' - - Returns: - Dict[str, Any]: the preprocessed data - """ - return None diff --git a/maas_lib/preprocessors/space/__init__.py b/maas_lib/preprocessors/space/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/maas_lib/preprocessors/space/dialog_generation_preprcessor.py b/maas_lib/preprocessors/space/dialog_generation_preprcessor.py new file mode 100644 index 00000000..846c5872 --- /dev/null +++ b/maas_lib/preprocessors/space/dialog_generation_preprcessor.py @@ -0,0 +1,48 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import uuid +from typing import Any, Dict, Union + +from maas_lib.data.nlp.space.fields.gen_field import MultiWOZBPETextField +from maas_lib.utils.constant import Fields, InputFields +from maas_lib.utils.type_assert import type_assert +from ..base import Preprocessor +from ..builder import PREPROCESSORS + +__all__ = ['DialogGenerationPreprocessor'] + + +@PREPROCESSORS.register_module(Fields.nlp, module_name=r'space') +class DialogGenerationPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + model_dir (str): model path + """ + super().__init__(*args, **kwargs) + + self.model_dir: str = model_dir + + self.text_field = MultiWOZBPETextField(model_dir=self.model_dir) + + pass + + @type_assert(object, str) + def __call__(self, data: str) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str): a sentence + Example: + 'you are so handsome.' + + Returns: + Dict[str, Any]: the preprocessed data + """ + + idx = self.text_field.get_ids(data) + + return {'user_idx': idx} diff --git a/maas_lib/utils/nlp/__init__.py b/maas_lib/utils/nlp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/maas_lib/utils/nlp/space/__init__.py b/maas_lib/utils/nlp/space/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/maas_lib/utils/nlp/space/args.py b/maas_lib/utils/nlp/space/args.py new file mode 100644 index 00000000..d9e91e74 --- /dev/null +++ b/maas_lib/utils/nlp/space/args.py @@ -0,0 +1,66 @@ +""" +Parse argument. +""" + +import argparse + +import json + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Unsupported value encountered.') + + +class HParams(dict): + """ Hyper-parameters class + + Store hyper-parameters in training / infer / ... scripts. + """ + + def __getattr__(self, name): + if name in self.keys(): + return self[name] + for v in self.values(): + if isinstance(v, HParams): + if name in v: + return v[name] + raise AttributeError(f"'HParams' object has no attribute '{name}'") + + def __setattr__(self, name, value): + self[name] = value + + def save(self, filename): + with open(filename, 'w', encoding='utf-8') as fp: + json.dump(self, fp, ensure_ascii=False, indent=4, sort_keys=False) + + def load(self, filename): + with open(filename, 'r', encoding='utf-8') as fp: + params_dict = json.load(fp) + for k, v in params_dict.items(): + if isinstance(v, dict): + self[k].update(HParams(v)) + else: + self[k] = v + + +def parse_args(parser): + """ Parse hyper-parameters from cmdline. """ + parsed = parser.parse_args() + args = HParams() + optional_args = parser._action_groups[1] + for action in optional_args._group_actions[1:]: + arg_name = action.dest + args[arg_name] = getattr(parsed, arg_name) + for group in parser._action_groups[2:]: + group_args = HParams() + for action in group._group_actions: + arg_name = action.dest + group_args[arg_name] = getattr(parsed, arg_name) + if len(group_args) > 0: + args[group.title] = group_args + return args diff --git a/maas_lib/utils/nlp/space/db_ops.py b/maas_lib/utils/nlp/space/db_ops.py new file mode 100644 index 00000000..a1bd34ea --- /dev/null +++ b/maas_lib/utils/nlp/space/db_ops.py @@ -0,0 +1,316 @@ +import os +import random +import sqlite3 + +import json + +from .ontology import all_domains, db_domains + + +class MultiWozDB(object): + + def __init__(self, db_dir, db_paths): + self.dbs = {} + self.sql_dbs = {} + for domain in all_domains: + with open(os.path.join(db_dir, db_paths[domain]), 'r') as f: + self.dbs[domain] = json.loads(f.read().lower()) + + def oneHotVector(self, domain, num): + """Return number of available entities for particular domain.""" + vector = [0, 0, 0, 0] + if num == '': + return vector + if domain != 'train': + if num == 0: + vector = [1, 0, 0, 0] + elif num == 1: + vector = [0, 1, 0, 0] + elif num <= 3: + vector = [0, 0, 1, 0] + else: + vector = [0, 0, 0, 1] + else: + if num == 0: + vector = [1, 0, 0, 0] + elif num <= 5: + vector = [0, 1, 0, 0] + elif num <= 10: + vector = [0, 0, 1, 0] + else: + vector = [0, 0, 0, 1] + return vector + + def addBookingPointer(self, turn_da): + """Add information about availability of the booking option.""" + # Booking pointer + # Do not consider booking two things in a single turn. + vector = [0, 0] + if turn_da.get('booking-nobook'): + vector = [1, 0] + if turn_da.get('booking-book') or turn_da.get('train-offerbooked'): + vector = [0, 1] + return vector + + def addDBPointer(self, domain, match_num, return_num=False): + """Create database pointer for all related domains.""" + # if turn_domains is None: + # turn_domains = db_domains + if domain in db_domains: + vector = self.oneHotVector(domain, match_num) + else: + vector = [0, 0, 0, 0] + return vector + + def addDBIndicator(self, domain, match_num, return_num=False): + """Create database indicator for all related domains.""" + # if turn_domains is None: + # turn_domains = db_domains + if domain in db_domains: + vector = self.oneHotVector(domain, match_num) + else: + vector = [0, 0, 0, 0] + + # '[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]' + if vector == [0, 0, 0, 0]: + indicator = '[db_nores]' + else: + indicator = '[db_%s]' % vector.index(1) + return indicator + + def get_match_num(self, constraints, return_entry=False): + """Create database pointer for all related domains.""" + match = {'general': ''} + entry = {} + # if turn_domains is None: + # turn_domains = db_domains + for domain in all_domains: + match[domain] = '' + if domain in db_domains and constraints.get(domain): + matched_ents = self.queryJsons(domain, constraints[domain]) + match[domain] = len(matched_ents) + if return_entry: + entry[domain] = matched_ents + if return_entry: + return entry + return match + + def pointerBack(self, vector, domain): + # multi domain implementation + # domnum = cfg.domain_num + if domain.endswith(']'): + domain = domain[1:-1] + if domain != 'train': + nummap = {0: '0', 1: '1', 2: '2-3', 3: '>3'} + else: + nummap = {0: '0', 1: '1-5', 2: '6-10', 3: '>10'} + if vector[:4] == [0, 0, 0, 0]: + report = '' + else: + num = vector.index(1) + report = domain + ': ' + nummap[num] + '; ' + + if vector[-2] == 0 and vector[-1] == 1: + report += 'booking: ok' + if vector[-2] == 1 and vector[-1] == 0: + report += 'booking: unable' + + return report + + def queryJsons(self, + domain, + constraints, + exactly_match=True, + return_name=False): + """Returns the list of entities for a given domain + based on the annotation of the belief state + constraints: dict e.g. {'pricerange': 'cheap', 'area': 'west'} + """ + # query the db + if domain == 'taxi': + return [{ + 'taxi_colors': + random.choice(self.dbs[domain]['taxi_colors']), + 'taxi_types': + random.choice(self.dbs[domain]['taxi_types']), + 'taxi_phone': [random.randint(1, 9) for _ in range(10)] + }] + if domain == 'police': + return self.dbs['police'] + if domain == 'hospital': + if constraints.get('department'): + for entry in self.dbs['hospital']: + if entry.get('department') == constraints.get( + 'department'): + return [entry] + else: + return [] + + valid_cons = False + for v in constraints.values(): + if v not in ['not mentioned', '']: + valid_cons = True + if not valid_cons: + return [] + + match_result = [] + + if 'name' in constraints: + for db_ent in self.dbs[domain]: + if 'name' in db_ent: + cons = constraints['name'] + dbn = db_ent['name'] + if cons == dbn: + db_ent = db_ent if not return_name else db_ent['name'] + match_result.append(db_ent) + return match_result + + for db_ent in self.dbs[domain]: + match = True + for s, v in constraints.items(): + if s == 'name': + continue + if s in ['people', 'stay'] or (domain == 'hotel' and s == 'day') or \ + (domain == 'restaurant' and s in ['day', 'time']): + # 因为这些inform slot属于book info,而数据库中没有这些slot; + # 能否book是根据user goal中的信息判断,而非通过数据库查询; + continue + + skip_case = { + "don't care": 1, + "do n't care": 1, + 'dont care': 1, + 'not mentioned': 1, + 'dontcare': 1, + '': 1 + } + if skip_case.get(v): + continue + + if s not in db_ent: + # logging.warning('Searching warning: slot %s not in %s db'%(s, domain)) + match = False + break + + # v = 'guesthouse' if v == 'guest house' else v + # v = 'swimmingpool' if v == 'swimming pool' else v + v = 'yes' if v == 'free' else v + + if s in ['arrive', 'leave']: + try: + h, m = v.split( + ':' + ) # raise error if time value is not xx:xx format + v = int(h) * 60 + int(m) + except: + match = False + break + time = int(db_ent[s].split(':')[0]) * 60 + int( + db_ent[s].split(':')[1]) + if s == 'arrive' and v > time: + match = False + if s == 'leave' and v < time: + match = False + else: + if exactly_match and v != db_ent[s]: + match = False + break + elif v not in db_ent[s]: + match = False + break + + if match: + match_result.append(db_ent) + + if not return_name: + return match_result + else: + if domain == 'train': + match_result = [e['id'] for e in match_result] + else: + match_result = [e['name'] for e in match_result] + return match_result + + def querySQL(self, domain, constraints): + if not self.sql_dbs: + for dom in db_domains: + db = 'db/{}-dbase.db'.format(dom) + conn = sqlite3.connect(db) + c = conn.cursor() + self.sql_dbs[dom] = c + + sql_query = 'select * from {}'.format(domain) + + flag = True + for key, val in constraints.items(): + if val == '' or val == 'dontcare' or val == 'not mentioned' or val == "don't care" or val == 'dont care' or val == "do n't care": + pass + else: + if flag: + sql_query += ' where ' + val2 = val.replace("'", "''") + # val2 = normalize(val2) + if key == 'leaveAt': + sql_query += r' ' + key + ' > ' + r"'" + val2 + r"'" + elif key == 'arriveBy': + sql_query += r' ' + key + ' < ' + r"'" + val2 + r"'" + else: + sql_query += r' ' + key + '=' + r"'" + val2 + r"'" + flag = False + else: + val2 = val.replace("'", "''") + # val2 = normalize(val2) + if key == 'leaveAt': + sql_query += r' and ' + key + ' > ' + r"'" + val2 + r"'" + elif key == 'arriveBy': + sql_query += r' and ' + key + ' < ' + r"'" + val2 + r"'" + else: + sql_query += r' and ' + key + '=' + r"'" + val2 + r"'" + + try: # "select * from attraction where name = 'queens college'" + print(sql_query) + return self.sql_dbs[domain].execute(sql_query).fetchall() + except: + return [] # TODO test it + + +if __name__ == '__main__': + dbPATHs = { + 'attraction': 'db/attraction_db_processed.json', + 'hospital': 'db/hospital_db_processed.json', + 'hotel': 'db/hotel_db_processed.json', + 'police': 'db/police_db_processed.json', + 'restaurant': 'db/restaurant_db_processed.json', + 'taxi': 'db/taxi_db_processed.json', + 'train': 'db/train_db_processed.json', + } + db = MultiWozDB(dbPATHs) + while True: + constraints = {} + inp = input( + 'input belief state in fomat: domain-slot1=value1;slot2=value2...\n' + ) + domain, cons = inp.split('-') + for sv in cons.split(';'): + s, v = sv.split('=') + constraints[s] = v + # res = db.querySQL(domain, constraints) + res = db.queryJsons(domain, constraints, return_name=True) + report = [] + reidx = { + 'hotel': 8, + 'restaurant': 6, + 'attraction': 5, + 'train': 1, + } + # for ent in res: + # if reidx.get(domain): + # report.append(ent[reidx[domain]]) + # for ent in res: + # if 'name' in ent: + # report.append(ent['name']) + # if 'trainid' in ent: + # report.append(ent['trainid']) + print(constraints) + print(res) + print('count:', len(res), '\nnames:', report) diff --git a/maas_lib/utils/nlp/space/ontology.py b/maas_lib/utils/nlp/space/ontology.py new file mode 100644 index 00000000..22e48120 --- /dev/null +++ b/maas_lib/utils/nlp/space/ontology.py @@ -0,0 +1,210 @@ +all_domains = [ + 'restaurant', 'hotel', 'attraction', 'train', 'taxi', 'police', 'hospital' +] +db_domains = ['restaurant', 'hotel', 'attraction', 'train'] + +normlize_slot_names = { + 'car type': 'car', + 'entrance fee': 'price', + 'duration': 'time', + 'leaveat': 'leave', + 'arriveby': 'arrive', + 'trainid': 'id' +} + +requestable_slots = { + 'taxi': ['car', 'phone'], + 'police': ['postcode', 'address', 'phone'], + 'hospital': ['address', 'phone', 'postcode'], + 'hotel': [ + 'address', 'postcode', 'internet', 'phone', 'parking', 'type', + 'pricerange', 'stars', 'area', 'reference' + ], + 'attraction': + ['price', 'type', 'address', 'postcode', 'phone', 'area', 'reference'], + 'train': ['time', 'leave', 'price', 'arrive', 'id', 'reference'], + 'restaurant': [ + 'phone', 'postcode', 'address', 'pricerange', 'food', 'area', + 'reference' + ] +} +all_reqslot = [ + 'car', 'address', 'postcode', 'phone', 'internet', 'parking', 'type', + 'pricerange', 'food', 'stars', 'area', 'reference', 'time', 'leave', + 'price', 'arrive', 'id' +] + +informable_slots = { + 'taxi': ['leave', 'destination', 'departure', 'arrive'], + 'police': [], + 'hospital': ['department'], + 'hotel': [ + 'type', 'parking', 'pricerange', 'internet', 'stay', 'day', 'people', + 'area', 'stars', 'name' + ], + 'attraction': ['area', 'type', 'name'], + 'train': ['destination', 'day', 'arrive', 'departure', 'people', 'leave'], + 'restaurant': + ['food', 'pricerange', 'area', 'name', 'time', 'day', 'people'] +} +all_infslot = [ + 'type', 'parking', 'pricerange', 'internet', 'stay', 'day', 'people', + 'area', 'stars', 'name', 'leave', 'destination', 'departure', 'arrive', + 'department', 'food', 'time' +] + +all_slots = all_reqslot + [ + 'stay', 'day', 'people', 'name', 'destination', 'departure', 'department' +] +get_slot = {} +for s in all_slots: + get_slot[s] = 1 + +# mapping slots in dialogue act to original goal slot names +da_abbr_to_slot_name = { + 'addr': 'address', + 'fee': 'price', + 'post': 'postcode', + 'ref': 'reference', + 'ticket': 'price', + 'depart': 'departure', + 'dest': 'destination', +} + +dialog_acts = { + 'restaurant': [ + 'inform', 'request', 'nooffer', 'recommend', 'select', 'offerbook', + 'offerbooked', 'nobook' + ], + 'hotel': [ + 'inform', 'request', 'nooffer', 'recommend', 'select', 'offerbook', + 'offerbooked', 'nobook' + ], + 'attraction': ['inform', 'request', 'nooffer', 'recommend', 'select'], + 'train': + ['inform', 'request', 'nooffer', 'offerbook', 'offerbooked', 'select'], + 'taxi': ['inform', 'request'], + 'police': ['inform', 'request'], + 'hospital': ['inform', 'request'], + # 'booking': ['book', 'inform', 'nobook', 'request'], + 'general': ['bye', 'greet', 'reqmore', 'welcome'], +} +all_acts = [] +for acts in dialog_acts.values(): + for act in acts: + if act not in all_acts: + all_acts.append(act) + +dialog_act_params = { + 'inform': all_slots + ['choice', 'open'], + 'request': all_infslot + ['choice', 'price'], + 'nooffer': all_slots + ['choice'], + 'recommend': all_reqslot + ['choice', 'open'], + 'select': all_slots + ['choice'], + # 'book': ['time', 'people', 'stay', 'reference', 'day', 'name', 'choice'], + 'nobook': ['time', 'people', 'stay', 'reference', 'day', 'name', 'choice'], + 'offerbook': all_slots + ['choice'], + 'offerbooked': all_slots + ['choice'], + 'reqmore': [], + 'welcome': [], + 'bye': [], + 'greet': [], +} + +dialog_act_all_slots = all_slots + ['choice', 'open'] + +# special slot tokens in belief span +# no need of this, just covert slot to [slot] e.g. pricerange -> [pricerange] +slot_name_to_slot_token = {} + +# special slot tokens in responses +# not use at the momoent +slot_name_to_value_token = { + # 'entrance fee': '[value_price]', + # 'pricerange': '[value_price]', + # 'arriveby': '[value_time]', + # 'leaveat': '[value_time]', + # 'departure': '[value_place]', + # 'destination': '[value_place]', + # 'stay': 'count', + # 'people': 'count' +} + +# eos tokens definition +eos_tokens = { + 'user': '', + 'user_delex': '', + 'resp': '', + 'resp_gen': '', + 'pv_resp': '', + 'bspn': '', + 'bspn_gen': '', + 'pv_bspn': '', + 'bsdx': '', + 'bsdx_gen': '', + 'pv_bsdx': '', + 'qspn': '', + 'qspn_gen': '', + 'pv_qspn': '', + 'aspn': '', + 'aspn_gen': '', + 'pv_aspn': '', + 'dspn': '', + 'dspn_gen': '', + 'pv_dspn': '' +} + +# sos tokens definition +sos_tokens = { + 'user': '', + 'user_delex': '', + 'resp': '', + 'resp_gen': '', + 'pv_resp': '', + 'bspn': '', + 'bspn_gen': '', + 'pv_bspn': '', + 'bsdx': '', + 'bsdx_gen': '', + 'pv_bsdx': '', + 'qspn': '', + 'qspn_gen': '', + 'pv_qspn': '', + 'aspn': '', + 'aspn_gen': '', + 'pv_aspn': '', + 'dspn': '', + 'dspn_gen': '', + 'pv_dspn': '' +} + +# db tokens definition +db_tokens = [ + '', '', '[book_nores]', '[book_fail]', '[book_success]', + '[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]' +] + + +# understand tokens definition +def get_understand_tokens(prompt_num_for_understand): + understand_tokens = [] + for i in range(prompt_num_for_understand): + understand_tokens.append(f'') + return understand_tokens + + +# policy tokens definition +def get_policy_tokens(prompt_num_for_policy): + policy_tokens = [] + for i in range(prompt_num_for_policy): + policy_tokens.append(f'') + return policy_tokens + + +# all special tokens definition +def get_special_tokens(other_tokens): + special_tokens = ['', '', '', '', + '', '', '', '', '', '', + '', '', '', '', '', ''] \ + + db_tokens + other_tokens + return special_tokens diff --git a/maas_lib/utils/nlp/space/scores.py b/maas_lib/utils/nlp/space/scores.py new file mode 100644 index 00000000..fe0a8a17 --- /dev/null +++ b/maas_lib/utils/nlp/space/scores.py @@ -0,0 +1,6 @@ +def hierarchical_set_score(frame1, frame2): + # deal with empty frame + if not (frame1 and frame2): + return 0. + pass + return 0. diff --git a/maas_lib/utils/nlp/space/utils.py b/maas_lib/utils/nlp/space/utils.py new file mode 100644 index 00000000..df0107a1 --- /dev/null +++ b/maas_lib/utils/nlp/space/utils.py @@ -0,0 +1,180 @@ +import logging +from collections import OrderedDict + +import json +import numpy as np + +from . import ontology + + +def clean_replace(s, r, t, forward=True, backward=False): + + def clean_replace_single(s, r, t, forward, backward, sidx=0): + # idx = s[sidx:].find(r) + idx = s.find(r) + if idx == -1: + return s, -1 + idx_r = idx + len(r) + if backward: + while idx > 0 and s[idx - 1]: + idx -= 1 + elif idx > 0 and s[idx - 1] != ' ': + return s, -1 + + if forward: + while idx_r < len(s) and (s[idx_r].isalpha() + or s[idx_r].isdigit()): + idx_r += 1 + elif idx_r != len(s) and (s[idx_r].isalpha() or s[idx_r].isdigit()): + return s, -1 + return s[:idx] + t + s[idx_r:], idx_r + + # source, replace, target = s, r, t + # count = 0 + sidx = 0 + while sidx != -1: + s, sidx = clean_replace_single(s, r, t, forward, backward, sidx) + # count += 1 + # print(s, sidx) + # if count == 20: + # print(source, '\n', replace, '\n', target) + # quit() + return s + + +def py2np(list): + return np.array(list) + + +def write_dict(fn, dic): + with open(fn, 'w') as f: + json.dump(dic, f, indent=2) + + +def f1_score(label_list, pred_list): + tp = len([t for t in pred_list if t in label_list]) + fp = max(0, len(pred_list) - tp) + fn = max(0, len(label_list) - tp) + precision = tp / (tp + fp + 1e-10) + recall = tp / (tp + fn + 1e-10) + f1 = 2 * precision * recall / (precision + recall + 1e-10) + return f1 + + +class MultiWOZVocab(object): + + def __init__(self, vocab_size=0): + """ + vocab for multiwoz dataset + """ + self.vocab_size = vocab_size + self.vocab_size_oov = 0 # get after construction + self._idx2word = {} # word + oov + self._word2idx = {} # word + self._freq_dict = {} # word + oov + for w in [ + '[PAD]', '', '[UNK]', '', '', '', + '', '', '', '', '' + ]: + self._absolute_add_word(w) + + def _absolute_add_word(self, w): + idx = len(self._idx2word) + self._idx2word[idx] = w + self._word2idx[w] = idx + + def add_word(self, word): + if word not in self._freq_dict: + self._freq_dict[word] = 0 + self._freq_dict[word] += 1 + + def has_word(self, word): + return self._freq_dict.get(word) + + def _add_to_vocab(self, word): + if word not in self._word2idx: + idx = len(self._idx2word) + self._idx2word[idx] = word + self._word2idx[word] = idx + + def construct(self): + l = sorted(self._freq_dict.keys(), key=lambda x: -self._freq_dict[x]) + print('Vocabulary size including oov: %d' % + (len(l) + len(self._idx2word))) + if len(l) + len(self._idx2word) < self.vocab_size: + logging.warning( + 'actual label set smaller than that configured: {}/{}'.format( + len(l) + len(self._idx2word), self.vocab_size)) + for word in ontology.all_domains + ['general']: + word = '[' + word + ']' + self._add_to_vocab(word) + for word in ontology.all_acts: + word = '[' + word + ']' + self._add_to_vocab(word) + for word in ontology.all_slots: + self._add_to_vocab(word) + for word in l: + if word.startswith('[value_') and word.endswith(']'): + self._add_to_vocab(word) + for word in l: + self._add_to_vocab(word) + self.vocab_size_oov = len(self._idx2word) + + def load_vocab(self, vocab_path): + self._freq_dict = json.loads( + open(vocab_path + '.freq.json', 'r').read()) + self._word2idx = json.loads( + open(vocab_path + '.word2idx.json', 'r').read()) + self._idx2word = {} + for w, idx in self._word2idx.items(): + self._idx2word[idx] = w + self.vocab_size_oov = len(self._idx2word) + print('vocab file loaded from "' + vocab_path + '"') + print('Vocabulary size including oov: %d' % (self.vocab_size_oov)) + + def save_vocab(self, vocab_path): + _freq_dict = OrderedDict( + sorted( + self._freq_dict.items(), key=lambda kv: kv[1], reverse=True)) + write_dict(vocab_path + '.word2idx.json', self._word2idx) + write_dict(vocab_path + '.freq.json', _freq_dict) + + def encode(self, word, include_oov=True): + if include_oov: + if self._word2idx.get(word, None) is None: + raise ValueError( + 'Unknown word: %s. Vocabulary should include oovs here.' % + word) + return self._word2idx[word] + else: + word = '' if word not in self._word2idx else word + return self._word2idx[word] + + def sentence_encode(self, word_list): + return [self.encode(_) for _ in word_list] + + def oov_idx_map(self, idx): + return 2 if idx > self.vocab_size else idx + + def sentence_oov_map(self, index_list): + return [self.oov_idx_map(_) for _ in index_list] + + def decode(self, idx, indicate_oov=False): + if not self._idx2word.get(idx): + raise ValueError( + 'Error idx: %d. Vocabulary should include oovs here.' % idx) + if not indicate_oov or idx < self.vocab_size: + return self._idx2word[idx] + else: + return self._idx2word[idx] + '(o)' + + def sentence_decode(self, index_list, eos=None, indicate_oov=False): + l = [self.decode(_, indicate_oov) for _ in index_list] + if not eos or eos not in l: + return ' '.join(l) + else: + idx = l.index(eos) + return ' '.join(l[:idx]) + + def nl_decode(self, l, eos=None): + return [self.sentence_decode(_, eos) + '\n' for _ in l] diff --git a/requirements/nlp/space.txt b/requirements/nlp/space.txt new file mode 100644 index 00000000..09a0f64e --- /dev/null +++ b/requirements/nlp/space.txt @@ -0,0 +1,2 @@ +spacy==2.3.5 +# python -m spacy download en_core_web_sm diff --git a/tests/case/__init__.py b/tests/case/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/case/nlp/__init__.py b/tests/case/nlp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/case/nlp/dialog_generation_case.py b/tests/case/nlp/dialog_generation_case.py new file mode 100644 index 00000000..6f5ea2fe --- /dev/null +++ b/tests/case/nlp/dialog_generation_case.py @@ -0,0 +1,76 @@ +test_case = { + 'sng0073': { + 'goal': { + 'taxi': { + 'info': { + 'leaveat': '17:15', + 'destination': 'pizza hut fen ditton', + 'departure': "saint john's college" + }, + 'reqt': ['car', 'phone'], + 'fail_info': {} + } + }, + 'log': [{ + 'user': + "i would like a taxi from saint john 's college to pizza hut fen ditton .", + 'user_delex': + 'i would like a taxi from [value_departure] to [value_destination] .', + 'resp': + 'what time do you want to leave and what time do you want to arrive by ?', + 'sys': + 'what time do you want to leave and what time do you want to arrive by ?', + 'pointer': '0,0,0,0,0,0', + 'match': '', + 'constraint': + "[taxi] destination pizza hut fen ditton departure saint john 's college", + 'cons_delex': '[taxi] destination departure', + 'sys_act': '[taxi] [request] leave arrive', + 'turn_num': 0, + 'turn_domain': '[taxi]' + }, { + 'user': 'i want to leave after 17:15 .', + 'user_delex': 'i want to leave after [value_leave] .', + 'resp': + 'booking completed ! your taxi will be [value_car] contact number is [value_phone]', + 'sys': + 'booking completed ! your taxi will be blue honda contact number is 07218068540', + 'pointer': '0,0,0,0,0,0', + 'match': '', + 'constraint': + "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", + 'cons_delex': '[taxi] destination departure leave', + 'sys_act': '[taxi] [inform] car phone', + 'turn_num': 1, + 'turn_domain': '[taxi]' + }, { + 'user': 'thank you for all the help ! i appreciate it .', + 'user_delex': 'thank you for all the help ! i appreciate it .', + 'resp': + 'you are welcome . is there anything else i can help you with today ?', + 'sys': + 'you are welcome . is there anything else i can help you with today ?', + 'pointer': '0,0,0,0,0,0', + 'match': '', + 'constraint': + "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", + 'cons_delex': '[taxi] destination departure leave', + 'sys_act': '[general] [reqmore]', + 'turn_num': 2, + 'turn_domain': '[general]' + }, { + 'user': 'no , i am all set . have a nice day . bye .', + 'user_delex': 'no , i am all set . have a nice day . bye .', + 'resp': 'you too ! thank you', + 'sys': 'you too ! thank you', + 'pointer': '0,0,0,0,0,0', + 'match': '', + 'constraint': + "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", + 'cons_delex': '[taxi] destination departure leave', + 'sys_act': '[general] [bye]', + 'turn_num': 3, + 'turn_domain': '[general]' + }] + } +} diff --git a/tests/pipelines/nlp/test_dialog_generation.py b/tests/pipelines/nlp/test_dialog_generation.py index 68b82132..b3186de6 100644 --- a/tests/pipelines/nlp/test_dialog_generation.py +++ b/tests/pipelines/nlp/test_dialog_generation.py @@ -37,30 +37,31 @@ dialog_case = [{ }] +def merge(info, result): + return info + + class DialogGenerationTest(unittest.TestCase): def test_run(self): - for item in dialog_case: - q = item['user'] - a = item['sys'] - print('user:{}'.format(q)) - print('sys:{}'.format(a)) - # preprocessor = DialogGenerationPreprocessor() - # # data = DialogGenerationData() - # model = DialogGenerationModel(path, preprocessor.tokenizer) - # pipeline = DialogGenerationPipeline(model, preprocessor) - # - # history_dialog = [] - # for item in dialog_case: - # user_question = item['user'] - # print('user: {}'.format(user_question)) - # - # pipeline(user_question) - # - # sys_answer, history_dialog = pipeline() - # - # print('sys : {}'.format(sys_answer)) + modeldir = '/Users/yangliu/Desktop/space-dialog-generation' + + preprocessor = DialogGenerationPreprocessor() + model = DialogGenerationModel( + model_dir=modeldir, preprocessor.tokenizer) + pipeline = DialogGenerationPipeline(model, preprocessor) + + history_dialog = {} + for step in range(0, len(dialog_case)): + user_question = dialog_case[step]['user'] + print('user: {}'.format(user_question)) + + history_dialog_info = merge(history_dialog_info, + result) if step > 0 else {} + result = pipeline(user_question, history=history_dialog_info) + + print('sys : {}'.format(result['pred_answer'])) if __name__ == '__main__': diff --git a/tests/preprocessors/nlp/__init__.py b/tests/preprocessors/nlp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/preprocessors/nlp/test_dialog_generation.py b/tests/preprocessors/nlp/test_dialog_generation.py new file mode 100644 index 00000000..ca07922b --- /dev/null +++ b/tests/preprocessors/nlp/test_dialog_generation.py @@ -0,0 +1,25 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from tests.case.nlp.dialog_generation_case import test_case + +from maas_lib.preprocessors import DialogGenerationPreprocessor +from maas_lib.utils.constant import Fields, InputFields +from maas_lib.utils.logger import get_logger + +logger = get_logger() + + +class DialogGenerationPreprocessorTest(unittest.TestCase): + + def test_tokenize(self): + modeldir = '/Users/yangliu/Desktop/space-dialog-generation' + processor = DialogGenerationPreprocessor(model_dir=modeldir) + + for item in test_case['sng0073']['log']: + print(processor(item['user'])) + + +if __name__ == '__main__': + unittest.main()