| @@ -1,9 +1,9 @@ | |||||
| repos: | repos: | ||||
| - repo: https://gitlab.com/pycqa/flake8.git | |||||
| rev: 3.8.3 | |||||
| hooks: | |||||
| - id: flake8 | |||||
| exclude: thirdparty/|examples/ | |||||
| # - repo: https://gitlab.com/pycqa/flake8.git | |||||
| # rev: 3.8.3 | |||||
| # hooks: | |||||
| # - id: flake8 | |||||
| # exclude: thirdparty/|examples/ | |||||
| - repo: https://github.com/timothycrosley/isort | - repo: https://github.com/timothycrosley/isort | ||||
| rev: 4.3.21 | rev: 4.3.21 | ||||
| hooks: | hooks: | ||||
| @@ -4,3 +4,4 @@ from .builder import pipeline | |||||
| from .cv import * # noqa F403 | from .cv import * # noqa F403 | ||||
| from .multi_modal import * # noqa F403 | from .multi_modal import * # noqa F403 | ||||
| from .nlp import * # noqa F403 | from .nlp import * # noqa F403 | ||||
| from .nlp.space import * # noqa F403 | |||||
| @@ -84,7 +84,7 @@ class Pipeline(ABC): | |||||
| def _process_single(self, input: Input, *args, | def _process_single(self, input: Input, *args, | ||||
| **post_kwargs) -> Dict[str, Any]: | **post_kwargs) -> Dict[str, Any]: | ||||
| out = self.preprocess(input) | |||||
| out = self.preprocess(input, **post_kwargs) | |||||
| out = self.forward(out) | out = self.forward(out) | ||||
| out = self.postprocess(out, **post_kwargs) | out = self.postprocess(out, **post_kwargs) | ||||
| return out | return out | ||||
| @@ -22,28 +22,29 @@ class DialogGenerationPipeline(Model): | |||||
| """ | """ | ||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | super().__init__(model=model, preprocessor=preprocessor, **kwargs) | ||||
| pass | |||||
| self.model = model | |||||
| self.tokenizer = preprocessor.tokenizer | |||||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||||
| """return the result by the model | |||||
| def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: | |||||
| """process the prediction results | |||||
| Args: | Args: | ||||
| input (Dict[str, Any]): the preprocessed data | |||||
| inputs (Dict[str, Any]): _description_ | |||||
| Returns: | Returns: | ||||
| Dict[str, np.ndarray]: results | |||||
| Example: | |||||
| { | |||||
| 'predictions': array([1]), # lable 0-negative 1-positive | |||||
| 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||||
| } | |||||
| Dict[str, str]: the prediction results | |||||
| """ | """ | ||||
| from numpy import array, float32 | |||||
| return { | |||||
| 'predictions': array([1]), # lable 0-negative 1-positive | |||||
| 'probabilities': array([[0.11491239, 0.8850876]], dtype=float32), | |||||
| 'logits': array([[-0.53860897, 1.5029076]], | |||||
| dtype=float32) # true value | |||||
| } | |||||
| vocab_size = len(self.tokenizer.vocab) | |||||
| pred_list = inputs['predictions'] | |||||
| pred_ids = pred_list[0][0].cpu().numpy().tolist() | |||||
| for j in range(len(pred_ids)): | |||||
| if pred_ids[j] >= vocab_size: | |||||
| pred_ids[j] = 100 | |||||
| pred = self.tokenizer.convert_ids_to_tokens(pred_ids) | |||||
| pred_string = ''.join(pred).replace( | |||||
| '##', | |||||
| '').split('[SEP]')[0].replace('[CLS]', | |||||
| '').replace('[SEP]', | |||||
| '').replace('[UNK]', '') | |||||
| return {'pred_string': pred_string} | |||||
| @@ -5,3 +5,4 @@ from .builder import PREPROCESSORS, build_preprocessor | |||||
| from .common import Compose | from .common import Compose | ||||
| from .image import LoadImage, load_image | from .image import LoadImage, load_image | ||||
| from .nlp import * # noqa F403 | from .nlp import * # noqa F403 | ||||
| from .space.dialog_generation_preprcessor import * # noqa F403 | |||||
| @@ -11,8 +11,8 @@ from .base import Preprocessor | |||||
| from .builder import PREPROCESSORS | from .builder import PREPROCESSORS | ||||
| __all__ = [ | __all__ = [ | ||||
| 'Tokenize', 'SequenceClassificationPreprocessor', | |||||
| 'DialogGenerationPreprocessor' | |||||
| 'Tokenize', | |||||
| 'SequenceClassificationPreprocessor', | |||||
| ] | ] | ||||
| @@ -92,31 +92,3 @@ class SequenceClassificationPreprocessor(Preprocessor): | |||||
| rst['token_type_ids'].append(feature['token_type_ids']) | rst['token_type_ids'].append(feature['token_type_ids']) | ||||
| return rst | return rst | ||||
| @PREPROCESSORS.register_module(Fields.nlp, module_name=r'space') | |||||
| class DialogGenerationPreprocessor(Preprocessor): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||||
| Args: | |||||
| model_dir (str): model path | |||||
| """ | |||||
| super().__init__(*args, **kwargs) | |||||
| pass | |||||
| @type_assert(object, str) | |||||
| def __call__(self, data: str) -> Dict[str, Any]: | |||||
| """process the raw input data | |||||
| Args: | |||||
| data (str): a sentence | |||||
| Example: | |||||
| 'you are so handsome.' | |||||
| Returns: | |||||
| Dict[str, Any]: the preprocessed data | |||||
| """ | |||||
| return None | |||||
| @@ -0,0 +1,48 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import uuid | |||||
| from typing import Any, Dict, Union | |||||
| from maas_lib.data.nlp.space.fields.gen_field import MultiWOZBPETextField | |||||
| from maas_lib.utils.constant import Fields, InputFields | |||||
| from maas_lib.utils.type_assert import type_assert | |||||
| from ..base import Preprocessor | |||||
| from ..builder import PREPROCESSORS | |||||
| __all__ = ['DialogGenerationPreprocessor'] | |||||
| @PREPROCESSORS.register_module(Fields.nlp, module_name=r'space') | |||||
| class DialogGenerationPreprocessor(Preprocessor): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||||
| Args: | |||||
| model_dir (str): model path | |||||
| """ | |||||
| super().__init__(*args, **kwargs) | |||||
| self.model_dir: str = model_dir | |||||
| self.text_field = MultiWOZBPETextField(model_dir=self.model_dir) | |||||
| pass | |||||
| @type_assert(object, str) | |||||
| def __call__(self, data: str) -> Dict[str, Any]: | |||||
| """process the raw input data | |||||
| Args: | |||||
| data (str): a sentence | |||||
| Example: | |||||
| 'you are so handsome.' | |||||
| Returns: | |||||
| Dict[str, Any]: the preprocessed data | |||||
| """ | |||||
| idx = self.text_field.get_ids(data) | |||||
| return {'user_idx': idx} | |||||
| @@ -0,0 +1,66 @@ | |||||
| """ | |||||
| Parse argument. | |||||
| """ | |||||
| import argparse | |||||
| import json | |||||
| def str2bool(v): | |||||
| if v.lower() in ('yes', 'true', 't', 'y', '1'): | |||||
| return True | |||||
| elif v.lower() in ('no', 'false', 'f', 'n', '0'): | |||||
| return False | |||||
| else: | |||||
| raise argparse.ArgumentTypeError('Unsupported value encountered.') | |||||
| class HParams(dict): | |||||
| """ Hyper-parameters class | |||||
| Store hyper-parameters in training / infer / ... scripts. | |||||
| """ | |||||
| def __getattr__(self, name): | |||||
| if name in self.keys(): | |||||
| return self[name] | |||||
| for v in self.values(): | |||||
| if isinstance(v, HParams): | |||||
| if name in v: | |||||
| return v[name] | |||||
| raise AttributeError(f"'HParams' object has no attribute '{name}'") | |||||
| def __setattr__(self, name, value): | |||||
| self[name] = value | |||||
| def save(self, filename): | |||||
| with open(filename, 'w', encoding='utf-8') as fp: | |||||
| json.dump(self, fp, ensure_ascii=False, indent=4, sort_keys=False) | |||||
| def load(self, filename): | |||||
| with open(filename, 'r', encoding='utf-8') as fp: | |||||
| params_dict = json.load(fp) | |||||
| for k, v in params_dict.items(): | |||||
| if isinstance(v, dict): | |||||
| self[k].update(HParams(v)) | |||||
| else: | |||||
| self[k] = v | |||||
| def parse_args(parser): | |||||
| """ Parse hyper-parameters from cmdline. """ | |||||
| parsed = parser.parse_args() | |||||
| args = HParams() | |||||
| optional_args = parser._action_groups[1] | |||||
| for action in optional_args._group_actions[1:]: | |||||
| arg_name = action.dest | |||||
| args[arg_name] = getattr(parsed, arg_name) | |||||
| for group in parser._action_groups[2:]: | |||||
| group_args = HParams() | |||||
| for action in group._group_actions: | |||||
| arg_name = action.dest | |||||
| group_args[arg_name] = getattr(parsed, arg_name) | |||||
| if len(group_args) > 0: | |||||
| args[group.title] = group_args | |||||
| return args | |||||
| @@ -0,0 +1,316 @@ | |||||
| import os | |||||
| import random | |||||
| import sqlite3 | |||||
| import json | |||||
| from .ontology import all_domains, db_domains | |||||
| class MultiWozDB(object): | |||||
| def __init__(self, db_dir, db_paths): | |||||
| self.dbs = {} | |||||
| self.sql_dbs = {} | |||||
| for domain in all_domains: | |||||
| with open(os.path.join(db_dir, db_paths[domain]), 'r') as f: | |||||
| self.dbs[domain] = json.loads(f.read().lower()) | |||||
| def oneHotVector(self, domain, num): | |||||
| """Return number of available entities for particular domain.""" | |||||
| vector = [0, 0, 0, 0] | |||||
| if num == '': | |||||
| return vector | |||||
| if domain != 'train': | |||||
| if num == 0: | |||||
| vector = [1, 0, 0, 0] | |||||
| elif num == 1: | |||||
| vector = [0, 1, 0, 0] | |||||
| elif num <= 3: | |||||
| vector = [0, 0, 1, 0] | |||||
| else: | |||||
| vector = [0, 0, 0, 1] | |||||
| else: | |||||
| if num == 0: | |||||
| vector = [1, 0, 0, 0] | |||||
| elif num <= 5: | |||||
| vector = [0, 1, 0, 0] | |||||
| elif num <= 10: | |||||
| vector = [0, 0, 1, 0] | |||||
| else: | |||||
| vector = [0, 0, 0, 1] | |||||
| return vector | |||||
| def addBookingPointer(self, turn_da): | |||||
| """Add information about availability of the booking option.""" | |||||
| # Booking pointer | |||||
| # Do not consider booking two things in a single turn. | |||||
| vector = [0, 0] | |||||
| if turn_da.get('booking-nobook'): | |||||
| vector = [1, 0] | |||||
| if turn_da.get('booking-book') or turn_da.get('train-offerbooked'): | |||||
| vector = [0, 1] | |||||
| return vector | |||||
| def addDBPointer(self, domain, match_num, return_num=False): | |||||
| """Create database pointer for all related domains.""" | |||||
| # if turn_domains is None: | |||||
| # turn_domains = db_domains | |||||
| if domain in db_domains: | |||||
| vector = self.oneHotVector(domain, match_num) | |||||
| else: | |||||
| vector = [0, 0, 0, 0] | |||||
| return vector | |||||
| def addDBIndicator(self, domain, match_num, return_num=False): | |||||
| """Create database indicator for all related domains.""" | |||||
| # if turn_domains is None: | |||||
| # turn_domains = db_domains | |||||
| if domain in db_domains: | |||||
| vector = self.oneHotVector(domain, match_num) | |||||
| else: | |||||
| vector = [0, 0, 0, 0] | |||||
| # '[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]' | |||||
| if vector == [0, 0, 0, 0]: | |||||
| indicator = '[db_nores]' | |||||
| else: | |||||
| indicator = '[db_%s]' % vector.index(1) | |||||
| return indicator | |||||
| def get_match_num(self, constraints, return_entry=False): | |||||
| """Create database pointer for all related domains.""" | |||||
| match = {'general': ''} | |||||
| entry = {} | |||||
| # if turn_domains is None: | |||||
| # turn_domains = db_domains | |||||
| for domain in all_domains: | |||||
| match[domain] = '' | |||||
| if domain in db_domains and constraints.get(domain): | |||||
| matched_ents = self.queryJsons(domain, constraints[domain]) | |||||
| match[domain] = len(matched_ents) | |||||
| if return_entry: | |||||
| entry[domain] = matched_ents | |||||
| if return_entry: | |||||
| return entry | |||||
| return match | |||||
| def pointerBack(self, vector, domain): | |||||
| # multi domain implementation | |||||
| # domnum = cfg.domain_num | |||||
| if domain.endswith(']'): | |||||
| domain = domain[1:-1] | |||||
| if domain != 'train': | |||||
| nummap = {0: '0', 1: '1', 2: '2-3', 3: '>3'} | |||||
| else: | |||||
| nummap = {0: '0', 1: '1-5', 2: '6-10', 3: '>10'} | |||||
| if vector[:4] == [0, 0, 0, 0]: | |||||
| report = '' | |||||
| else: | |||||
| num = vector.index(1) | |||||
| report = domain + ': ' + nummap[num] + '; ' | |||||
| if vector[-2] == 0 and vector[-1] == 1: | |||||
| report += 'booking: ok' | |||||
| if vector[-2] == 1 and vector[-1] == 0: | |||||
| report += 'booking: unable' | |||||
| return report | |||||
| def queryJsons(self, | |||||
| domain, | |||||
| constraints, | |||||
| exactly_match=True, | |||||
| return_name=False): | |||||
| """Returns the list of entities for a given domain | |||||
| based on the annotation of the belief state | |||||
| constraints: dict e.g. {'pricerange': 'cheap', 'area': 'west'} | |||||
| """ | |||||
| # query the db | |||||
| if domain == 'taxi': | |||||
| return [{ | |||||
| 'taxi_colors': | |||||
| random.choice(self.dbs[domain]['taxi_colors']), | |||||
| 'taxi_types': | |||||
| random.choice(self.dbs[domain]['taxi_types']), | |||||
| 'taxi_phone': [random.randint(1, 9) for _ in range(10)] | |||||
| }] | |||||
| if domain == 'police': | |||||
| return self.dbs['police'] | |||||
| if domain == 'hospital': | |||||
| if constraints.get('department'): | |||||
| for entry in self.dbs['hospital']: | |||||
| if entry.get('department') == constraints.get( | |||||
| 'department'): | |||||
| return [entry] | |||||
| else: | |||||
| return [] | |||||
| valid_cons = False | |||||
| for v in constraints.values(): | |||||
| if v not in ['not mentioned', '']: | |||||
| valid_cons = True | |||||
| if not valid_cons: | |||||
| return [] | |||||
| match_result = [] | |||||
| if 'name' in constraints: | |||||
| for db_ent in self.dbs[domain]: | |||||
| if 'name' in db_ent: | |||||
| cons = constraints['name'] | |||||
| dbn = db_ent['name'] | |||||
| if cons == dbn: | |||||
| db_ent = db_ent if not return_name else db_ent['name'] | |||||
| match_result.append(db_ent) | |||||
| return match_result | |||||
| for db_ent in self.dbs[domain]: | |||||
| match = True | |||||
| for s, v in constraints.items(): | |||||
| if s == 'name': | |||||
| continue | |||||
| if s in ['people', 'stay'] or (domain == 'hotel' and s == 'day') or \ | |||||
| (domain == 'restaurant' and s in ['day', 'time']): | |||||
| # 因为这些inform slot属于book info,而数据库中没有这些slot; | |||||
| # 能否book是根据user goal中的信息判断,而非通过数据库查询; | |||||
| continue | |||||
| skip_case = { | |||||
| "don't care": 1, | |||||
| "do n't care": 1, | |||||
| 'dont care': 1, | |||||
| 'not mentioned': 1, | |||||
| 'dontcare': 1, | |||||
| '': 1 | |||||
| } | |||||
| if skip_case.get(v): | |||||
| continue | |||||
| if s not in db_ent: | |||||
| # logging.warning('Searching warning: slot %s not in %s db'%(s, domain)) | |||||
| match = False | |||||
| break | |||||
| # v = 'guesthouse' if v == 'guest house' else v | |||||
| # v = 'swimmingpool' if v == 'swimming pool' else v | |||||
| v = 'yes' if v == 'free' else v | |||||
| if s in ['arrive', 'leave']: | |||||
| try: | |||||
| h, m = v.split( | |||||
| ':' | |||||
| ) # raise error if time value is not xx:xx format | |||||
| v = int(h) * 60 + int(m) | |||||
| except: | |||||
| match = False | |||||
| break | |||||
| time = int(db_ent[s].split(':')[0]) * 60 + int( | |||||
| db_ent[s].split(':')[1]) | |||||
| if s == 'arrive' and v > time: | |||||
| match = False | |||||
| if s == 'leave' and v < time: | |||||
| match = False | |||||
| else: | |||||
| if exactly_match and v != db_ent[s]: | |||||
| match = False | |||||
| break | |||||
| elif v not in db_ent[s]: | |||||
| match = False | |||||
| break | |||||
| if match: | |||||
| match_result.append(db_ent) | |||||
| if not return_name: | |||||
| return match_result | |||||
| else: | |||||
| if domain == 'train': | |||||
| match_result = [e['id'] for e in match_result] | |||||
| else: | |||||
| match_result = [e['name'] for e in match_result] | |||||
| return match_result | |||||
| def querySQL(self, domain, constraints): | |||||
| if not self.sql_dbs: | |||||
| for dom in db_domains: | |||||
| db = 'db/{}-dbase.db'.format(dom) | |||||
| conn = sqlite3.connect(db) | |||||
| c = conn.cursor() | |||||
| self.sql_dbs[dom] = c | |||||
| sql_query = 'select * from {}'.format(domain) | |||||
| flag = True | |||||
| for key, val in constraints.items(): | |||||
| if val == '' or val == 'dontcare' or val == 'not mentioned' or val == "don't care" or val == 'dont care' or val == "do n't care": | |||||
| pass | |||||
| else: | |||||
| if flag: | |||||
| sql_query += ' where ' | |||||
| val2 = val.replace("'", "''") | |||||
| # val2 = normalize(val2) | |||||
| if key == 'leaveAt': | |||||
| sql_query += r' ' + key + ' > ' + r"'" + val2 + r"'" | |||||
| elif key == 'arriveBy': | |||||
| sql_query += r' ' + key + ' < ' + r"'" + val2 + r"'" | |||||
| else: | |||||
| sql_query += r' ' + key + '=' + r"'" + val2 + r"'" | |||||
| flag = False | |||||
| else: | |||||
| val2 = val.replace("'", "''") | |||||
| # val2 = normalize(val2) | |||||
| if key == 'leaveAt': | |||||
| sql_query += r' and ' + key + ' > ' + r"'" + val2 + r"'" | |||||
| elif key == 'arriveBy': | |||||
| sql_query += r' and ' + key + ' < ' + r"'" + val2 + r"'" | |||||
| else: | |||||
| sql_query += r' and ' + key + '=' + r"'" + val2 + r"'" | |||||
| try: # "select * from attraction where name = 'queens college'" | |||||
| print(sql_query) | |||||
| return self.sql_dbs[domain].execute(sql_query).fetchall() | |||||
| except: | |||||
| return [] # TODO test it | |||||
| if __name__ == '__main__': | |||||
| dbPATHs = { | |||||
| 'attraction': 'db/attraction_db_processed.json', | |||||
| 'hospital': 'db/hospital_db_processed.json', | |||||
| 'hotel': 'db/hotel_db_processed.json', | |||||
| 'police': 'db/police_db_processed.json', | |||||
| 'restaurant': 'db/restaurant_db_processed.json', | |||||
| 'taxi': 'db/taxi_db_processed.json', | |||||
| 'train': 'db/train_db_processed.json', | |||||
| } | |||||
| db = MultiWozDB(dbPATHs) | |||||
| while True: | |||||
| constraints = {} | |||||
| inp = input( | |||||
| 'input belief state in fomat: domain-slot1=value1;slot2=value2...\n' | |||||
| ) | |||||
| domain, cons = inp.split('-') | |||||
| for sv in cons.split(';'): | |||||
| s, v = sv.split('=') | |||||
| constraints[s] = v | |||||
| # res = db.querySQL(domain, constraints) | |||||
| res = db.queryJsons(domain, constraints, return_name=True) | |||||
| report = [] | |||||
| reidx = { | |||||
| 'hotel': 8, | |||||
| 'restaurant': 6, | |||||
| 'attraction': 5, | |||||
| 'train': 1, | |||||
| } | |||||
| # for ent in res: | |||||
| # if reidx.get(domain): | |||||
| # report.append(ent[reidx[domain]]) | |||||
| # for ent in res: | |||||
| # if 'name' in ent: | |||||
| # report.append(ent['name']) | |||||
| # if 'trainid' in ent: | |||||
| # report.append(ent['trainid']) | |||||
| print(constraints) | |||||
| print(res) | |||||
| print('count:', len(res), '\nnames:', report) | |||||
| @@ -0,0 +1,210 @@ | |||||
| all_domains = [ | |||||
| 'restaurant', 'hotel', 'attraction', 'train', 'taxi', 'police', 'hospital' | |||||
| ] | |||||
| db_domains = ['restaurant', 'hotel', 'attraction', 'train'] | |||||
| normlize_slot_names = { | |||||
| 'car type': 'car', | |||||
| 'entrance fee': 'price', | |||||
| 'duration': 'time', | |||||
| 'leaveat': 'leave', | |||||
| 'arriveby': 'arrive', | |||||
| 'trainid': 'id' | |||||
| } | |||||
| requestable_slots = { | |||||
| 'taxi': ['car', 'phone'], | |||||
| 'police': ['postcode', 'address', 'phone'], | |||||
| 'hospital': ['address', 'phone', 'postcode'], | |||||
| 'hotel': [ | |||||
| 'address', 'postcode', 'internet', 'phone', 'parking', 'type', | |||||
| 'pricerange', 'stars', 'area', 'reference' | |||||
| ], | |||||
| 'attraction': | |||||
| ['price', 'type', 'address', 'postcode', 'phone', 'area', 'reference'], | |||||
| 'train': ['time', 'leave', 'price', 'arrive', 'id', 'reference'], | |||||
| 'restaurant': [ | |||||
| 'phone', 'postcode', 'address', 'pricerange', 'food', 'area', | |||||
| 'reference' | |||||
| ] | |||||
| } | |||||
| all_reqslot = [ | |||||
| 'car', 'address', 'postcode', 'phone', 'internet', 'parking', 'type', | |||||
| 'pricerange', 'food', 'stars', 'area', 'reference', 'time', 'leave', | |||||
| 'price', 'arrive', 'id' | |||||
| ] | |||||
| informable_slots = { | |||||
| 'taxi': ['leave', 'destination', 'departure', 'arrive'], | |||||
| 'police': [], | |||||
| 'hospital': ['department'], | |||||
| 'hotel': [ | |||||
| 'type', 'parking', 'pricerange', 'internet', 'stay', 'day', 'people', | |||||
| 'area', 'stars', 'name' | |||||
| ], | |||||
| 'attraction': ['area', 'type', 'name'], | |||||
| 'train': ['destination', 'day', 'arrive', 'departure', 'people', 'leave'], | |||||
| 'restaurant': | |||||
| ['food', 'pricerange', 'area', 'name', 'time', 'day', 'people'] | |||||
| } | |||||
| all_infslot = [ | |||||
| 'type', 'parking', 'pricerange', 'internet', 'stay', 'day', 'people', | |||||
| 'area', 'stars', 'name', 'leave', 'destination', 'departure', 'arrive', | |||||
| 'department', 'food', 'time' | |||||
| ] | |||||
| all_slots = all_reqslot + [ | |||||
| 'stay', 'day', 'people', 'name', 'destination', 'departure', 'department' | |||||
| ] | |||||
| get_slot = {} | |||||
| for s in all_slots: | |||||
| get_slot[s] = 1 | |||||
| # mapping slots in dialogue act to original goal slot names | |||||
| da_abbr_to_slot_name = { | |||||
| 'addr': 'address', | |||||
| 'fee': 'price', | |||||
| 'post': 'postcode', | |||||
| 'ref': 'reference', | |||||
| 'ticket': 'price', | |||||
| 'depart': 'departure', | |||||
| 'dest': 'destination', | |||||
| } | |||||
| dialog_acts = { | |||||
| 'restaurant': [ | |||||
| 'inform', 'request', 'nooffer', 'recommend', 'select', 'offerbook', | |||||
| 'offerbooked', 'nobook' | |||||
| ], | |||||
| 'hotel': [ | |||||
| 'inform', 'request', 'nooffer', 'recommend', 'select', 'offerbook', | |||||
| 'offerbooked', 'nobook' | |||||
| ], | |||||
| 'attraction': ['inform', 'request', 'nooffer', 'recommend', 'select'], | |||||
| 'train': | |||||
| ['inform', 'request', 'nooffer', 'offerbook', 'offerbooked', 'select'], | |||||
| 'taxi': ['inform', 'request'], | |||||
| 'police': ['inform', 'request'], | |||||
| 'hospital': ['inform', 'request'], | |||||
| # 'booking': ['book', 'inform', 'nobook', 'request'], | |||||
| 'general': ['bye', 'greet', 'reqmore', 'welcome'], | |||||
| } | |||||
| all_acts = [] | |||||
| for acts in dialog_acts.values(): | |||||
| for act in acts: | |||||
| if act not in all_acts: | |||||
| all_acts.append(act) | |||||
| dialog_act_params = { | |||||
| 'inform': all_slots + ['choice', 'open'], | |||||
| 'request': all_infslot + ['choice', 'price'], | |||||
| 'nooffer': all_slots + ['choice'], | |||||
| 'recommend': all_reqslot + ['choice', 'open'], | |||||
| 'select': all_slots + ['choice'], | |||||
| # 'book': ['time', 'people', 'stay', 'reference', 'day', 'name', 'choice'], | |||||
| 'nobook': ['time', 'people', 'stay', 'reference', 'day', 'name', 'choice'], | |||||
| 'offerbook': all_slots + ['choice'], | |||||
| 'offerbooked': all_slots + ['choice'], | |||||
| 'reqmore': [], | |||||
| 'welcome': [], | |||||
| 'bye': [], | |||||
| 'greet': [], | |||||
| } | |||||
| dialog_act_all_slots = all_slots + ['choice', 'open'] | |||||
| # special slot tokens in belief span | |||||
| # no need of this, just covert slot to [slot] e.g. pricerange -> [pricerange] | |||||
| slot_name_to_slot_token = {} | |||||
| # special slot tokens in responses | |||||
| # not use at the momoent | |||||
| slot_name_to_value_token = { | |||||
| # 'entrance fee': '[value_price]', | |||||
| # 'pricerange': '[value_price]', | |||||
| # 'arriveby': '[value_time]', | |||||
| # 'leaveat': '[value_time]', | |||||
| # 'departure': '[value_place]', | |||||
| # 'destination': '[value_place]', | |||||
| # 'stay': 'count', | |||||
| # 'people': 'count' | |||||
| } | |||||
| # eos tokens definition | |||||
| eos_tokens = { | |||||
| 'user': '<eos_u>', | |||||
| 'user_delex': '<eos_u>', | |||||
| 'resp': '<eos_r>', | |||||
| 'resp_gen': '<eos_r>', | |||||
| 'pv_resp': '<eos_r>', | |||||
| 'bspn': '<eos_b>', | |||||
| 'bspn_gen': '<eos_b>', | |||||
| 'pv_bspn': '<eos_b>', | |||||
| 'bsdx': '<eos_b>', | |||||
| 'bsdx_gen': '<eos_b>', | |||||
| 'pv_bsdx': '<eos_b>', | |||||
| 'qspn': '<eos_q>', | |||||
| 'qspn_gen': '<eos_q>', | |||||
| 'pv_qspn': '<eos_q>', | |||||
| 'aspn': '<eos_a>', | |||||
| 'aspn_gen': '<eos_a>', | |||||
| 'pv_aspn': '<eos_a>', | |||||
| 'dspn': '<eos_d>', | |||||
| 'dspn_gen': '<eos_d>', | |||||
| 'pv_dspn': '<eos_d>' | |||||
| } | |||||
| # sos tokens definition | |||||
| sos_tokens = { | |||||
| 'user': '<sos_u>', | |||||
| 'user_delex': '<sos_u>', | |||||
| 'resp': '<sos_r>', | |||||
| 'resp_gen': '<sos_r>', | |||||
| 'pv_resp': '<sos_r>', | |||||
| 'bspn': '<sos_b>', | |||||
| 'bspn_gen': '<sos_b>', | |||||
| 'pv_bspn': '<sos_b>', | |||||
| 'bsdx': '<sos_b>', | |||||
| 'bsdx_gen': '<sos_b>', | |||||
| 'pv_bsdx': '<sos_b>', | |||||
| 'qspn': '<sos_q>', | |||||
| 'qspn_gen': '<sos_q>', | |||||
| 'pv_qspn': '<sos_q>', | |||||
| 'aspn': '<sos_a>', | |||||
| 'aspn_gen': '<sos_a>', | |||||
| 'pv_aspn': '<sos_a>', | |||||
| 'dspn': '<sos_d>', | |||||
| 'dspn_gen': '<sos_d>', | |||||
| 'pv_dspn': '<sos_d>' | |||||
| } | |||||
| # db tokens definition | |||||
| db_tokens = [ | |||||
| '<sos_db>', '<eos_db>', '[book_nores]', '[book_fail]', '[book_success]', | |||||
| '[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]' | |||||
| ] | |||||
| # understand tokens definition | |||||
| def get_understand_tokens(prompt_num_for_understand): | |||||
| understand_tokens = [] | |||||
| for i in range(prompt_num_for_understand): | |||||
| understand_tokens.append(f'<understand_{i}>') | |||||
| return understand_tokens | |||||
| # policy tokens definition | |||||
| def get_policy_tokens(prompt_num_for_policy): | |||||
| policy_tokens = [] | |||||
| for i in range(prompt_num_for_policy): | |||||
| policy_tokens.append(f'<policy_{i}>') | |||||
| return policy_tokens | |||||
| # all special tokens definition | |||||
| def get_special_tokens(other_tokens): | |||||
| special_tokens = ['<go_r>', '<go_b>', '<go_a>', '<go_d>', | |||||
| '<eos_u>', '<eos_r>', '<eos_b>', '<eos_a>', '<eos_d>', '<eos_q>', | |||||
| '<sos_u>', '<sos_r>', '<sos_b>', '<sos_a>', '<sos_d>', '<sos_q>'] \ | |||||
| + db_tokens + other_tokens | |||||
| return special_tokens | |||||
| @@ -0,0 +1,6 @@ | |||||
| def hierarchical_set_score(frame1, frame2): | |||||
| # deal with empty frame | |||||
| if not (frame1 and frame2): | |||||
| return 0. | |||||
| pass | |||||
| return 0. | |||||
| @@ -0,0 +1,180 @@ | |||||
| import logging | |||||
| from collections import OrderedDict | |||||
| import json | |||||
| import numpy as np | |||||
| from . import ontology | |||||
| def clean_replace(s, r, t, forward=True, backward=False): | |||||
| def clean_replace_single(s, r, t, forward, backward, sidx=0): | |||||
| # idx = s[sidx:].find(r) | |||||
| idx = s.find(r) | |||||
| if idx == -1: | |||||
| return s, -1 | |||||
| idx_r = idx + len(r) | |||||
| if backward: | |||||
| while idx > 0 and s[idx - 1]: | |||||
| idx -= 1 | |||||
| elif idx > 0 and s[idx - 1] != ' ': | |||||
| return s, -1 | |||||
| if forward: | |||||
| while idx_r < len(s) and (s[idx_r].isalpha() | |||||
| or s[idx_r].isdigit()): | |||||
| idx_r += 1 | |||||
| elif idx_r != len(s) and (s[idx_r].isalpha() or s[idx_r].isdigit()): | |||||
| return s, -1 | |||||
| return s[:idx] + t + s[idx_r:], idx_r | |||||
| # source, replace, target = s, r, t | |||||
| # count = 0 | |||||
| sidx = 0 | |||||
| while sidx != -1: | |||||
| s, sidx = clean_replace_single(s, r, t, forward, backward, sidx) | |||||
| # count += 1 | |||||
| # print(s, sidx) | |||||
| # if count == 20: | |||||
| # print(source, '\n', replace, '\n', target) | |||||
| # quit() | |||||
| return s | |||||
| def py2np(list): | |||||
| return np.array(list) | |||||
| def write_dict(fn, dic): | |||||
| with open(fn, 'w') as f: | |||||
| json.dump(dic, f, indent=2) | |||||
| def f1_score(label_list, pred_list): | |||||
| tp = len([t for t in pred_list if t in label_list]) | |||||
| fp = max(0, len(pred_list) - tp) | |||||
| fn = max(0, len(label_list) - tp) | |||||
| precision = tp / (tp + fp + 1e-10) | |||||
| recall = tp / (tp + fn + 1e-10) | |||||
| f1 = 2 * precision * recall / (precision + recall + 1e-10) | |||||
| return f1 | |||||
| class MultiWOZVocab(object): | |||||
| def __init__(self, vocab_size=0): | |||||
| """ | |||||
| vocab for multiwoz dataset | |||||
| """ | |||||
| self.vocab_size = vocab_size | |||||
| self.vocab_size_oov = 0 # get after construction | |||||
| self._idx2word = {} # word + oov | |||||
| self._word2idx = {} # word | |||||
| self._freq_dict = {} # word + oov | |||||
| for w in [ | |||||
| '[PAD]', '<go_r>', '[UNK]', '<go_b>', '<go_a>', '<eos_u>', | |||||
| '<eos_r>', '<eos_b>', '<eos_a>', '<go_d>', '<eos_d>' | |||||
| ]: | |||||
| self._absolute_add_word(w) | |||||
| def _absolute_add_word(self, w): | |||||
| idx = len(self._idx2word) | |||||
| self._idx2word[idx] = w | |||||
| self._word2idx[w] = idx | |||||
| def add_word(self, word): | |||||
| if word not in self._freq_dict: | |||||
| self._freq_dict[word] = 0 | |||||
| self._freq_dict[word] += 1 | |||||
| def has_word(self, word): | |||||
| return self._freq_dict.get(word) | |||||
| def _add_to_vocab(self, word): | |||||
| if word not in self._word2idx: | |||||
| idx = len(self._idx2word) | |||||
| self._idx2word[idx] = word | |||||
| self._word2idx[word] = idx | |||||
| def construct(self): | |||||
| l = sorted(self._freq_dict.keys(), key=lambda x: -self._freq_dict[x]) | |||||
| print('Vocabulary size including oov: %d' % | |||||
| (len(l) + len(self._idx2word))) | |||||
| if len(l) + len(self._idx2word) < self.vocab_size: | |||||
| logging.warning( | |||||
| 'actual label set smaller than that configured: {}/{}'.format( | |||||
| len(l) + len(self._idx2word), self.vocab_size)) | |||||
| for word in ontology.all_domains + ['general']: | |||||
| word = '[' + word + ']' | |||||
| self._add_to_vocab(word) | |||||
| for word in ontology.all_acts: | |||||
| word = '[' + word + ']' | |||||
| self._add_to_vocab(word) | |||||
| for word in ontology.all_slots: | |||||
| self._add_to_vocab(word) | |||||
| for word in l: | |||||
| if word.startswith('[value_') and word.endswith(']'): | |||||
| self._add_to_vocab(word) | |||||
| for word in l: | |||||
| self._add_to_vocab(word) | |||||
| self.vocab_size_oov = len(self._idx2word) | |||||
| def load_vocab(self, vocab_path): | |||||
| self._freq_dict = json.loads( | |||||
| open(vocab_path + '.freq.json', 'r').read()) | |||||
| self._word2idx = json.loads( | |||||
| open(vocab_path + '.word2idx.json', 'r').read()) | |||||
| self._idx2word = {} | |||||
| for w, idx in self._word2idx.items(): | |||||
| self._idx2word[idx] = w | |||||
| self.vocab_size_oov = len(self._idx2word) | |||||
| print('vocab file loaded from "' + vocab_path + '"') | |||||
| print('Vocabulary size including oov: %d' % (self.vocab_size_oov)) | |||||
| def save_vocab(self, vocab_path): | |||||
| _freq_dict = OrderedDict( | |||||
| sorted( | |||||
| self._freq_dict.items(), key=lambda kv: kv[1], reverse=True)) | |||||
| write_dict(vocab_path + '.word2idx.json', self._word2idx) | |||||
| write_dict(vocab_path + '.freq.json', _freq_dict) | |||||
| def encode(self, word, include_oov=True): | |||||
| if include_oov: | |||||
| if self._word2idx.get(word, None) is None: | |||||
| raise ValueError( | |||||
| 'Unknown word: %s. Vocabulary should include oovs here.' % | |||||
| word) | |||||
| return self._word2idx[word] | |||||
| else: | |||||
| word = '<unk>' if word not in self._word2idx else word | |||||
| return self._word2idx[word] | |||||
| def sentence_encode(self, word_list): | |||||
| return [self.encode(_) for _ in word_list] | |||||
| def oov_idx_map(self, idx): | |||||
| return 2 if idx > self.vocab_size else idx | |||||
| def sentence_oov_map(self, index_list): | |||||
| return [self.oov_idx_map(_) for _ in index_list] | |||||
| def decode(self, idx, indicate_oov=False): | |||||
| if not self._idx2word.get(idx): | |||||
| raise ValueError( | |||||
| 'Error idx: %d. Vocabulary should include oovs here.' % idx) | |||||
| if not indicate_oov or idx < self.vocab_size: | |||||
| return self._idx2word[idx] | |||||
| else: | |||||
| return self._idx2word[idx] + '(o)' | |||||
| def sentence_decode(self, index_list, eos=None, indicate_oov=False): | |||||
| l = [self.decode(_, indicate_oov) for _ in index_list] | |||||
| if not eos or eos not in l: | |||||
| return ' '.join(l) | |||||
| else: | |||||
| idx = l.index(eos) | |||||
| return ' '.join(l[:idx]) | |||||
| def nl_decode(self, l, eos=None): | |||||
| return [self.sentence_decode(_, eos) + '\n' for _ in l] | |||||
| @@ -0,0 +1,2 @@ | |||||
| spacy==2.3.5 | |||||
| # python -m spacy download en_core_web_sm | |||||
| @@ -0,0 +1,76 @@ | |||||
| test_case = { | |||||
| 'sng0073': { | |||||
| 'goal': { | |||||
| 'taxi': { | |||||
| 'info': { | |||||
| 'leaveat': '17:15', | |||||
| 'destination': 'pizza hut fen ditton', | |||||
| 'departure': "saint john's college" | |||||
| }, | |||||
| 'reqt': ['car', 'phone'], | |||||
| 'fail_info': {} | |||||
| } | |||||
| }, | |||||
| 'log': [{ | |||||
| 'user': | |||||
| "i would like a taxi from saint john 's college to pizza hut fen ditton .", | |||||
| 'user_delex': | |||||
| 'i would like a taxi from [value_departure] to [value_destination] .', | |||||
| 'resp': | |||||
| 'what time do you want to leave and what time do you want to arrive by ?', | |||||
| 'sys': | |||||
| 'what time do you want to leave and what time do you want to arrive by ?', | |||||
| 'pointer': '0,0,0,0,0,0', | |||||
| 'match': '', | |||||
| 'constraint': | |||||
| "[taxi] destination pizza hut fen ditton departure saint john 's college", | |||||
| 'cons_delex': '[taxi] destination departure', | |||||
| 'sys_act': '[taxi] [request] leave arrive', | |||||
| 'turn_num': 0, | |||||
| 'turn_domain': '[taxi]' | |||||
| }, { | |||||
| 'user': 'i want to leave after 17:15 .', | |||||
| 'user_delex': 'i want to leave after [value_leave] .', | |||||
| 'resp': | |||||
| 'booking completed ! your taxi will be [value_car] contact number is [value_phone]', | |||||
| 'sys': | |||||
| 'booking completed ! your taxi will be blue honda contact number is 07218068540', | |||||
| 'pointer': '0,0,0,0,0,0', | |||||
| 'match': '', | |||||
| 'constraint': | |||||
| "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", | |||||
| 'cons_delex': '[taxi] destination departure leave', | |||||
| 'sys_act': '[taxi] [inform] car phone', | |||||
| 'turn_num': 1, | |||||
| 'turn_domain': '[taxi]' | |||||
| }, { | |||||
| 'user': 'thank you for all the help ! i appreciate it .', | |||||
| 'user_delex': 'thank you for all the help ! i appreciate it .', | |||||
| 'resp': | |||||
| 'you are welcome . is there anything else i can help you with today ?', | |||||
| 'sys': | |||||
| 'you are welcome . is there anything else i can help you with today ?', | |||||
| 'pointer': '0,0,0,0,0,0', | |||||
| 'match': '', | |||||
| 'constraint': | |||||
| "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", | |||||
| 'cons_delex': '[taxi] destination departure leave', | |||||
| 'sys_act': '[general] [reqmore]', | |||||
| 'turn_num': 2, | |||||
| 'turn_domain': '[general]' | |||||
| }, { | |||||
| 'user': 'no , i am all set . have a nice day . bye .', | |||||
| 'user_delex': 'no , i am all set . have a nice day . bye .', | |||||
| 'resp': 'you too ! thank you', | |||||
| 'sys': 'you too ! thank you', | |||||
| 'pointer': '0,0,0,0,0,0', | |||||
| 'match': '', | |||||
| 'constraint': | |||||
| "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", | |||||
| 'cons_delex': '[taxi] destination departure leave', | |||||
| 'sys_act': '[general] [bye]', | |||||
| 'turn_num': 3, | |||||
| 'turn_domain': '[general]' | |||||
| }] | |||||
| } | |||||
| } | |||||
| @@ -37,30 +37,31 @@ dialog_case = [{ | |||||
| }] | }] | ||||
| def merge(info, result): | |||||
| return info | |||||
| class DialogGenerationTest(unittest.TestCase): | class DialogGenerationTest(unittest.TestCase): | ||||
| def test_run(self): | def test_run(self): | ||||
| for item in dialog_case: | |||||
| q = item['user'] | |||||
| a = item['sys'] | |||||
| print('user:{}'.format(q)) | |||||
| print('sys:{}'.format(a)) | |||||
| # preprocessor = DialogGenerationPreprocessor() | |||||
| # # data = DialogGenerationData() | |||||
| # model = DialogGenerationModel(path, preprocessor.tokenizer) | |||||
| # pipeline = DialogGenerationPipeline(model, preprocessor) | |||||
| # | |||||
| # history_dialog = [] | |||||
| # for item in dialog_case: | |||||
| # user_question = item['user'] | |||||
| # print('user: {}'.format(user_question)) | |||||
| # | |||||
| # pipeline(user_question) | |||||
| # | |||||
| # sys_answer, history_dialog = pipeline() | |||||
| # | |||||
| # print('sys : {}'.format(sys_answer)) | |||||
| modeldir = '/Users/yangliu/Desktop/space-dialog-generation' | |||||
| preprocessor = DialogGenerationPreprocessor() | |||||
| model = DialogGenerationModel( | |||||
| model_dir=modeldir, preprocessor.tokenizer) | |||||
| pipeline = DialogGenerationPipeline(model, preprocessor) | |||||
| history_dialog = {} | |||||
| for step in range(0, len(dialog_case)): | |||||
| user_question = dialog_case[step]['user'] | |||||
| print('user: {}'.format(user_question)) | |||||
| history_dialog_info = merge(history_dialog_info, | |||||
| result) if step > 0 else {} | |||||
| result = pipeline(user_question, history=history_dialog_info) | |||||
| print('sys : {}'.format(result['pred_answer'])) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| @@ -0,0 +1,25 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from tests.case.nlp.dialog_generation_case import test_case | |||||
| from maas_lib.preprocessors import DialogGenerationPreprocessor | |||||
| from maas_lib.utils.constant import Fields, InputFields | |||||
| from maas_lib.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| class DialogGenerationPreprocessorTest(unittest.TestCase): | |||||
| def test_tokenize(self): | |||||
| modeldir = '/Users/yangliu/Desktop/space-dialog-generation' | |||||
| processor = DialogGenerationPreprocessor(model_dir=modeldir) | |||||
| for item in test_case['sng0073']['log']: | |||||
| print(processor(item['user'])) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||