Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10061562master
| @@ -241,6 +241,7 @@ class Trainers(object): | |||
| # nlp trainers | |||
| bert_sentiment_analysis = 'bert-sentiment-analysis' | |||
| dialog_modeling_trainer = 'dialog-modeling-trainer' | |||
| dialog_intent_trainer = 'dialog-intent-trainer' | |||
| nlp_base_trainer = 'nlp-base-trainer' | |||
| nlp_veco_trainer = 'nlp-veco-trainer' | |||
| @@ -1,6 +1,6 @@ | |||
| from .configuration_space import SpaceConfig | |||
| from .gen_unified_transformer import GenUnifiedTransformer | |||
| from .generator import Generator as SpaceGenerator | |||
| from .generator import SpaceGenerator | |||
| from .intent_unified_transformer import IntentUnifiedTransformer | |||
| from .model_base import SpaceModelBase | |||
| from .modeling_space import (SpaceForDST, SpaceForMaskedLM, | |||
| @@ -38,24 +38,24 @@ def gather(var, idx): | |||
| return var | |||
| class Generator(object): | |||
| class SpaceGenerator(object): | |||
| """ Genrator class. """ | |||
| _registry = dict() | |||
| @classmethod | |||
| def register(cls, name): | |||
| Generator._registry[name] = cls | |||
| SpaceGenerator._registry[name] = cls | |||
| return | |||
| @staticmethod | |||
| def by_name(name): | |||
| return Generator._registry[name] | |||
| return SpaceGenerator._registry[name] | |||
| @staticmethod | |||
| def create(config, *args, **kwargs): | |||
| """ Create generator. """ | |||
| generator_cls = Generator.by_name(config.Generator.generator) | |||
| generator_cls = SpaceGenerator.by_name(config.Generator.generator) | |||
| return generator_cls(config, *args, **kwargs) | |||
| def __init__(self, config, reader): | |||
| @@ -83,7 +83,7 @@ class Generator(object): | |||
| raise NotImplementedError | |||
| class BeamSearch(Generator): | |||
| class BeamSearch(SpaceGenerator): | |||
| """ BeamSearch generator. """ | |||
| def __init__(self, config, reader): | |||
| @@ -41,7 +41,7 @@ class SpaceForDialogModeling(TorchModel): | |||
| self.text_field = kwargs.pop( | |||
| 'text_field', | |||
| MultiWOZBPETextField(self.model_dir, config=self.config)) | |||
| MultiWOZBPETextField(config=self.config, model_dir=self.model_dir)) | |||
| self.generator = SpaceGenerator.create( | |||
| self.config, reader=self.text_field) | |||
| self.model = SpaceModelBase.create( | |||
| @@ -35,7 +35,7 @@ class DialogModelingPreprocessor(Preprocessor): | |||
| self.config.use_gpu = self.config.use_gpu and torch.cuda.is_available() | |||
| self.text_field = MultiWOZBPETextField( | |||
| self.model_dir, config=self.config) | |||
| config=self.config, model_dir=self.model_dir) | |||
| @type_assert(object, Dict) | |||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| @@ -2,9 +2,11 @@ | |||
| import os | |||
| import random | |||
| from asyncio import constants | |||
| from collections import OrderedDict | |||
| from itertools import chain | |||
| import json | |||
| import numpy as np | |||
| from modelscope.preprocessors.space.tokenizer import Tokenizer | |||
| @@ -117,7 +119,8 @@ class BPETextField(object): | |||
| return self.tokenizer.convert_tokens_to_ids([self.eos_d_token])[0] | |||
| def __init__(self, config): | |||
| self.gpu = 0 | |||
| self.train, self.dev, self.test = [], [], [] | |||
| self.gpu = config.Trainer.gpu | |||
| self.tokenizer = None | |||
| self.vocab = None | |||
| self.db = None | |||
| @@ -249,13 +252,9 @@ class BPETextField(object): | |||
| for dial in data: | |||
| batch.append(dial) | |||
| if len(batch) == self.batch_size: | |||
| # print('batch size: %d, batch num +1'%(len(batch))) | |||
| all_batches.append(batch) | |||
| batch = [] | |||
| # if remainder > 1/2 batch_size, just put them in the previous batch, otherwise form a new batch | |||
| # print('last batch size: %d, batch num +1'%(len(batch))) | |||
| # if (len(batch) % len(cfg.cuda_device)) != 0: | |||
| # batch = batch[:-(len(batch) % len(cfg.cuda_device))] | |||
| # TODO deal with deleted data | |||
| if self.gpu <= 1: | |||
| if len(batch) > 0.5 * self.batch_size: | |||
| @@ -308,7 +307,7 @@ class BPETextField(object): | |||
| class MultiWOZBPETextField(BPETextField): | |||
| def __init__(self, model_dir, config): | |||
| def __init__(self, config, **kwargs): | |||
| super(MultiWOZBPETextField, self).__init__(config) | |||
| import spacy | |||
| @@ -327,8 +326,12 @@ class MultiWOZBPETextField(BPETextField): | |||
| ) | |||
| self.nlp = spacy.load('en_core_web_sm') | |||
| if config.do_train: | |||
| db_dir = kwargs['data_dir'] | |||
| else: | |||
| db_dir = kwargs['model_dir'] | |||
| self.db = MultiWozDB( | |||
| model_dir, { | |||
| db_dir, { | |||
| 'attraction': 'db/attraction_db_processed.json', | |||
| 'hospital': 'db/hospital_db_processed.json', | |||
| 'hotel': 'db/hotel_db_processed.json', | |||
| @@ -337,14 +340,14 @@ class MultiWOZBPETextField(BPETextField): | |||
| 'taxi': 'db/taxi_db_processed.json', | |||
| 'train': 'db/train_db_processed.json', | |||
| }) | |||
| self._build_vocab(model_dir) | |||
| self._build_vocab(db_dir) | |||
| special_tokens = [ | |||
| self.pad_token, self.bos_token, self.eos_token, self.unk_token | |||
| ] | |||
| special_tokens.extend(self.add_sepcial_tokens()) | |||
| self.tokenizer = Tokenizer( | |||
| vocab_path=os.path.join(model_dir, ModelFile.VOCAB_FILE), | |||
| vocab_path=os.path.join(kwargs['model_dir'], ModelFile.VOCAB_FILE), | |||
| special_tokens=special_tokens, | |||
| tokenizer_type=config.BPETextField.tokenizer_type) | |||
| self.understand_ids = self.tokenizer.convert_tokens_to_ids( | |||
| @@ -352,6 +355,26 @@ class MultiWOZBPETextField(BPETextField): | |||
| self.policy_ids = self.tokenizer.convert_tokens_to_ids( | |||
| self.policy_tokens) | |||
| if config.do_train: | |||
| test_list = [ | |||
| line.strip().lower() for line in open( | |||
| os.path.join(kwargs['data_dir'], 'testListFile.json'), | |||
| 'r').readlines() | |||
| ] | |||
| dev_list = [ | |||
| line.strip().lower() for line in open( | |||
| os.path.join(kwargs['data_dir'], 'valListFile.json'), | |||
| 'r').readlines() | |||
| ] | |||
| self.dev_files, self.test_files = {}, {} | |||
| for fn in test_list: | |||
| self.test_files[fn.replace('.json', '')] = 1 | |||
| for fn in dev_list: | |||
| self.dev_files[fn.replace('.json', '')] = 1 | |||
| self._load_data(kwargs['data_dir']) | |||
| return | |||
| def get_ids(self, data: str): | |||
| @@ -414,7 +437,6 @@ class MultiWOZBPETextField(BPETextField): | |||
| name_to_set = {'train': self.train, 'test': self.test, 'dev': self.dev} | |||
| dial = name_to_set[set_name] | |||
| turn_bucket = self._bucket_by_turn(dial) | |||
| # self._shuffle_turn_bucket(turn_bucket) | |||
| all_batches = [] | |||
| if set_name not in self.set_stats: | |||
| @@ -433,19 +455,13 @@ class MultiWOZBPETextField(BPETextField): | |||
| except Exception: | |||
| log_str += 'turn num:%d, dial num: %d, batch num: %d last batch len: %d\n' % ( | |||
| k, len(turn_bucket[k]), len(batches), 0.0) | |||
| # print("turn num:%d, dial num:v%d, batch num: %d, "%(k, len(turn_bucket[k]), len(batches))) | |||
| num_training_steps += k * len(batches) | |||
| num_turns += k * len(turn_bucket[k]) | |||
| num_dials += len(turn_bucket[k]) | |||
| all_batches += batches | |||
| log_str += 'total batch num: %d\n' % len(all_batches) | |||
| # print('total batch num: %d'%len(all_batches)) | |||
| # print('dialog count: %d'%dia_count) | |||
| # return all_batches | |||
| # log stats | |||
| # logging.info(log_str) | |||
| # cfg.num_training_steps = num_training_steps * cfg.epoch_num | |||
| self.set_stats[set_name][ | |||
| 'num_training_steps_per_epoch'] = num_training_steps # turn-level steps | |||
| self.set_stats[set_name]['num_turns'] = num_turns | |||
| @@ -484,6 +500,71 @@ class MultiWOZBPETextField(BPETextField): | |||
| self.vocab.load_vocab(vp) | |||
| return self.vocab.vocab_size | |||
| def _load_data(self, data_dir, save_temp=True): | |||
| """ | |||
| load processed data and encode, or load already encoded data | |||
| """ | |||
| def load_data_from_resource(data_resource): | |||
| data = json.loads( | |||
| open( | |||
| os.path.join(data_dir, data_resource), | |||
| 'r', | |||
| encoding='utf-8').read().lower()) | |||
| train, dev, test = [], [], [] | |||
| for fn, dial in data.items(): | |||
| if '.json' in fn: | |||
| fn = fn.replace('.json', '') | |||
| if self.dev_files.get(fn): | |||
| dev.append(self._get_encoded_data(fn, dial)) | |||
| elif self.test_files.get(fn): | |||
| test.append(self._get_encoded_data(fn, dial)) | |||
| else: | |||
| train.append(self._get_encoded_data(fn, dial)) | |||
| return train, dev, test | |||
| data_processed = 'new_db_se_blank_encoded_domain.data.json' | |||
| data_resource = 'data_for_damd.json' | |||
| if save_temp: # save encoded data | |||
| # encoded: no sos, se_encoded: sos and eos | |||
| encoded_file = os.path.join(data_dir, data_processed) | |||
| if os.path.exists(encoded_file): | |||
| logger.info( | |||
| 'Reading encoded data from {}'.format(encoded_file)) | |||
| self.data = json.loads( | |||
| open( | |||
| os.path.join(data_dir, data_resource), | |||
| 'r', | |||
| encoding='utf-8').read().lower()) | |||
| encoded_data = json.loads( | |||
| open(encoded_file, 'r', encoding='utf-8').read()) | |||
| self.train = encoded_data['train'] | |||
| self.dev = encoded_data['dev'] | |||
| self.test = encoded_data['test'] | |||
| else: | |||
| logger.info( | |||
| 'Encoding data now and save the encoded data in {}'.format( | |||
| encoded_file)) | |||
| # not exists, encode data and save | |||
| self.train, self.dev, self.test = load_data_from_resource( | |||
| data_resource) | |||
| # save encoded data | |||
| encoded_data = { | |||
| 'train': self.train, | |||
| 'dev': self.dev, | |||
| 'test': self.test | |||
| } | |||
| json.dump(encoded_data, open(encoded_file, 'w'), indent=2) | |||
| else: # directly read processed data and encode | |||
| self.train, self.dev, self.test = load_data_from_resource( | |||
| data_resource) | |||
| random.seed(10) | |||
| random.shuffle(self.train) | |||
| logger.info('train size:{}, dev size:{}, test size:{}'.format( | |||
| len(self.train), len(self.dev), len(self.test))) | |||
| def _get_convert_str(self, sent): | |||
| assert isinstance(sent, str) | |||
| return ' '.join([ | |||
| @@ -491,14 +572,65 @@ class MultiWOZBPETextField(BPETextField): | |||
| for tok in sent.split() | |||
| ]) | |||
| def _get_encoded_data(self, fn, dial): | |||
| encoded_dial = [] | |||
| for idx, t in enumerate(dial['log']): # tokenize to list of ids | |||
| enc = {} | |||
| enc['dial_id'] = fn | |||
| enc_info_list = [ | |||
| ('user', self.sos_u_id, 'user', self.eos_u_id), | |||
| ('usdx', self.sos_u_id, 'user', self.eos_u_id), | |||
| ('resp', self.sos_r_id, 'resp', self.eos_r_id), | |||
| ('bspn', self.sos_b_id, 'constraint', self.eos_b_id), | |||
| ('bsdx', self.sos_b_id, 'cons_delex', self.eos_b_id), | |||
| ('aspn', self.sos_a_id, 'sys_act', self.eos_a_id) | |||
| ] | |||
| for enc_key, start_token, item_key, end_token in enc_info_list: | |||
| enc[enc_key] = [ | |||
| start_token | |||
| ] + self.tokenizer.convert_tokens_to_ids( | |||
| self.tokenizer.tokenize( | |||
| self._get_convert_str(t[item_key]))) + [end_token] | |||
| enc['turn_num'] = t['turn_num'] | |||
| if idx > 0 and t['turn_domain'] == '[general]': | |||
| enc['dspn'] = encoded_dial[idx - 1]['dspn'] | |||
| enc['pointer'] = encoded_dial[idx - 1]['pointer'][:4] + [ | |||
| int(i) for i in t['pointer'].split(',') | |||
| ][-2:] | |||
| enc['turn_domain'] = encoded_dial[idx - 1]['turn_domain'] | |||
| enc['db'] = encoded_dial[idx - 1]['db'] | |||
| else: | |||
| if t['turn_domain'] == '[general]': | |||
| assert not t['constraint'], f'{fn}-{idx}' | |||
| enc['dspn'] = [ | |||
| self.sos_d_id | |||
| ] + self.tokenizer.convert_tokens_to_ids( | |||
| self.tokenizer.tokenize( | |||
| self._get_convert_str( | |||
| t['turn_domain']))) + [self.eos_d_id] | |||
| enc['pointer'] = [int(i) for i in t['pointer'].split(',')] | |||
| enc['turn_domain'] = t['turn_domain'].split() | |||
| db_pointer = self.bspan_to_DBpointer(t['constraint'], | |||
| t['turn_domain'].split()) | |||
| enc['db'] = [ | |||
| self.sos_db_id | |||
| ] + self.tokenizer.convert_tokens_to_ids( | |||
| self.tokenizer.tokenize( | |||
| self._get_convert_str(db_pointer))) + [self.eos_db_id] | |||
| encoded_dial.append(enc) | |||
| return encoded_dial | |||
| def bspan_to_DBpointer(self, bspan, turn_domain): | |||
| constraint_dict = self.bspan_to_constraint_dict(bspan) | |||
| # print(constraint_dict) | |||
| matnums = self.db.get_match_num(constraint_dict) | |||
| match_dom = turn_domain[0] if len(turn_domain) == 1 else turn_domain[1] | |||
| match_dom = match_dom[1:-1] if match_dom.startswith('[') else match_dom | |||
| match = matnums[match_dom] | |||
| # vector = self.db.addDBPointer(match_dom, match) | |||
| vector = self.db.addDBIndicator(match_dom, match) | |||
| return vector | |||
| @@ -691,3 +823,67 @@ class MultiWOZBPETextField(BPETextField): | |||
| inputs['labels'] = [context] # use previous turn | |||
| return inputs, prompt_id | |||
| def restore(self, resp, domain, constraint_dict, mat_ents): | |||
| restored = resp | |||
| restored = restored.replace('[value_reference]', '53022') | |||
| restored = restored.replace('[value_car]', 'BMW') | |||
| for d in domain: | |||
| constraint = constraint_dict.get(d, None) | |||
| if constraint: | |||
| replace_res_list = [('stay', '[value_stay]'), | |||
| ('day', '[value_day]'), | |||
| ('people', '[value_people]'), | |||
| ('time', '[value_time]'), | |||
| ('type', '[value_type]')] | |||
| for key, value_key in replace_res_list: | |||
| if key in constraint: | |||
| restored = restored.replace(value_key, constraint[key]) | |||
| if d in mat_ents and len(mat_ents[d]) == 0: | |||
| for s in constraint: | |||
| if s == 'pricerange' and d in [ | |||
| 'hotel', 'restaurant' | |||
| ] and 'price]' in restored: | |||
| restored = restored.replace( | |||
| '[value_price]', constraint['pricerange']) | |||
| if s + ']' in restored: | |||
| restored = restored.replace( | |||
| '[value_%s]' % s, constraint[s]) | |||
| if '[value_choice' in restored and mat_ents.get(d): | |||
| restored = restored.replace('[value_choice]', | |||
| str(len(mat_ents[d]))) | |||
| if '[value_choice' in restored: | |||
| restored = restored.replace('[value_choice]', '3') | |||
| try: | |||
| ent = mat_ents.get(domain[-1], []) | |||
| if ent: | |||
| ent = ent[0] | |||
| for t in restored.split(): | |||
| if '[value' in t: | |||
| slot = t[7:-1] | |||
| if ent.get(slot): | |||
| if domain[-1] == 'hotel' and slot == 'price': | |||
| slot = 'pricerange' | |||
| restored = restored.replace(t, ent[slot]) | |||
| elif slot == 'price': | |||
| if ent.get('pricerange'): | |||
| restored = restored.replace( | |||
| t, ent['pricerange']) | |||
| else: | |||
| logger.info(restored, domain) | |||
| except Exception: | |||
| logger.error(resp) | |||
| logger.error(restored) | |||
| quit() | |||
| restored = restored.replace('[value_phone]', '62781111') | |||
| restored = restored.replace('[value_postcode]', 'CG9566') | |||
| restored = restored.replace('[value_address]', 'Parkside, Cambridge') | |||
| return restored | |||
| @@ -0,0 +1,130 @@ | |||
| import os | |||
| import time | |||
| from typing import Callable, Dict, Optional, Tuple, Union | |||
| import numpy as np | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.models.nlp.space.model.generator import SpaceGenerator | |||
| from modelscope.models.nlp.space.model.model_base import SpaceModelBase | |||
| from modelscope.preprocessors.space.fields.gen_field import \ | |||
| MultiWOZBPETextField | |||
| from modelscope.trainers.base import BaseTrainer | |||
| from modelscope.trainers.builder import TRAINERS | |||
| from modelscope.trainers.nlp.space.eval import MultiWOZEvaluator | |||
| from modelscope.trainers.nlp.space.trainer.gen_trainer import MultiWOZTrainer | |||
| from modelscope.utils.config import Config, ModelFile | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| def setup_seed(seed: int): | |||
| import random | |||
| import torch | |||
| torch.manual_seed(seed) | |||
| torch.cuda.manual_seed_all(seed) | |||
| np.random.seed(seed) | |||
| random.seed(seed) | |||
| torch.backends.cudnn.deterministic = True | |||
| @TRAINERS.register_module(module_name=Trainers.dialog_modeling_trainer) | |||
| class DialogModelingTrainer(BaseTrainer): | |||
| def __init__(self, | |||
| cfg_file: Optional[str] = None, | |||
| cfg_modify_fn: Optional[Callable] = None, | |||
| *args, | |||
| **kwargs): | |||
| super().__init__(os.path.join(kwargs['model_dir'], kwargs['cfg_name'])) | |||
| self.cfg_modify_fn = cfg_modify_fn | |||
| self.cfg = self.rebuild_config(self.cfg) | |||
| setup_seed(self.cfg.Trainer.seed) | |||
| # set reader and evaluator | |||
| self.bpe = MultiWOZBPETextField(self.cfg, **kwargs) | |||
| self.cfg.Model.num_token_embeddings = self.bpe.vocab_size | |||
| self.cfg.Model.num_turn_embeddings = self.bpe.max_ctx_turn + 1 | |||
| if 'work_dir' in kwargs: | |||
| self.cfg.Trainer.save_dir = kwargs['work_dir'] | |||
| else: | |||
| self.cfg.Trainer.save_dir = './default_save_dir' | |||
| # set data and data status | |||
| self.train_data = self.bpe.get_batches('train') | |||
| self.dev_data = self.bpe.get_batches('dev') | |||
| self.evaluator = MultiWOZEvaluator(reader=self.bpe, **kwargs) | |||
| # set generator | |||
| self.generator = SpaceGenerator.create(self.cfg, reader=self.bpe) | |||
| self._load_model(**kwargs) | |||
| def _load_model(self, **kwargs): | |||
| def to_tensor(array): | |||
| """ | |||
| numpy array -> tensor | |||
| """ | |||
| import torch | |||
| array = torch.tensor(array) | |||
| return array.cuda( | |||
| ) if self.cfg.use_gpu and torch.cuda.is_available() else array | |||
| # construct model | |||
| if 'model' in kwargs: | |||
| self.model = kwargs['model'] | |||
| else: | |||
| self.model = SpaceModelBase.create( | |||
| kwargs['model_dir'], | |||
| self.cfg, | |||
| reader=self.bpe, | |||
| generator=self.generator) | |||
| import torch | |||
| # multi-gpu | |||
| if self.cfg.Trainer.gpu > 1 and torch.cuda.device_count() > 1: | |||
| self.model = torch.nn.DataParallel(self.model) | |||
| # construct trainer | |||
| self.trainer = MultiWOZTrainer( | |||
| self.model, | |||
| to_tensor, | |||
| self.cfg, | |||
| reader=self.bpe, | |||
| evaluator=self.evaluator) | |||
| self.trainer.set_optimizers() | |||
| # load model, optimizer and lr_scheduler | |||
| self.trainer.load() | |||
| def rebuild_config(self, cfg: Config): | |||
| if self.cfg_modify_fn is not None: | |||
| return self.cfg_modify_fn(cfg) | |||
| return cfg | |||
| def train(self, *args, **kwargs): | |||
| logger.info('Train') | |||
| self.trainer.train(train_data=self.train_data, dev_data=self.dev_data) | |||
| def evaluate(self, | |||
| checkpoint_path: Optional[str] = None, | |||
| *args, | |||
| **kwargs) -> Dict[str, float]: | |||
| logger.info('Evaluate') | |||
| self.cfg.do_infer = True | |||
| # get best checkpoint path | |||
| pos = checkpoint_path.rfind('/') | |||
| checkpoint_name = checkpoint_path[pos + 1:] | |||
| checkpoint_dir = checkpoint_path[:pos] | |||
| assert checkpoint_name == ModelFile.TORCH_MODEL_BIN_FILE | |||
| kwargs['model_dir'] = checkpoint_dir | |||
| self._load_model(**kwargs) | |||
| self.trainer.infer(data_type='test') | |||
| @@ -0,0 +1,952 @@ | |||
| # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||
| # Copyright from https://github.com/thu-spmi/LABES | |||
| # Copyright from https://github.com/TonyNemo/UBAR-MultiWOZ | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import math | |||
| from collections import Counter | |||
| import json | |||
| import numpy as np | |||
| from nltk.util import ngrams | |||
| from sklearn.metrics import f1_score | |||
| from modelscope.utils.nlp.space import ontology, utils | |||
| from modelscope.utils.nlp.space.clean_dataset import clean_slot_values | |||
| def similar(a, b): | |||
| return a == b or a in b or b in a or a.split()[0] == b.split( | |||
| )[0] or a.split()[-1] == b.split()[-1] | |||
| def setsub(a, b): | |||
| junks_a = [] | |||
| useless_constraint = [ | |||
| 'temperature', 'week', 'est ', 'quick', 'reminder', 'near' | |||
| ] | |||
| for i in a: | |||
| flg = False | |||
| for j in b: | |||
| if similar(i, j): | |||
| flg = True | |||
| if not flg: | |||
| junks_a.append(i) | |||
| for junk in junks_a: | |||
| flg = False | |||
| for item in useless_constraint: | |||
| if item in junk: | |||
| flg = True | |||
| if not flg: | |||
| return False | |||
| return True | |||
| def setsim(a, b): | |||
| a, b = set(a), set(b) | |||
| return setsub(a, b) and setsub(b, a) | |||
| def DA_evaluate(preds, labels): | |||
| preds = np.array(preds) | |||
| labels = np.array(labels) | |||
| results = {} | |||
| for avg_name in ['micro']: | |||
| my_f1_score = f1_score(y_true=labels, y_pred=preds, average=avg_name) | |||
| results['f1_{}'.format(avg_name)] = my_f1_score | |||
| return results | |||
| class BLEUScorer(object): | |||
| # BLEU score calculator via GentScorer interface | |||
| # it calculates the BLEU-4 by taking the entire corpus in | |||
| # Calulate based multiple candidates against multiple references | |||
| def __init__(self): | |||
| pass | |||
| def score(self, parallel_corpus): | |||
| # containers | |||
| count = [0, 0, 0, 0] | |||
| clip_count = [0, 0, 0, 0] | |||
| r = 0 | |||
| c = 0 | |||
| weights = [0.25, 0.25, 0.25, 0.25] | |||
| # accumulate ngram statistics | |||
| for hyps, refs in parallel_corpus: | |||
| hyps = [hyp.split() for hyp in hyps] | |||
| refs = [ref.split() for ref in refs] | |||
| for hyp in hyps: | |||
| for i in range(4): | |||
| # accumulate ngram counts | |||
| hypcnts = Counter(ngrams(hyp, i + 1)) | |||
| cnt = sum(hypcnts.values()) | |||
| count[i] += cnt | |||
| # compute clipped counts | |||
| max_counts = {} | |||
| for ref in refs: | |||
| refcnts = Counter(ngrams(ref, i + 1)) | |||
| for ng in hypcnts: | |||
| max_counts[ng] = max( | |||
| max_counts.get(ng, 0), refcnts[ng]) | |||
| clipcnt = \ | |||
| dict((ng, min(count, max_counts[ng])) for ng, count in hypcnts.items()) | |||
| clip_count[i] += sum(clipcnt.values()) | |||
| # accumulate r & c | |||
| bestmatch = [1000, 1000] | |||
| for ref in refs: | |||
| if bestmatch[0] == 0: | |||
| break | |||
| diff = abs(len(ref) - len(hyp)) | |||
| if diff < bestmatch[0]: | |||
| bestmatch[0] = diff | |||
| bestmatch[1] = len(ref) | |||
| r += bestmatch[1] | |||
| c += len(hyp) | |||
| # computing bleu score | |||
| p0 = 1e-7 | |||
| bp = \ | |||
| 1 if c > r else math.exp(1 - float(r) / float(c)) | |||
| p_ns = \ | |||
| [float(clip_count[i]) / float(count[i] + p0) + p0 for i in range(4)] | |||
| s = \ | |||
| math.fsum(w * math.log(p_n) for w, p_n in zip(weights, p_ns) if p_n) | |||
| bleu = bp * math.exp(s) | |||
| return bleu * 100 | |||
| """" | |||
| For the data preparation and evaluation on MultiWOZ2.0/2.1, | |||
| we refer to the code of UBAR (https://github.com/TonyNemo/UBAR-MultiWOZ) | |||
| """ | |||
| class MultiWOZEvaluator(object): | |||
| def __init__(self, reader, **kwargs): | |||
| self.reader = reader | |||
| self.domains = ontology.all_domains | |||
| self.all_data = self.reader.data | |||
| self.test_data = self.reader.test | |||
| self.bleu_scorer = BLEUScorer() | |||
| self.all_info_slot = [] | |||
| for d, s_list in ontology.informable_slots.items(): | |||
| for s in s_list: | |||
| self.all_info_slot.append(d + '-' + s) | |||
| # only evaluate these slots for dialog success | |||
| self.requestables = ['phone', 'address', 'postcode', 'reference', 'id'] | |||
| self.db_dir = kwargs['data_dir'] | |||
| def pack_dial(self, data): | |||
| dials = {} | |||
| for turn in data: | |||
| dial_id = turn['dial_id'] | |||
| if dial_id not in dials: | |||
| dials[dial_id] = [] | |||
| dials[dial_id].append(turn) | |||
| return dials | |||
| def validation_metric(self, data, fout=None): | |||
| bleu = self.bleu_metric(data) | |||
| # accu_single_dom, accu_multi_dom, multi_dom_num = self.domain_eval(data) | |||
| success, match, req_offer_counts, dial_num = \ | |||
| self.context_to_response_eval(data, same_eval_as_cambridge=True, fout=fout) | |||
| return bleu, success, match | |||
| def bleu_metric(self, data, eval_dial_list=None): | |||
| gen, truth = [], [] | |||
| for row in data: | |||
| if eval_dial_list and row[ | |||
| 'dial_id'] + '.json' not in eval_dial_list: | |||
| continue | |||
| gen.append(row['resp_gen']) | |||
| truth.append(row['resp']) | |||
| wrap_generated = [[_] for _ in gen] | |||
| wrap_truth = [[_] for _ in truth] | |||
| if gen and truth: | |||
| try: | |||
| sc = self.bleu_scorer.score(zip(wrap_generated, wrap_truth)) | |||
| except Exception: | |||
| sc = 0.0 | |||
| else: | |||
| sc = 0.0 | |||
| return sc | |||
| def context_to_response_eval(self, | |||
| data, | |||
| eval_dial_list=None, | |||
| same_eval_as_cambridge=False, | |||
| fout=None): | |||
| dials = self.pack_dial(data) | |||
| counts = {} | |||
| for req in self.requestables: | |||
| counts[req + '_total'] = 0 | |||
| counts[req + '_offer'] = 0 | |||
| dial_num, successes, matches = 0, 0, 0 | |||
| for dial_id in dials: | |||
| if eval_dial_list and dial_id + '.json' not in eval_dial_list: | |||
| continue | |||
| dial = dials[dial_id] | |||
| reqs = {} | |||
| goal = {} | |||
| if '.json' not in dial_id and '.json' in list( | |||
| self.all_data.keys())[0]: | |||
| dial_id = dial_id + '.json' | |||
| for domain in ontology.all_domains: | |||
| if self.all_data[dial_id]['goal'].get(domain): | |||
| true_goal = self.all_data[dial_id]['goal'] | |||
| goal = self._parseGoal(goal, true_goal, domain) | |||
| for domain in goal.keys(): | |||
| reqs[domain] = goal[domain]['requestable'] | |||
| success, match, stats, counts = \ | |||
| self._evaluateGeneratedDialogue(dial, goal, reqs, counts, | |||
| same_eval_as_cambridge=same_eval_as_cambridge, fout=fout) | |||
| successes += success | |||
| matches += match | |||
| dial_num += 1 | |||
| succ_rate = successes / (float(dial_num) + 1e-10) * 100 | |||
| match_rate = matches / (float(dial_num) + 1e-10) * 100 | |||
| return succ_rate, match_rate, counts, dial_num | |||
| def _evaluateGeneratedDialogue(self, | |||
| dialog, | |||
| goal, | |||
| real_requestables, | |||
| counts, | |||
| soft_acc=False, | |||
| same_eval_as_cambridge=False, | |||
| fout=None): | |||
| """Evaluates the dialogue created by the model. | |||
| First we load the user goal of the dialogue, then for each turn | |||
| generated by the system we look for key-words. | |||
| For the Inform rate we look whether the entity was proposed. | |||
| For the Success rate we look for requestables slots""" | |||
| # for computing corpus success | |||
| requestables = self.requestables | |||
| # CHECK IF MATCH HAPPENED | |||
| provided_requestables = {} | |||
| venue_offered = {} | |||
| domains_in_goal = [] | |||
| log = [] | |||
| bspans = {} | |||
| for domain in goal.keys(): | |||
| venue_offered[domain] = [] | |||
| provided_requestables[domain] = [] | |||
| domains_in_goal.append(domain) | |||
| for t, turn in enumerate(dialog): | |||
| if t == 0: | |||
| continue | |||
| if fout is not None: | |||
| log.append({ | |||
| 'turn_num': turn['turn_num'], | |||
| 'turn_domain': turn['dspn'], | |||
| 'user': turn['user'], | |||
| 'aspn': turn['aspn'], | |||
| 'aspn_gen': turn['aspn_gen'], | |||
| 'resp': turn['resp'], | |||
| 'resp_gen': turn['resp_gen'], | |||
| 'pointer': turn['pointer'], | |||
| }) | |||
| sent_t = turn['resp_gen'] | |||
| for domain in goal.keys(): | |||
| # for computing success | |||
| if same_eval_as_cambridge: | |||
| # [restaurant_name], [hotel_name] instead of [value_name] | |||
| if self.reader.use_true_domain_for_ctr_eval: | |||
| dom_pred = [d[1:-1] for d in turn['dspn'].split()] | |||
| else: | |||
| dom_pred = [d[1:-1] for d in turn['dspn_gen'].split()] | |||
| if domain not in dom_pred: # fail | |||
| continue | |||
| if '[value_name]' in sent_t or '[value_id]' in sent_t: | |||
| if domain in [ | |||
| 'restaurant', 'hotel', 'attraction', 'train' | |||
| ]: | |||
| # HERE YOU CAN PUT YOUR BELIEF STATE ESTIMATION | |||
| if not self.reader.use_true_curr_bspn and not self.reader.use_true_bspn_for_ctr_eval: | |||
| bspn = turn['bspn_gen'] | |||
| else: | |||
| bspn = turn['bspn'] | |||
| constraint_dict = self.reader.bspan_to_constraint_dict( | |||
| bspn) | |||
| if constraint_dict.get(domain): | |||
| venues = self.reader.db.queryJsons( | |||
| domain, | |||
| constraint_dict[domain], | |||
| return_name=True) | |||
| else: | |||
| venues = [] | |||
| if len(venue_offered[domain]) == 0 and venues: | |||
| venue_offered[domain] = venues | |||
| bspans[domain] = constraint_dict[domain] | |||
| else: | |||
| flag = False | |||
| for ven in venues: | |||
| if ven not in venue_offered[domain]: | |||
| flag = True | |||
| break | |||
| if flag and venues: # sometimes there are no results so sample won't work | |||
| venue_offered[domain] = venues | |||
| bspans[domain] = constraint_dict[domain] | |||
| else: # not limited so we can provide one | |||
| venue_offered[domain] = '[value_name]' | |||
| # ATTENTION: assumption here - we didn't provide phone or address twice! etc | |||
| for requestable in requestables: | |||
| if requestable == 'reference': | |||
| if '[value_reference]' in sent_t: | |||
| if domain in ['restaurant', 'hotel', 'train']: | |||
| if 'booked' in turn['pointer'] or 'ok' in turn[ | |||
| 'pointer'] or '[value_reference]' in turn[ | |||
| 'resp']: | |||
| # if pointer was allowing for that? | |||
| provided_requestables[domain].append( | |||
| 'reference') | |||
| else: | |||
| provided_requestables[domain].append( | |||
| 'reference') | |||
| else: | |||
| if '[value_' + requestable + ']' in sent_t: | |||
| provided_requestables[domain].append(requestable) | |||
| # if name was given in the task | |||
| for domain in goal.keys(): | |||
| # if name was provided for the user, the match is being done automatically | |||
| if 'name' in goal[domain]['informable']: | |||
| venue_offered[domain] = '[value_name]' | |||
| # special domains - entity does not need to be provided | |||
| if domain in ['taxi', 'police', 'hospital']: | |||
| venue_offered[domain] = '[value_name]' | |||
| if domain == 'train': | |||
| if not venue_offered[domain] and 'id' not in goal[domain][ | |||
| 'requestable']: | |||
| venue_offered[domain] = '[value_name]' | |||
| """ | |||
| Given all inform and requestable slots | |||
| we go through each domain from the user goal | |||
| and check whether right entity was provided and | |||
| all requestable slots were given to the user. | |||
| The dialogue is successful if that's the case for all domains. | |||
| """ | |||
| # HARD EVAL | |||
| stats = { | |||
| 'restaurant': [0, 0, 0], | |||
| 'hotel': [0, 0, 0], | |||
| 'attraction': [0, 0, 0], | |||
| 'train': [0, 0, 0], | |||
| 'taxi': [0, 0, 0], | |||
| 'hospital': [0, 0, 0], | |||
| 'police': [0, 0, 0] | |||
| } | |||
| match = 0 | |||
| success = 0 | |||
| # MATCH | |||
| for domain in goal.keys(): | |||
| match_stat = 0 | |||
| if domain in ['restaurant', 'hotel', 'attraction', 'train']: | |||
| goal_venues = self.reader.db.queryJsons( | |||
| domain, goal[domain]['informable'], return_name=True) | |||
| if type(venue_offered[domain] | |||
| ) is str and '_name' in venue_offered[domain]: | |||
| match += 1 | |||
| match_stat = 1 | |||
| elif len(venue_offered[domain]) > 0 and len( | |||
| set(venue_offered[domain]) & set(goal_venues)) > 0: | |||
| match += 1 | |||
| match_stat = 1 | |||
| else: | |||
| if '_name]' in venue_offered[domain]: | |||
| match += 1 | |||
| match_stat = 1 | |||
| stats[domain][0] = match_stat | |||
| stats[domain][2] = 1 | |||
| if soft_acc: | |||
| match = float(match) / len(goal.keys()) | |||
| else: | |||
| if match == len(goal.keys()): | |||
| match = 1.0 | |||
| else: | |||
| match = 0.0 | |||
| for domain in domains_in_goal: | |||
| for request in real_requestables[domain]: | |||
| counts[request + '_total'] += 1 | |||
| if request in provided_requestables[domain]: | |||
| counts[request + '_offer'] += 1 | |||
| # SUCCESS | |||
| if fout is not None: | |||
| for domain in domains_in_goal: | |||
| success_stat = 0 | |||
| domain_success = 0 | |||
| if len(real_requestables[domain]) == 0: | |||
| success += 1 | |||
| success_stat = 1 | |||
| stats[domain][1] = success_stat | |||
| continue | |||
| # if values in sentences are super set of requestables | |||
| for request in real_requestables[domain]: | |||
| if request in provided_requestables[domain]: | |||
| domain_success += 1 | |||
| if domain_success == len(real_requestables[domain]): | |||
| success += 1 | |||
| success_stat = 1 | |||
| stats[domain][1] = success_stat | |||
| # final eval | |||
| if soft_acc: | |||
| success = float(success) / len(real_requestables) | |||
| else: | |||
| if success >= len(real_requestables): | |||
| success = 1 | |||
| else: | |||
| success = 0 | |||
| else: | |||
| if match == 1.0: | |||
| for domain in domains_in_goal: | |||
| success_stat = 0 | |||
| domain_success = 0 | |||
| if len(real_requestables[domain]) == 0: | |||
| success += 1 | |||
| success_stat = 1 | |||
| stats[domain][1] = success_stat | |||
| continue | |||
| # if values in sentences are super set of requestables | |||
| for request in real_requestables[domain]: | |||
| if request in provided_requestables[domain]: | |||
| domain_success += 1 | |||
| if domain_success == len(real_requestables[domain]): | |||
| success += 1 | |||
| success_stat = 1 | |||
| stats[domain][1] = success_stat | |||
| # final eval | |||
| if soft_acc: | |||
| success = float(success) / len(real_requestables) | |||
| else: | |||
| if success >= len(real_requestables): | |||
| success = 1 | |||
| else: | |||
| success = 0 | |||
| if fout is not None and success == 0: | |||
| sample = { | |||
| dialog[0]['dial_id']: { | |||
| 'log': log, | |||
| 'real_requestables': real_requestables, | |||
| 'provided_requestables': provided_requestables | |||
| } | |||
| } | |||
| line = json.dumps(sample) | |||
| fout.write(line) | |||
| fout.write('\n') | |||
| return success, match, stats, counts | |||
| def _parseGoal(self, goal, true_goal, domain): | |||
| """Parses user goal into dictionary format.""" | |||
| goal[domain] = {} | |||
| goal[domain] = {'informable': {}, 'requestable': [], 'booking': []} | |||
| if 'info' in true_goal[domain]: | |||
| if domain == 'train': | |||
| # we consider dialogues only where train had to be booked! | |||
| if 'book' in true_goal[domain]: | |||
| goal[domain]['requestable'].append('reference') | |||
| if 'reqt' in true_goal[domain]: | |||
| if 'id' in true_goal[domain]['reqt']: | |||
| goal[domain]['requestable'].append('id') | |||
| else: | |||
| if 'reqt' in true_goal[domain]: | |||
| for s in true_goal[domain]['reqt']: # addtional requests: | |||
| if s in [ | |||
| 'phone', 'address', 'postcode', 'reference', | |||
| 'id' | |||
| ]: | |||
| # ones that can be easily delexicalized | |||
| goal[domain]['requestable'].append(s) | |||
| if 'book' in true_goal[domain]: | |||
| goal[domain]['requestable'].append('reference') | |||
| for s, v in true_goal[domain]['info'].items(): | |||
| s_, v_ = clean_slot_values(self.db_dir, domain, s, v) | |||
| if len(v_.split()) > 1: | |||
| v_ = ' '.join( | |||
| [token.text for token in self.reader.nlp(v_)]).strip() | |||
| goal[domain]['informable'][s_] = v_ | |||
| if 'book' in true_goal[domain]: | |||
| goal[domain]['booking'] = true_goal[domain]['book'] | |||
| return goal | |||
| class GenericEvaluator: | |||
| def __init__(self, reader): | |||
| self.reader = reader | |||
| self.metric_dict = {} | |||
| def pack_dial(self, data): | |||
| dials = {} | |||
| for turn in data: | |||
| dial_id = turn['dial_id'] | |||
| if dial_id not in dials: | |||
| dials[dial_id] = [] | |||
| dials[dial_id].append(turn) | |||
| return dials | |||
| def run_metrics(self, results): | |||
| raise ValueError('Please specify the evaluator first') | |||
| def bleu_metric(self, data, type='bleu'): | |||
| gen, truth = [], [] | |||
| for row in data: | |||
| gen.append(self.clean(row['resp_gen'])) | |||
| # gen.append(self.clean(row['resp'])) | |||
| truth.append(self.clean(row['resp'])) | |||
| wrap_generated = [[_] for _ in gen] | |||
| wrap_truth = [[_] for _ in truth] | |||
| sc = BLEUScorer().score(zip(wrap_generated, wrap_truth)) | |||
| return sc | |||
| def _normalize_constraint(self, | |||
| constraint, | |||
| ignore_dontcare=False, | |||
| intersection=True): | |||
| """ | |||
| Normalize belief span, e.g. delete repeated words | |||
| :param constraint - {'food': 'asian oritental', 'pricerange': 'cheap'} | |||
| :param intersection: if true, only keeps the words that appear in th ontology | |||
| we set intersection=True as in previous works | |||
| :returns: normalized constraint dict | |||
| e.g. - {'food': 'asian oritental', 'pricerange': 'cheap', 'area': ''} | |||
| """ | |||
| normalized = {} | |||
| for s in self.informable_slots: | |||
| normalized[s] = '' | |||
| for s, v in constraint.items(): | |||
| if ignore_dontcare and v == 'dontcare': | |||
| continue | |||
| if intersection and v != 'dontcare' and v not in self.entities_flat: | |||
| continue | |||
| normalized[s] = v | |||
| return normalized | |||
| def _normalize_act(self, aspn, intersection=False): | |||
| aspn_list = aspn.split('|') | |||
| normalized = {} | |||
| for i, v in enumerate(aspn_list): | |||
| seq = v.strip() | |||
| word_set = set() | |||
| for w in seq.split(): | |||
| if intersection: | |||
| if self.reader.act_order[i] == 'av': | |||
| if '[value' in w: | |||
| word_set.add(w) | |||
| else: | |||
| if w in self.requestable_slots: | |||
| word_set.add(w) | |||
| else: | |||
| word_set.add(w) | |||
| normalized[self.reader.act_order[i]] = word_set | |||
| return normalized | |||
| def tracker_metric(self, data, normalize=True): | |||
| # turn level metric | |||
| tp, fp, fn, db_correct = 0, 0, 0, 0 | |||
| goal_accr, slot_accr, total = 0, {}, 1e-8 | |||
| for s in self.informable_slots: | |||
| slot_accr[s] = 0 | |||
| for row in data: | |||
| if normalize: | |||
| gen = self._normalize_constraint(row['bspn_gen']) | |||
| truth = self._normalize_constraint(row['bspn']) | |||
| else: | |||
| gen = self._normalize_constraint( | |||
| row['bspn_gen'], intersection=False) | |||
| truth = self._normalize_constraint( | |||
| row['bspn'], intersection=False) | |||
| valid = 'thank' not in row['user'] and 'bye' not in row['user'] | |||
| if valid: | |||
| for slot, value in gen.items(): | |||
| if value in truth[slot]: | |||
| tp += 1 | |||
| else: | |||
| fp += 1 | |||
| for slot, value in truth.items(): | |||
| if value not in gen[slot]: | |||
| fn += 1 | |||
| if truth and valid: | |||
| total += 1 | |||
| for s in self.informable_slots: | |||
| if gen[s] == truth[s]: | |||
| slot_accr[s] += 1 | |||
| if gen == truth: | |||
| goal_accr += 1 | |||
| if row.get('db_gen') and row.get('db_match'): | |||
| if row['db_gen'] == row['db_match']: | |||
| db_correct += 1 | |||
| precision, recall = tp / (tp + fp + 1e-8), tp / (tp + fn + 1e-8) | |||
| f1 = 2 * precision * recall / (precision + recall + 1e-8) | |||
| goal_accr /= total | |||
| db_correct /= total | |||
| for s in slot_accr: | |||
| slot_accr[s] /= total | |||
| return precision, recall, f1, goal_accr, slot_accr, db_correct | |||
| def request_metric(self, data): | |||
| # dialog level metric | |||
| dials = self.pack_dial(data) | |||
| tp, fp, fn = 0, 0, 0 | |||
| for dial_id in dials: | |||
| truth_req, gen_req = set(), set() | |||
| dial = dials[dial_id] | |||
| for turn_num, turn in enumerate(dial): | |||
| resp_gen_token = self.clean(turn['resp_gen']).split() | |||
| resp_token = self.clean(turn['resp']).split() | |||
| for w in resp_gen_token: | |||
| if '[value_' in w and w.endswith( | |||
| ']') and w != '[value_name]': | |||
| gen_req.add(w[1:-1].split('_')[1]) | |||
| for w in resp_token: | |||
| if '[value_' in w and w.endswith( | |||
| ']') and w != '[value_name]': | |||
| truth_req.add(w[1:-1].split('_')[1]) | |||
| for req in gen_req: | |||
| if req in truth_req: | |||
| tp += 1 | |||
| else: | |||
| fp += 1 | |||
| for req in truth_req: | |||
| if req not in gen_req: | |||
| fn += 1 | |||
| precision, recall = tp / (tp + fp + 1e-8), tp / (tp + fn + 1e-8) | |||
| f1 = 2 * precision * recall / (precision + recall + 1e-8) | |||
| return f1, precision, recall | |||
| def act_metric(self, data): | |||
| # turn level metric | |||
| tp, fp, fn = { | |||
| 'all_s': 0, | |||
| 'all_v': 0 | |||
| }, { | |||
| 'all_s': 0, | |||
| 'all_v': 0 | |||
| }, { | |||
| 'all_s': 0, | |||
| 'all_v': 0 | |||
| } | |||
| for s in self.requestable_slots: | |||
| tp[s], fp[s], fn[s] = 0, 0, 0 | |||
| tp['[value_%s]' % s], fp['[value_%s]' % s], fn['[value_%s]' | |||
| % s] = 0, 0, 0 | |||
| for row in data: | |||
| gen = self._normalize_act(row['aspn_gen']) | |||
| truth = self._normalize_act(row['aspn']) | |||
| valid = 'thank' not in row['user'] and 'bye' not in row['user'] | |||
| if valid: | |||
| # how well the act decoder captures user's requests | |||
| for value in gen['av']: | |||
| if value in truth['av']: | |||
| tp['all_v'] += 1 | |||
| if tp.get(value): | |||
| tp[value] += 1 | |||
| else: | |||
| fp['all_v'] += 1 | |||
| if fp.get(value): | |||
| fp[value] += 1 | |||
| for value in truth['av']: | |||
| if value not in gen['av']: | |||
| fn['all_v'] += 1 | |||
| if fn.get(value): | |||
| fn[value] += 1 | |||
| # how accurately the act decoder predicts system's question | |||
| if 'as' not in gen: | |||
| continue | |||
| for slot in gen['as']: | |||
| if slot in truth['as']: | |||
| tp['all_s'] += 1 | |||
| if tp.get(slot): | |||
| tp[slot] += 1 | |||
| else: | |||
| fp['all_s'] += 1 | |||
| if fp.get(slot): | |||
| fp[slot] += 1 | |||
| for slot in truth['as']: | |||
| if slot not in gen['as']: | |||
| fn['all_s'] += 1 | |||
| if fn.get(slot): | |||
| fn[slot] += 1 | |||
| result = {} | |||
| for k, v in tp.items(): | |||
| precision, recall = tp[k] / (tp[k] + fp[k] + 1e-8), tp[k] / ( | |||
| tp[k] + fn[k] + 1e-8) | |||
| f1 = 2 * precision * recall / (precision + recall + 1e-8) | |||
| result[k] = [f1, precision, recall] | |||
| return result | |||
| """ | |||
| For the data preparation and evaluation on In-Car Assistant/CamRest, | |||
| we refer to the code of LABES (https://github.com/thu-spmi/LABES) | |||
| """ | |||
| class CamRestEvaluator(GenericEvaluator): | |||
| def __init__(self, reader): | |||
| super().__init__(reader) | |||
| self.entities_flat, self.entitiy_to_slot_dict = self.get_entities( | |||
| self.reader.ontology_path) | |||
| self.informable_slots = self.reader.otlg.informable_slots | |||
| self.requestable_slots = self.reader.otlg.requestable_slots | |||
| def run_metrics(self, results): | |||
| metrics = {} | |||
| bleu = self.bleu_metric(results) | |||
| p, r, f1, goal_acc, slot_acc, db_acc = self.tracker_metric(results) | |||
| match = self.match_metric(results) | |||
| req_f1, req_p, req_r = self.request_metric(results) | |||
| metrics['bleu'] = bleu | |||
| metrics['match'] = match | |||
| metrics['req_f1'] = req_f1 | |||
| metrics['joint_goal'] = goal_acc | |||
| metrics['slot_accu'] = slot_acc | |||
| metrics['slot-p/r/f1'] = (p, r, f1) | |||
| metrics['db_acc'] = db_acc | |||
| return metrics | |||
| def get_entities(self, entity_path): | |||
| entities_flat = [] | |||
| entitiy_to_slot_dict = {} | |||
| raw_entities = json.loads(open(entity_path).read().lower()) | |||
| for s in raw_entities['informable']: | |||
| entities_flat.extend(raw_entities['informable'][s]) | |||
| for v in raw_entities['informable'][s]: | |||
| entitiy_to_slot_dict[v] = s | |||
| return entities_flat, entitiy_to_slot_dict | |||
| def constraint_same(self, truth_cons, gen_cons): | |||
| if not truth_cons and not gen_cons: | |||
| return True | |||
| if not truth_cons or not gen_cons: | |||
| return False | |||
| return setsim(gen_cons, truth_cons) | |||
| def match_metric(self, data): | |||
| dials = self.pack_dial(data) | |||
| match, total = 0, 1e-8 | |||
| for dial_id in dials: | |||
| dial = dials[dial_id] | |||
| truth_cons, gen_cons = {'1': '', '2': '', '3': ''}, None | |||
| for turn_num, turn in enumerate(dial): | |||
| # find the last turn which the system provide an entity | |||
| if '[value' in turn['resp_gen']: | |||
| gen_cons = self._normalize_constraint( | |||
| turn['bspn_gen'], ignore_dontcare=True) | |||
| if '[value' in turn['resp']: | |||
| truth_cons = self._normalize_constraint( | |||
| turn['bspn'], ignore_dontcare=True) | |||
| if not gen_cons: | |||
| # if no entity is provided, choose the state of the last dialog turn | |||
| gen_cons = self._normalize_constraint( | |||
| dial[-1]['bspn_gen'], ignore_dontcare=True) | |||
| if list(truth_cons.values()) != ['', '', '']: | |||
| if gen_cons == truth_cons: | |||
| match += 1 | |||
| total += 1 | |||
| return match / total | |||
| def clean(self, resp): | |||
| # we use the same clean process as in Sequicity, SEDST, FSDM | |||
| # to ensure comparable results | |||
| resp = resp.replace(f'{self.reader.sos_r_token} ', '') | |||
| resp = resp.replace(f' {self.reader.eos_r_token}', '') | |||
| resp = f'{self.reader.sos_r_token} {resp} {self.reader.eos_r_token}' | |||
| for value, slot in self.entitiy_to_slot_dict.items(): | |||
| resp = utils.clean_replace(resp, value, '[value_%s]' % slot) | |||
| return resp | |||
| class KvretEvaluator(GenericEvaluator): | |||
| def __init__(self, reader): | |||
| super().__init__(reader) | |||
| self.entities_flat, self.entitiy_to_slot_dict = self.get_entities( | |||
| self.reader.ontology_path) | |||
| self.informable_slots = self.reader.otlg.informable_slots | |||
| self.requestable_slots = self.reader.otlg.requestable_slots | |||
| def run_metrics(self, results): | |||
| metrics = {} | |||
| bleu = self.bleu_metric(results) | |||
| p, r, f1, goal_acc, slot_acc, db_acc = self.tracker_metric( | |||
| results, normalize=True) | |||
| match = self.match_metric(results) | |||
| req_f1, req_p, req_r = self.request_metric(results) | |||
| metrics['bleu'] = bleu | |||
| metrics['match'] = match | |||
| metrics['req_f1'] = req_f1 | |||
| metrics['joint_goal'] = goal_acc | |||
| metrics['slot_accu'] = slot_acc | |||
| metrics['slot-p/r/f1'] = (p, r, f1) | |||
| metrics['db_acc'] = db_acc | |||
| return metrics | |||
| def _normalize_constraint(self, | |||
| constraint, | |||
| ignore_dontcare=False, | |||
| intersection=True): | |||
| """ | |||
| Normalize belief span, e.g. delete repeated words | |||
| :param constraint - {'food': 'asian oritental', 'pricerange': 'cheap'} | |||
| :param intersection: if true, only keeps the words that appear in th ontology | |||
| we set intersection=True as in previous works | |||
| :returns: normalized constraint dict | |||
| e.g. - {'food': 'asian oritental', 'pricerange': 'cheap', 'area': ''} | |||
| """ | |||
| junk = [ | |||
| 'good', 'great', 'quickest', 'shortest', 'route', 'week', | |||
| 'fastest', 'nearest', 'next', 'closest', 'way', 'mile', 'activity', | |||
| 'restaurant', 'appointment' | |||
| ] | |||
| normalized = {} | |||
| for s in self.informable_slots: | |||
| normalized[s] = '' | |||
| for s, v in constraint.items(): | |||
| for j in junk: | |||
| v = ' '.join(v.replace(j, '').split()) | |||
| if intersection and v not in self.entities_flat: | |||
| continue | |||
| if s in self.informable_slots: | |||
| normalized[s] = v | |||
| else: | |||
| # TODO only use slot (not domain) in s for matching !!! | |||
| pass | |||
| return normalized | |||
| def get_entities(self, entity_path): | |||
| entities_flat = [] | |||
| entitiy_to_slot_dict = {} | |||
| entitiy_to_slot_dict = self.reader.entity_dict | |||
| for s in entitiy_to_slot_dict: | |||
| if s not in entities_flat: | |||
| entities_flat.append(s) | |||
| return entities_flat, entitiy_to_slot_dict | |||
| def constraint_same(self, truth_cons, gen_cons): | |||
| if not truth_cons and not gen_cons: | |||
| return True | |||
| if not truth_cons or not gen_cons: | |||
| return False | |||
| return setsim(gen_cons, truth_cons) | |||
| def match_metric(self, data): | |||
| dials = self.pack_dial(data) | |||
| match, total = 0, 1e-8 | |||
| for dial_id in dials: | |||
| dial = dials[dial_id] | |||
| truth_cons, gen_cons = { | |||
| '1': '', | |||
| '2': '', | |||
| '3': '', | |||
| '4': '', | |||
| '5': '', | |||
| '6': '', | |||
| '7': '', | |||
| '8': '', | |||
| '9': '', | |||
| '10': '', | |||
| '11': '' | |||
| }, None | |||
| for turn_num, turn in enumerate(dial): | |||
| # find the last turn which the system provide an entity | |||
| if '[value' in turn['resp_gen']: | |||
| gen_cons = self._normalize_constraint( | |||
| turn['bspn_gen'], ignore_dontcare=True) | |||
| if '[value' in turn['resp']: | |||
| truth_cons = self._normalize_constraint( | |||
| turn['bspn'], ignore_dontcare=True) | |||
| if not gen_cons: | |||
| # if no entity is provided, choose the state of the last dialog turn | |||
| gen_cons = self._normalize_constraint( | |||
| dial[-1]['bspn_gen'], ignore_dontcare=True) | |||
| if list(truth_cons.values()) != [''] * 11: | |||
| gen_cons = [x for x in gen_cons.values() if x] | |||
| truth_cons = [x for x in truth_cons.values() if x] | |||
| if self.constraint_same(gen_cons, truth_cons): | |||
| match += 1 | |||
| total += 1 | |||
| return match / total | |||
| def clean(self, resp): | |||
| # we use the same clean process as in Sequicity, SEDST, FSDM | |||
| # to ensure comparable results | |||
| resp = resp.replace(f'{self.reader.sos_r_token} ', '') | |||
| resp = resp.replace(f' {self.reader.eos_r_token}', '') | |||
| resp = f'{self.reader.sos_r_token} {resp} {self.reader.eos_r_token}' | |||
| for value, slot in self.entitiy_to_slot_dict.items(): | |||
| resp = utils.clean_replace(resp, value, '[value_%s]' % slot) | |||
| return resp | |||
| @@ -15,27 +15,11 @@ from transformers.optimization import AdamW, get_linear_schedule_with_warmup | |||
| from modelscope.trainers.nlp.space.metrics.metrics_tracker import \ | |||
| MetricsTracker | |||
| from modelscope.utils.constant import ModelFile | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.nlp.space import ontology | |||
| def get_logger(log_path, name='default'): | |||
| logger = logging.getLogger(name) | |||
| logger.propagate = False | |||
| logger.setLevel(logging.DEBUG) | |||
| formatter = logging.Formatter('%(message)s') | |||
| sh = logging.StreamHandler(sys.stdout) | |||
| sh.setFormatter(formatter) | |||
| logger.addHandler(sh) | |||
| fh = logging.FileHandler(log_path, mode='w') | |||
| fh.setFormatter(formatter) | |||
| logger.addHandler(fh) | |||
| return logger | |||
| class Trainer(object): | |||
| def __init__(self, | |||
| @@ -51,15 +35,16 @@ class Trainer(object): | |||
| self.do_train = config.do_train | |||
| self.do_infer = config.do_infer | |||
| self.is_decreased_valid_metric = config.Trainer.valid_metric_name[ | |||
| 0] == '-' | |||
| self.valid_metric_name = config.Trainer.valid_metric_name[1:] | |||
| self.num_epochs = config.Trainer.num_epochs | |||
| # self.save_dir = config.Trainer.save_dir | |||
| self.log_steps = config.Trainer.log_steps | |||
| self.valid_steps = config.Trainer.valid_steps | |||
| self.save_checkpoint = config.Trainer.save_checkpoint | |||
| self.save_summary = config.Trainer.save_summary | |||
| if self.do_train: | |||
| self.is_decreased_valid_metric = config.Trainer.valid_metric_name[ | |||
| 0] == '-' | |||
| self.valid_metric_name = config.Trainer.valid_metric_name[1:] | |||
| self.num_epochs = config.Trainer.num_epochs | |||
| self.save_dir = config.Trainer.save_dir | |||
| self.log_steps = config.Trainer.log_steps | |||
| self.valid_steps = config.Trainer.valid_steps | |||
| self.save_checkpoint = config.Trainer.save_checkpoint | |||
| self.save_summary = config.Trainer.save_summary | |||
| self.lr = config.Model.lr | |||
| self.weight_decay = config.Model.weight_decay | |||
| self.batch_size = config.Trainer.batch_size | |||
| @@ -71,22 +56,21 @@ class Trainer(object): | |||
| self.optimizer = optimizer | |||
| self.model = model | |||
| self.func_model = self.model.module if self.gpu > 1 else self.model | |||
| self.func_model = self.model.module if self.gpu > 1 and config.use_gpu else self.model | |||
| self.reader = reader | |||
| self.evaluator = evaluator | |||
| self.tokenizer = reader.tokenizer | |||
| # if not os.path.exists(self.save_dir): | |||
| # os.makedirs(self.save_dir) | |||
| # self.logger = logger or get_logger(os.path.join(self.save_dir, "trainer.log"), "trainer") | |||
| self.logger = logger or get_logger('trainer.log', 'trainer') | |||
| self.logger = get_logger() | |||
| self.batch_metrics_tracker = MetricsTracker() | |||
| self.token_metrics_tracker = MetricsTracker() | |||
| self.best_valid_metric = float( | |||
| 'inf' if self.is_decreased_valid_metric else '-inf') | |||
| if self.do_train: | |||
| if not os.path.exists(self.save_dir): | |||
| os.makedirs(self.save_dir) | |||
| self.best_valid_metric = float( | |||
| 'inf' if self.is_decreased_valid_metric else '-inf') | |||
| self.epoch = 0 | |||
| def decode_generated_bspn_resp(self, generated): | |||
| @@ -248,9 +232,12 @@ class Trainer(object): | |||
| # Save current best model | |||
| if is_best: | |||
| best_model_file = os.path.join(self.save_dir, 'best.model') | |||
| best_model_file = os.path.join(self.save_dir, | |||
| ModelFile.TORCH_MODEL_BIN_FILE) | |||
| torch.save(self.model.state_dict(), best_model_file) | |||
| best_train_file = os.path.join(self.save_dir, 'best.train') | |||
| best_train_file = os.path.join( | |||
| self.save_dir, | |||
| '{}.train'.format(ModelFile.TORCH_MODEL_BIN_FILE)) | |||
| torch.save(train_state, best_train_file) | |||
| self.logger.info( | |||
| f"Saved best model state to '{best_model_file}' with new best valid metric " | |||
| @@ -324,8 +311,7 @@ class Trainer(object): | |||
| self.func_model.load_state_dict(model_state_dict) | |||
| self.logger.info( | |||
| f"Loaded model state from '{self.func_model.init_checkpoint}.model'" | |||
| ) | |||
| f"Loaded model state from '{self.func_model.init_checkpoint}'") | |||
| def _load_train_state(): | |||
| train_file = f'{self.func_model.init_checkpoint}.train' | |||
| @@ -558,19 +544,17 @@ class MultiWOZTrainer(Trainer): | |||
| generated_bs = outputs[0].cpu().numpy().tolist() | |||
| bspn_gen = self.decode_generated_bspn(generated_bs) | |||
| # check DB result | |||
| if self.reader.use_true_db_pointer: # To control whether current db is ground truth | |||
| if self.reader.use_true_db_pointer: | |||
| db = turn['db'] | |||
| else: | |||
| db_result = self.reader.bspan_to_DBpointer( | |||
| self.tokenizer.decode(bspn_gen), | |||
| turn['turn_domain']) | |||
| assert len(turn['db']) == 4 | |||
| book_result = turn['db'][2] | |||
| assert len(turn['db']) == 3 | |||
| assert isinstance(db_result, str) | |||
| db = \ | |||
| [self.reader.sos_db_id] + \ | |||
| self.tokenizer.convert_tokens_to_ids([db_result]) + \ | |||
| [book_result] + \ | |||
| [self.reader.eos_db_id] | |||
| prompt_id = self.reader.sos_a_id | |||
| @@ -636,7 +620,7 @@ class MultiWOZTrainer(Trainer): | |||
| score = 0.5 * (success + match) + bleu | |||
| # log results | |||
| metrics_message = 'match: %2.2f success: %2.2f bleu: %2.2f score: %.2f' %\ | |||
| metrics_message = 'match: %2.2f success: %2.2f bleu: %2.2f score: %.2f' % \ | |||
| (match, success, bleu, score) | |||
| message_prefix = f'[Infer][{self.epoch}]' | |||
| time_cost = f'TIME-{time.time() - begin_time:.3f}' | |||
| @@ -0,0 +1,333 @@ | |||
| import os | |||
| import re | |||
| from . import ontology | |||
| def clean_text_split_dot(text): | |||
| text = re.sub(r'([a-zT]+)\.([a-z])', r'\1 . \2', | |||
| text) # 'abc.xyz' -> 'abc . xyz' | |||
| text = re.sub(r'(\w+)\.\.? ', r'\1 . ', text) # if 'abc. ' -> 'abc . ' | |||
| return text | |||
| def clean_text(data_dir, text): | |||
| text = text.strip() | |||
| text = text.lower() | |||
| text = text.replace(u'’', "'") | |||
| text = text.replace(u'‘', "'") | |||
| text = text.replace(';', ',') | |||
| text = text.replace('"', ' ') | |||
| text = text.replace('/', ' and ') | |||
| text = text.replace("don't", "do n't") | |||
| text = clean_time(text) | |||
| baddata = { | |||
| r'c\.b (\d), (\d) ([a-z])\.([a-z])': r'cb\1\2\3\4', | |||
| 'c.b. 1 7 d.y': 'cb17dy', | |||
| 'c.b.1 7 d.y': 'cb17dy', | |||
| 'c.b 25, 9 a.q': 'cb259aq', | |||
| 'isc.b 25, 9 a.q': 'is cb259aq', | |||
| 'c.b2, 1 u.f': 'cb21uf', | |||
| 'c.b 1,2 q.a': 'cb12qa', | |||
| '0-122-336-5664': '01223365664', | |||
| 'postcodecb21rs': 'postcode cb21rs', | |||
| r'i\.d': 'id', | |||
| ' i d ': 'id', | |||
| 'Telephone:01223358966': 'Telephone: 01223358966', | |||
| 'depature': 'departure', | |||
| 'depearting': 'departing', | |||
| '-type': ' type', | |||
| r'b[\s]?&[\s]?b': 'bed and breakfast', | |||
| 'b and b': 'bed and breakfast', | |||
| r'guesthouse[s]?': 'guest house', | |||
| r'swimmingpool[s]?': 'swimming pool', | |||
| "wo n\'t": 'will not', | |||
| " \'d ": ' would ', | |||
| " \'m ": ' am ', | |||
| " \'re' ": ' are ', | |||
| " \'ll' ": ' will ', | |||
| " \'ve ": ' have ', | |||
| r'^\'': '', | |||
| r'\'$': '', | |||
| } | |||
| for tmpl, good in baddata.items(): | |||
| text = re.sub(tmpl, good, text) | |||
| text = re.sub(r'([a-zT]+)\.([a-z])', r'\1 . \2', | |||
| text) # 'abc.xyz' -> 'abc . xyz' | |||
| text = re.sub(r'(\w+)\.\.? ', r'\1 . ', text) # if 'abc. ' -> 'abc . ' | |||
| with open(os.path.join(data_dir, 'mapping.pair'), 'r') as fin: | |||
| for line in fin.readlines(): | |||
| fromx, tox = line.replace('\n', '').split('\t') | |||
| text = ' ' + text + ' ' | |||
| text = text.replace(' ' + fromx + ' ', ' ' + tox + ' ')[1:-1] | |||
| return text | |||
| def clean_time(utter): | |||
| utter = re.sub(r'(\d+) ([ap]\.?m)', lambda x: x.group(1) + x.group(2), | |||
| utter) # 9 am -> 9am | |||
| utter = re.sub(r'((?<!\d)\d:\d+)(am)?', r'0\1', utter) | |||
| utter = re.sub(r'((?<!\d)\d)am', r'0\1:00', utter) | |||
| utter = re.sub(r'((?<!\d)\d)pm', | |||
| lambda x: str(int(x.group(1)) + 12) + ':00', utter) | |||
| utter = re.sub(r'(\d+)(:\d+)pm', | |||
| lambda x: str(int(x.group(1)) + 12) + x.group(2), utter) | |||
| utter = re.sub(r'(\d+)a\.?m', r'\1', utter) | |||
| return utter | |||
| def clean_slot_values(data_dir, domain, slot, value): | |||
| value = clean_text(data_dir, value) | |||
| if not value: | |||
| value = '' | |||
| elif value == 'not mentioned': | |||
| value = '' | |||
| # value = 'not mentioned' # if in DST setting | |||
| elif domain == 'attraction': | |||
| if slot == 'name': | |||
| if value == 't': | |||
| value = '' | |||
| if value == 'trinity': | |||
| value = 'trinity college' | |||
| elif slot == 'area': | |||
| if value in ['town centre', 'cent', 'center', 'ce']: | |||
| value = 'centre' | |||
| elif value in [ | |||
| 'ely', 'in town', 'museum', 'norwich', 'same area as hotel' | |||
| ]: | |||
| value = '' | |||
| elif value in ['we']: | |||
| value = 'west' | |||
| elif slot == 'type': | |||
| if value in ['m', 'mus', 'musuem']: | |||
| value = 'museum' | |||
| elif value in ['art', 'architectural']: | |||
| value = 'architecture' | |||
| elif value in ['churches']: | |||
| value = 'church' | |||
| elif value in ['coll']: | |||
| value = 'college' | |||
| elif value in ['concert', 'concerthall']: | |||
| value = 'concert hall' | |||
| elif value in ['night club']: | |||
| value = 'nightclub' | |||
| elif value in [ | |||
| 'mutiple sports', 'mutliple sports', 'sports', 'galleria' | |||
| ]: | |||
| value = 'multiple sports' | |||
| elif value in ['ol', 'science', 'gastropub', 'la raza']: | |||
| value = '' | |||
| elif value in ['swimmingpool', 'pool']: | |||
| value = 'swimming pool' | |||
| elif value in ['fun']: | |||
| value = 'entertainment' | |||
| elif domain == 'hotel': | |||
| if slot == 'area': | |||
| if value in [ | |||
| 'cen', 'centre of town', 'near city center', 'center' | |||
| ]: | |||
| value = 'centre' | |||
| elif value in ['east area', 'east side']: | |||
| value = 'east' | |||
| elif value in ['in the north', 'north part of town']: | |||
| value = 'north' | |||
| elif value in ['we']: | |||
| value = 'west' | |||
| elif slot == 'day': | |||
| if value == 'monda': | |||
| value = 'monday' | |||
| elif value == 't': | |||
| value = 'tuesday' | |||
| elif slot == 'name': | |||
| if value == 'uni': | |||
| value = 'university arms hotel' | |||
| elif value == 'university arms': | |||
| value = 'university arms hotel' | |||
| elif value == 'acron': | |||
| value = 'acorn guest house' | |||
| elif value == 'ashley': | |||
| value = 'ashley hotel' | |||
| elif value == 'arbury lodge guesthouse': | |||
| value = 'arbury lodge guest house' | |||
| elif value == 'la': | |||
| value = 'la margherit' | |||
| elif value == 'no': | |||
| value = '' | |||
| elif slot == 'internet': | |||
| if value == 'does not': | |||
| value = 'no' | |||
| elif value in ['y', 'free', 'free internet']: | |||
| value = 'yes' | |||
| elif value in ['4']: | |||
| value = '' | |||
| elif slot == 'parking': | |||
| if value == 'n': | |||
| value = 'no' | |||
| elif value in ['free parking']: | |||
| value = 'yes' | |||
| elif value in ['y']: | |||
| value = 'yes' | |||
| elif slot in ['pricerange', 'price range']: | |||
| slot = 'pricerange' | |||
| if value == 'moderately': | |||
| value = 'moderate' | |||
| elif value in ['any']: | |||
| value = "do n't care" | |||
| elif value in ['any']: | |||
| value = "do n't care" | |||
| elif value in ['inexpensive']: | |||
| value = 'cheap' | |||
| elif value in ['2', '4']: | |||
| value = '' | |||
| elif slot == 'stars': | |||
| if value == 'two': | |||
| value = '2' | |||
| elif value == 'three': | |||
| value = '3' | |||
| elif value in [ | |||
| '4-star', '4 stars', '4 star', 'four star', 'four stars' | |||
| ]: | |||
| value = '4' | |||
| elif slot == 'type': | |||
| if value == '0 star rarting': | |||
| value = '' | |||
| elif value == 'guesthouse': | |||
| value = 'guest house' | |||
| elif value not in ['hotel', 'guest house', "do n't care"]: | |||
| value = '' | |||
| elif domain == 'restaurant': | |||
| if slot == 'area': | |||
| if value in [ | |||
| 'center', 'scentre', 'center of town', 'city center', | |||
| 'cb30aq', 'town center', 'centre of cambridge', | |||
| 'city centre' | |||
| ]: | |||
| value = 'centre' | |||
| elif value == 'west part of town': | |||
| value = 'west' | |||
| elif value == 'n': | |||
| value = 'north' | |||
| elif value in ['the south']: | |||
| value = 'south' | |||
| elif value not in [ | |||
| 'centre', 'south', "do n't care", 'west', 'east', 'north' | |||
| ]: | |||
| value = '' | |||
| elif slot == 'day': | |||
| if value == 'monda': | |||
| value = 'monday' | |||
| elif value == 't': | |||
| value = 'tuesday' | |||
| elif slot in ['pricerange', 'price range']: | |||
| slot = 'pricerange' | |||
| if value in ['moderately', 'mode', 'mo']: | |||
| value = 'moderate' | |||
| elif value in ['not']: | |||
| value = '' | |||
| elif value in ['inexpensive', 'ch']: | |||
| value = 'cheap' | |||
| elif slot == 'food': | |||
| if value == 'barbecue': | |||
| value = 'barbeque' | |||
| elif slot == 'pricerange': | |||
| if value == 'moderately': | |||
| value = 'moderate' | |||
| elif slot == 'time': | |||
| if value == '9:00': | |||
| value = '09:00' | |||
| elif value == '9:45': | |||
| value = '09:45' | |||
| elif value == '1330': | |||
| value = '13:30' | |||
| elif value == '1430': | |||
| value = '14:30' | |||
| elif value == '9:15': | |||
| value = '09:15' | |||
| elif value == '9:30': | |||
| value = '09:30' | |||
| elif value == '1830': | |||
| value = '18:30' | |||
| elif value == '9': | |||
| value = '09:00' | |||
| elif value == '2:00': | |||
| value = '14:00' | |||
| elif value == '1:00': | |||
| value = '13:00' | |||
| elif value == '3:00': | |||
| value = '15:00' | |||
| elif domain == 'taxi': | |||
| if slot in ['arriveBy', 'arrive by']: | |||
| slot = 'arriveby' | |||
| if value == '1530': | |||
| value = '15:30' | |||
| elif value == '15 minutes': | |||
| value = '' | |||
| elif slot in ['leaveAt', 'leave at']: | |||
| slot = 'leaveat' | |||
| if value == '1:00': | |||
| value = '01:00' | |||
| elif value == '21:4': | |||
| value = '21:04' | |||
| elif value == '4:15': | |||
| value = '04:15' | |||
| elif value == '5:45': | |||
| value = '05:45' | |||
| elif value == '0700': | |||
| value = '07:00' | |||
| elif value == '4:45': | |||
| value = '04:45' | |||
| elif value == '8:30': | |||
| value = '08:30' | |||
| elif value == '9:30': | |||
| value = '09:30' | |||
| value = value.replace('.', ':') | |||
| elif domain == 'train': | |||
| if slot in ['arriveBy', 'arrive by']: | |||
| slot = 'arriveby' | |||
| if value == '1': | |||
| value = '01:00' | |||
| elif value in ['does not care', 'doesnt care', "doesn't care"]: | |||
| value = "do n't care" | |||
| elif value == '8:30': | |||
| value = '08:30' | |||
| elif value == 'not 15:45': | |||
| value = '' | |||
| value = value.replace('.', ':') | |||
| elif slot == 'day': | |||
| if value == 'doesnt care' or value == "doesn't care": | |||
| value = "do n't care" | |||
| elif slot in ['leaveAt', 'leave at']: | |||
| slot = 'leaveat' | |||
| if value == '2:30': | |||
| value = '02:30' | |||
| elif value == '7:54': | |||
| value = '07:54' | |||
| elif value == 'after 5:45 pm': | |||
| value = '17:45' | |||
| elif value in [ | |||
| 'early evening', 'friday', 'sunday', 'tuesday', 'afternoon' | |||
| ]: | |||
| value = '' | |||
| elif value == '12': | |||
| value = '12:00' | |||
| elif value == '1030': | |||
| value = '10:30' | |||
| elif value == '1700': | |||
| value = '17:00' | |||
| elif value in [ | |||
| 'does not care', 'doesnt care', 'do nt care', | |||
| "doesn't care" | |||
| ]: | |||
| value = "do n't care" | |||
| value = value.replace('.', ':') | |||
| if value in ['dont care', "don't care", 'do nt care', "doesn't care"]: | |||
| value = "do n't care" | |||
| if ontology.normlize_slot_names.get(slot): | |||
| slot = ontology.normlize_slot_names[slot] | |||
| return slot, value | |||
| @@ -4,8 +4,11 @@ from collections import OrderedDict | |||
| import json | |||
| import numpy as np | |||
| from modelscope.utils.logger import get_logger | |||
| from . import ontology | |||
| logger = get_logger() | |||
| def max_lens(X): | |||
| lens = [len(X)] | |||
| @@ -117,8 +120,8 @@ class MultiWOZVocab(object): | |||
| def construct(self): | |||
| freq_dict_sorted = sorted( | |||
| self._freq_dict.keys(), key=lambda x: -self._freq_dict[x]) | |||
| print('Vocabulary size including oov: %d' % | |||
| (len(freq_dict_sorted) + len(self._idx2word))) | |||
| logger.info('Vocabulary size including oov: %d' % | |||
| (len(freq_dict_sorted) + len(self._idx2word))) | |||
| if len(freq_dict_sorted) + len(self._idx2word) < self.vocab_size: | |||
| logging.warning( | |||
| 'actual label set smaller than that configured: {}/{}'.format( | |||
| @@ -148,8 +151,9 @@ class MultiWOZVocab(object): | |||
| for w, idx in self._word2idx.items(): | |||
| self._idx2word[idx] = w | |||
| self.vocab_size_oov = len(self._idx2word) | |||
| print('vocab file loaded from "' + vocab_path + '"') | |||
| print('Vocabulary size including oov: %d' % (self.vocab_size_oov)) | |||
| logger.info('vocab file loaded from "' + vocab_path + '"') | |||
| logger.info('Vocabulary size including oov: %d' % | |||
| (self.vocab_size_oov)) | |||
| def save_vocab(self, vocab_path): | |||
| _freq_dict = OrderedDict( | |||
| @@ -0,0 +1,68 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import unittest | |||
| import torch | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.metainfo import Preprocessors, Trainers | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.utils.constant import DownloadMode, ModelFile | |||
| from modelscope.utils.test_utils import test_level | |||
| class TestDialogModelingTrainer(unittest.TestCase): | |||
| model_id = 'damo/nlp_space_pretrained-dialog-model' | |||
| output_dir = './dialog_fintune_result' | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_trainer_with_model_and_args(self): | |||
| # download data set | |||
| data_multiwoz = MsDataset.load( | |||
| 'MultiWoz2.0', download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) | |||
| data_dir = os.path.join( | |||
| data_multiwoz._hf_ds.config_kwargs['split_config']['train'], | |||
| 'data') | |||
| # download model | |||
| model_dir = snapshot_download(self.model_id) | |||
| # dialog finetune config | |||
| def cfg_modify_fn(cfg): | |||
| config = { | |||
| 'seed': 10, | |||
| 'gpu': 4, | |||
| 'use_data_distributed': False, | |||
| 'valid_metric_name': '-loss', | |||
| 'num_epochs': 60, | |||
| 'save_dir': self.output_dir, | |||
| 'token_loss': True, | |||
| 'batch_size': 32, | |||
| 'log_steps': 10, | |||
| 'valid_steps': 0, | |||
| 'save_checkpoint': True, | |||
| 'save_summary': False, | |||
| 'shuffle': True, | |||
| 'sort_pool_size': 0 | |||
| } | |||
| cfg.Trainer = config | |||
| cfg.use_gpu = torch.cuda.is_available() and config['gpu'] >= 1 | |||
| return cfg | |||
| # trainer config | |||
| kwargs = dict( | |||
| model_dir=model_dir, | |||
| cfg_name='gen_train_config.json', | |||
| data_dir=data_dir, | |||
| cfg_modify_fn=cfg_modify_fn) | |||
| trainer = build_trainer( | |||
| name=Trainers.dialog_modeling_trainer, default_args=kwargs) | |||
| trainer.train() | |||
| checkpoint_path = os.path.join(self.output_dir, | |||
| ModelFile.TORCH_MODEL_BIN_FILE) | |||
| assert os.path.exists(checkpoint_path) | |||
| trainer.evaluate(checkpoint_path=checkpoint_path) | |||