Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10061562master
| @@ -241,6 +241,7 @@ class Trainers(object): | |||||
| # nlp trainers | # nlp trainers | ||||
| bert_sentiment_analysis = 'bert-sentiment-analysis' | bert_sentiment_analysis = 'bert-sentiment-analysis' | ||||
| dialog_modeling_trainer = 'dialog-modeling-trainer' | |||||
| dialog_intent_trainer = 'dialog-intent-trainer' | dialog_intent_trainer = 'dialog-intent-trainer' | ||||
| nlp_base_trainer = 'nlp-base-trainer' | nlp_base_trainer = 'nlp-base-trainer' | ||||
| nlp_veco_trainer = 'nlp-veco-trainer' | nlp_veco_trainer = 'nlp-veco-trainer' | ||||
| @@ -1,6 +1,6 @@ | |||||
| from .configuration_space import SpaceConfig | from .configuration_space import SpaceConfig | ||||
| from .gen_unified_transformer import GenUnifiedTransformer | from .gen_unified_transformer import GenUnifiedTransformer | ||||
| from .generator import Generator as SpaceGenerator | |||||
| from .generator import SpaceGenerator | |||||
| from .intent_unified_transformer import IntentUnifiedTransformer | from .intent_unified_transformer import IntentUnifiedTransformer | ||||
| from .model_base import SpaceModelBase | from .model_base import SpaceModelBase | ||||
| from .modeling_space import (SpaceForDST, SpaceForMaskedLM, | from .modeling_space import (SpaceForDST, SpaceForMaskedLM, | ||||
| @@ -38,24 +38,24 @@ def gather(var, idx): | |||||
| return var | return var | ||||
| class Generator(object): | |||||
| class SpaceGenerator(object): | |||||
| """ Genrator class. """ | """ Genrator class. """ | ||||
| _registry = dict() | _registry = dict() | ||||
| @classmethod | @classmethod | ||||
| def register(cls, name): | def register(cls, name): | ||||
| Generator._registry[name] = cls | |||||
| SpaceGenerator._registry[name] = cls | |||||
| return | return | ||||
| @staticmethod | @staticmethod | ||||
| def by_name(name): | def by_name(name): | ||||
| return Generator._registry[name] | |||||
| return SpaceGenerator._registry[name] | |||||
| @staticmethod | @staticmethod | ||||
| def create(config, *args, **kwargs): | def create(config, *args, **kwargs): | ||||
| """ Create generator. """ | """ Create generator. """ | ||||
| generator_cls = Generator.by_name(config.Generator.generator) | |||||
| generator_cls = SpaceGenerator.by_name(config.Generator.generator) | |||||
| return generator_cls(config, *args, **kwargs) | return generator_cls(config, *args, **kwargs) | ||||
| def __init__(self, config, reader): | def __init__(self, config, reader): | ||||
| @@ -83,7 +83,7 @@ class Generator(object): | |||||
| raise NotImplementedError | raise NotImplementedError | ||||
| class BeamSearch(Generator): | |||||
| class BeamSearch(SpaceGenerator): | |||||
| """ BeamSearch generator. """ | """ BeamSearch generator. """ | ||||
| def __init__(self, config, reader): | def __init__(self, config, reader): | ||||
| @@ -41,7 +41,7 @@ class SpaceForDialogModeling(TorchModel): | |||||
| self.text_field = kwargs.pop( | self.text_field = kwargs.pop( | ||||
| 'text_field', | 'text_field', | ||||
| MultiWOZBPETextField(self.model_dir, config=self.config)) | |||||
| MultiWOZBPETextField(config=self.config, model_dir=self.model_dir)) | |||||
| self.generator = SpaceGenerator.create( | self.generator = SpaceGenerator.create( | ||||
| self.config, reader=self.text_field) | self.config, reader=self.text_field) | ||||
| self.model = SpaceModelBase.create( | self.model = SpaceModelBase.create( | ||||
| @@ -35,7 +35,7 @@ class DialogModelingPreprocessor(Preprocessor): | |||||
| self.config.use_gpu = self.config.use_gpu and torch.cuda.is_available() | self.config.use_gpu = self.config.use_gpu and torch.cuda.is_available() | ||||
| self.text_field = MultiWOZBPETextField( | self.text_field = MultiWOZBPETextField( | ||||
| self.model_dir, config=self.config) | |||||
| config=self.config, model_dir=self.model_dir) | |||||
| @type_assert(object, Dict) | @type_assert(object, Dict) | ||||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||||
| @@ -2,9 +2,11 @@ | |||||
| import os | import os | ||||
| import random | import random | ||||
| from asyncio import constants | |||||
| from collections import OrderedDict | from collections import OrderedDict | ||||
| from itertools import chain | from itertools import chain | ||||
| import json | |||||
| import numpy as np | import numpy as np | ||||
| from modelscope.preprocessors.space.tokenizer import Tokenizer | from modelscope.preprocessors.space.tokenizer import Tokenizer | ||||
| @@ -117,7 +119,8 @@ class BPETextField(object): | |||||
| return self.tokenizer.convert_tokens_to_ids([self.eos_d_token])[0] | return self.tokenizer.convert_tokens_to_ids([self.eos_d_token])[0] | ||||
| def __init__(self, config): | def __init__(self, config): | ||||
| self.gpu = 0 | |||||
| self.train, self.dev, self.test = [], [], [] | |||||
| self.gpu = config.Trainer.gpu | |||||
| self.tokenizer = None | self.tokenizer = None | ||||
| self.vocab = None | self.vocab = None | ||||
| self.db = None | self.db = None | ||||
| @@ -249,13 +252,9 @@ class BPETextField(object): | |||||
| for dial in data: | for dial in data: | ||||
| batch.append(dial) | batch.append(dial) | ||||
| if len(batch) == self.batch_size: | if len(batch) == self.batch_size: | ||||
| # print('batch size: %d, batch num +1'%(len(batch))) | |||||
| all_batches.append(batch) | all_batches.append(batch) | ||||
| batch = [] | batch = [] | ||||
| # if remainder > 1/2 batch_size, just put them in the previous batch, otherwise form a new batch | |||||
| # print('last batch size: %d, batch num +1'%(len(batch))) | |||||
| # if (len(batch) % len(cfg.cuda_device)) != 0: | |||||
| # batch = batch[:-(len(batch) % len(cfg.cuda_device))] | |||||
| # TODO deal with deleted data | # TODO deal with deleted data | ||||
| if self.gpu <= 1: | if self.gpu <= 1: | ||||
| if len(batch) > 0.5 * self.batch_size: | if len(batch) > 0.5 * self.batch_size: | ||||
| @@ -308,7 +307,7 @@ class BPETextField(object): | |||||
| class MultiWOZBPETextField(BPETextField): | class MultiWOZBPETextField(BPETextField): | ||||
| def __init__(self, model_dir, config): | |||||
| def __init__(self, config, **kwargs): | |||||
| super(MultiWOZBPETextField, self).__init__(config) | super(MultiWOZBPETextField, self).__init__(config) | ||||
| import spacy | import spacy | ||||
| @@ -327,8 +326,12 @@ class MultiWOZBPETextField(BPETextField): | |||||
| ) | ) | ||||
| self.nlp = spacy.load('en_core_web_sm') | self.nlp = spacy.load('en_core_web_sm') | ||||
| if config.do_train: | |||||
| db_dir = kwargs['data_dir'] | |||||
| else: | |||||
| db_dir = kwargs['model_dir'] | |||||
| self.db = MultiWozDB( | self.db = MultiWozDB( | ||||
| model_dir, { | |||||
| db_dir, { | |||||
| 'attraction': 'db/attraction_db_processed.json', | 'attraction': 'db/attraction_db_processed.json', | ||||
| 'hospital': 'db/hospital_db_processed.json', | 'hospital': 'db/hospital_db_processed.json', | ||||
| 'hotel': 'db/hotel_db_processed.json', | 'hotel': 'db/hotel_db_processed.json', | ||||
| @@ -337,14 +340,14 @@ class MultiWOZBPETextField(BPETextField): | |||||
| 'taxi': 'db/taxi_db_processed.json', | 'taxi': 'db/taxi_db_processed.json', | ||||
| 'train': 'db/train_db_processed.json', | 'train': 'db/train_db_processed.json', | ||||
| }) | }) | ||||
| self._build_vocab(model_dir) | |||||
| self._build_vocab(db_dir) | |||||
| special_tokens = [ | special_tokens = [ | ||||
| self.pad_token, self.bos_token, self.eos_token, self.unk_token | self.pad_token, self.bos_token, self.eos_token, self.unk_token | ||||
| ] | ] | ||||
| special_tokens.extend(self.add_sepcial_tokens()) | special_tokens.extend(self.add_sepcial_tokens()) | ||||
| self.tokenizer = Tokenizer( | self.tokenizer = Tokenizer( | ||||
| vocab_path=os.path.join(model_dir, ModelFile.VOCAB_FILE), | |||||
| vocab_path=os.path.join(kwargs['model_dir'], ModelFile.VOCAB_FILE), | |||||
| special_tokens=special_tokens, | special_tokens=special_tokens, | ||||
| tokenizer_type=config.BPETextField.tokenizer_type) | tokenizer_type=config.BPETextField.tokenizer_type) | ||||
| self.understand_ids = self.tokenizer.convert_tokens_to_ids( | self.understand_ids = self.tokenizer.convert_tokens_to_ids( | ||||
| @@ -352,6 +355,26 @@ class MultiWOZBPETextField(BPETextField): | |||||
| self.policy_ids = self.tokenizer.convert_tokens_to_ids( | self.policy_ids = self.tokenizer.convert_tokens_to_ids( | ||||
| self.policy_tokens) | self.policy_tokens) | ||||
| if config.do_train: | |||||
| test_list = [ | |||||
| line.strip().lower() for line in open( | |||||
| os.path.join(kwargs['data_dir'], 'testListFile.json'), | |||||
| 'r').readlines() | |||||
| ] | |||||
| dev_list = [ | |||||
| line.strip().lower() for line in open( | |||||
| os.path.join(kwargs['data_dir'], 'valListFile.json'), | |||||
| 'r').readlines() | |||||
| ] | |||||
| self.dev_files, self.test_files = {}, {} | |||||
| for fn in test_list: | |||||
| self.test_files[fn.replace('.json', '')] = 1 | |||||
| for fn in dev_list: | |||||
| self.dev_files[fn.replace('.json', '')] = 1 | |||||
| self._load_data(kwargs['data_dir']) | |||||
| return | return | ||||
| def get_ids(self, data: str): | def get_ids(self, data: str): | ||||
| @@ -414,7 +437,6 @@ class MultiWOZBPETextField(BPETextField): | |||||
| name_to_set = {'train': self.train, 'test': self.test, 'dev': self.dev} | name_to_set = {'train': self.train, 'test': self.test, 'dev': self.dev} | ||||
| dial = name_to_set[set_name] | dial = name_to_set[set_name] | ||||
| turn_bucket = self._bucket_by_turn(dial) | turn_bucket = self._bucket_by_turn(dial) | ||||
| # self._shuffle_turn_bucket(turn_bucket) | |||||
| all_batches = [] | all_batches = [] | ||||
| if set_name not in self.set_stats: | if set_name not in self.set_stats: | ||||
| @@ -433,19 +455,13 @@ class MultiWOZBPETextField(BPETextField): | |||||
| except Exception: | except Exception: | ||||
| log_str += 'turn num:%d, dial num: %d, batch num: %d last batch len: %d\n' % ( | log_str += 'turn num:%d, dial num: %d, batch num: %d last batch len: %d\n' % ( | ||||
| k, len(turn_bucket[k]), len(batches), 0.0) | k, len(turn_bucket[k]), len(batches), 0.0) | ||||
| # print("turn num:%d, dial num:v%d, batch num: %d, "%(k, len(turn_bucket[k]), len(batches))) | |||||
| num_training_steps += k * len(batches) | num_training_steps += k * len(batches) | ||||
| num_turns += k * len(turn_bucket[k]) | num_turns += k * len(turn_bucket[k]) | ||||
| num_dials += len(turn_bucket[k]) | num_dials += len(turn_bucket[k]) | ||||
| all_batches += batches | all_batches += batches | ||||
| log_str += 'total batch num: %d\n' % len(all_batches) | log_str += 'total batch num: %d\n' % len(all_batches) | ||||
| # print('total batch num: %d'%len(all_batches)) | |||||
| # print('dialog count: %d'%dia_count) | |||||
| # return all_batches | |||||
| # log stats | |||||
| # logging.info(log_str) | |||||
| # cfg.num_training_steps = num_training_steps * cfg.epoch_num | |||||
| self.set_stats[set_name][ | self.set_stats[set_name][ | ||||
| 'num_training_steps_per_epoch'] = num_training_steps # turn-level steps | 'num_training_steps_per_epoch'] = num_training_steps # turn-level steps | ||||
| self.set_stats[set_name]['num_turns'] = num_turns | self.set_stats[set_name]['num_turns'] = num_turns | ||||
| @@ -484,6 +500,71 @@ class MultiWOZBPETextField(BPETextField): | |||||
| self.vocab.load_vocab(vp) | self.vocab.load_vocab(vp) | ||||
| return self.vocab.vocab_size | return self.vocab.vocab_size | ||||
| def _load_data(self, data_dir, save_temp=True): | |||||
| """ | |||||
| load processed data and encode, or load already encoded data | |||||
| """ | |||||
| def load_data_from_resource(data_resource): | |||||
| data = json.loads( | |||||
| open( | |||||
| os.path.join(data_dir, data_resource), | |||||
| 'r', | |||||
| encoding='utf-8').read().lower()) | |||||
| train, dev, test = [], [], [] | |||||
| for fn, dial in data.items(): | |||||
| if '.json' in fn: | |||||
| fn = fn.replace('.json', '') | |||||
| if self.dev_files.get(fn): | |||||
| dev.append(self._get_encoded_data(fn, dial)) | |||||
| elif self.test_files.get(fn): | |||||
| test.append(self._get_encoded_data(fn, dial)) | |||||
| else: | |||||
| train.append(self._get_encoded_data(fn, dial)) | |||||
| return train, dev, test | |||||
| data_processed = 'new_db_se_blank_encoded_domain.data.json' | |||||
| data_resource = 'data_for_damd.json' | |||||
| if save_temp: # save encoded data | |||||
| # encoded: no sos, se_encoded: sos and eos | |||||
| encoded_file = os.path.join(data_dir, data_processed) | |||||
| if os.path.exists(encoded_file): | |||||
| logger.info( | |||||
| 'Reading encoded data from {}'.format(encoded_file)) | |||||
| self.data = json.loads( | |||||
| open( | |||||
| os.path.join(data_dir, data_resource), | |||||
| 'r', | |||||
| encoding='utf-8').read().lower()) | |||||
| encoded_data = json.loads( | |||||
| open(encoded_file, 'r', encoding='utf-8').read()) | |||||
| self.train = encoded_data['train'] | |||||
| self.dev = encoded_data['dev'] | |||||
| self.test = encoded_data['test'] | |||||
| else: | |||||
| logger.info( | |||||
| 'Encoding data now and save the encoded data in {}'.format( | |||||
| encoded_file)) | |||||
| # not exists, encode data and save | |||||
| self.train, self.dev, self.test = load_data_from_resource( | |||||
| data_resource) | |||||
| # save encoded data | |||||
| encoded_data = { | |||||
| 'train': self.train, | |||||
| 'dev': self.dev, | |||||
| 'test': self.test | |||||
| } | |||||
| json.dump(encoded_data, open(encoded_file, 'w'), indent=2) | |||||
| else: # directly read processed data and encode | |||||
| self.train, self.dev, self.test = load_data_from_resource( | |||||
| data_resource) | |||||
| random.seed(10) | |||||
| random.shuffle(self.train) | |||||
| logger.info('train size:{}, dev size:{}, test size:{}'.format( | |||||
| len(self.train), len(self.dev), len(self.test))) | |||||
| def _get_convert_str(self, sent): | def _get_convert_str(self, sent): | ||||
| assert isinstance(sent, str) | assert isinstance(sent, str) | ||||
| return ' '.join([ | return ' '.join([ | ||||
| @@ -491,14 +572,65 @@ class MultiWOZBPETextField(BPETextField): | |||||
| for tok in sent.split() | for tok in sent.split() | ||||
| ]) | ]) | ||||
| def _get_encoded_data(self, fn, dial): | |||||
| encoded_dial = [] | |||||
| for idx, t in enumerate(dial['log']): # tokenize to list of ids | |||||
| enc = {} | |||||
| enc['dial_id'] = fn | |||||
| enc_info_list = [ | |||||
| ('user', self.sos_u_id, 'user', self.eos_u_id), | |||||
| ('usdx', self.sos_u_id, 'user', self.eos_u_id), | |||||
| ('resp', self.sos_r_id, 'resp', self.eos_r_id), | |||||
| ('bspn', self.sos_b_id, 'constraint', self.eos_b_id), | |||||
| ('bsdx', self.sos_b_id, 'cons_delex', self.eos_b_id), | |||||
| ('aspn', self.sos_a_id, 'sys_act', self.eos_a_id) | |||||
| ] | |||||
| for enc_key, start_token, item_key, end_token in enc_info_list: | |||||
| enc[enc_key] = [ | |||||
| start_token | |||||
| ] + self.tokenizer.convert_tokens_to_ids( | |||||
| self.tokenizer.tokenize( | |||||
| self._get_convert_str(t[item_key]))) + [end_token] | |||||
| enc['turn_num'] = t['turn_num'] | |||||
| if idx > 0 and t['turn_domain'] == '[general]': | |||||
| enc['dspn'] = encoded_dial[idx - 1]['dspn'] | |||||
| enc['pointer'] = encoded_dial[idx - 1]['pointer'][:4] + [ | |||||
| int(i) for i in t['pointer'].split(',') | |||||
| ][-2:] | |||||
| enc['turn_domain'] = encoded_dial[idx - 1]['turn_domain'] | |||||
| enc['db'] = encoded_dial[idx - 1]['db'] | |||||
| else: | |||||
| if t['turn_domain'] == '[general]': | |||||
| assert not t['constraint'], f'{fn}-{idx}' | |||||
| enc['dspn'] = [ | |||||
| self.sos_d_id | |||||
| ] + self.tokenizer.convert_tokens_to_ids( | |||||
| self.tokenizer.tokenize( | |||||
| self._get_convert_str( | |||||
| t['turn_domain']))) + [self.eos_d_id] | |||||
| enc['pointer'] = [int(i) for i in t['pointer'].split(',')] | |||||
| enc['turn_domain'] = t['turn_domain'].split() | |||||
| db_pointer = self.bspan_to_DBpointer(t['constraint'], | |||||
| t['turn_domain'].split()) | |||||
| enc['db'] = [ | |||||
| self.sos_db_id | |||||
| ] + self.tokenizer.convert_tokens_to_ids( | |||||
| self.tokenizer.tokenize( | |||||
| self._get_convert_str(db_pointer))) + [self.eos_db_id] | |||||
| encoded_dial.append(enc) | |||||
| return encoded_dial | |||||
| def bspan_to_DBpointer(self, bspan, turn_domain): | def bspan_to_DBpointer(self, bspan, turn_domain): | ||||
| constraint_dict = self.bspan_to_constraint_dict(bspan) | constraint_dict = self.bspan_to_constraint_dict(bspan) | ||||
| # print(constraint_dict) | |||||
| matnums = self.db.get_match_num(constraint_dict) | matnums = self.db.get_match_num(constraint_dict) | ||||
| match_dom = turn_domain[0] if len(turn_domain) == 1 else turn_domain[1] | match_dom = turn_domain[0] if len(turn_domain) == 1 else turn_domain[1] | ||||
| match_dom = match_dom[1:-1] if match_dom.startswith('[') else match_dom | match_dom = match_dom[1:-1] if match_dom.startswith('[') else match_dom | ||||
| match = matnums[match_dom] | match = matnums[match_dom] | ||||
| # vector = self.db.addDBPointer(match_dom, match) | |||||
| vector = self.db.addDBIndicator(match_dom, match) | vector = self.db.addDBIndicator(match_dom, match) | ||||
| return vector | return vector | ||||
| @@ -691,3 +823,67 @@ class MultiWOZBPETextField(BPETextField): | |||||
| inputs['labels'] = [context] # use previous turn | inputs['labels'] = [context] # use previous turn | ||||
| return inputs, prompt_id | return inputs, prompt_id | ||||
| def restore(self, resp, domain, constraint_dict, mat_ents): | |||||
| restored = resp | |||||
| restored = restored.replace('[value_reference]', '53022') | |||||
| restored = restored.replace('[value_car]', 'BMW') | |||||
| for d in domain: | |||||
| constraint = constraint_dict.get(d, None) | |||||
| if constraint: | |||||
| replace_res_list = [('stay', '[value_stay]'), | |||||
| ('day', '[value_day]'), | |||||
| ('people', '[value_people]'), | |||||
| ('time', '[value_time]'), | |||||
| ('type', '[value_type]')] | |||||
| for key, value_key in replace_res_list: | |||||
| if key in constraint: | |||||
| restored = restored.replace(value_key, constraint[key]) | |||||
| if d in mat_ents and len(mat_ents[d]) == 0: | |||||
| for s in constraint: | |||||
| if s == 'pricerange' and d in [ | |||||
| 'hotel', 'restaurant' | |||||
| ] and 'price]' in restored: | |||||
| restored = restored.replace( | |||||
| '[value_price]', constraint['pricerange']) | |||||
| if s + ']' in restored: | |||||
| restored = restored.replace( | |||||
| '[value_%s]' % s, constraint[s]) | |||||
| if '[value_choice' in restored and mat_ents.get(d): | |||||
| restored = restored.replace('[value_choice]', | |||||
| str(len(mat_ents[d]))) | |||||
| if '[value_choice' in restored: | |||||
| restored = restored.replace('[value_choice]', '3') | |||||
| try: | |||||
| ent = mat_ents.get(domain[-1], []) | |||||
| if ent: | |||||
| ent = ent[0] | |||||
| for t in restored.split(): | |||||
| if '[value' in t: | |||||
| slot = t[7:-1] | |||||
| if ent.get(slot): | |||||
| if domain[-1] == 'hotel' and slot == 'price': | |||||
| slot = 'pricerange' | |||||
| restored = restored.replace(t, ent[slot]) | |||||
| elif slot == 'price': | |||||
| if ent.get('pricerange'): | |||||
| restored = restored.replace( | |||||
| t, ent['pricerange']) | |||||
| else: | |||||
| logger.info(restored, domain) | |||||
| except Exception: | |||||
| logger.error(resp) | |||||
| logger.error(restored) | |||||
| quit() | |||||
| restored = restored.replace('[value_phone]', '62781111') | |||||
| restored = restored.replace('[value_postcode]', 'CG9566') | |||||
| restored = restored.replace('[value_address]', 'Parkside, Cambridge') | |||||
| return restored | |||||
| @@ -0,0 +1,130 @@ | |||||
| import os | |||||
| import time | |||||
| from typing import Callable, Dict, Optional, Tuple, Union | |||||
| import numpy as np | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.models.nlp.space.model.generator import SpaceGenerator | |||||
| from modelscope.models.nlp.space.model.model_base import SpaceModelBase | |||||
| from modelscope.preprocessors.space.fields.gen_field import \ | |||||
| MultiWOZBPETextField | |||||
| from modelscope.trainers.base import BaseTrainer | |||||
| from modelscope.trainers.builder import TRAINERS | |||||
| from modelscope.trainers.nlp.space.eval import MultiWOZEvaluator | |||||
| from modelscope.trainers.nlp.space.trainer.gen_trainer import MultiWOZTrainer | |||||
| from modelscope.utils.config import Config, ModelFile | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| def setup_seed(seed: int): | |||||
| import random | |||||
| import torch | |||||
| torch.manual_seed(seed) | |||||
| torch.cuda.manual_seed_all(seed) | |||||
| np.random.seed(seed) | |||||
| random.seed(seed) | |||||
| torch.backends.cudnn.deterministic = True | |||||
| @TRAINERS.register_module(module_name=Trainers.dialog_modeling_trainer) | |||||
| class DialogModelingTrainer(BaseTrainer): | |||||
| def __init__(self, | |||||
| cfg_file: Optional[str] = None, | |||||
| cfg_modify_fn: Optional[Callable] = None, | |||||
| *args, | |||||
| **kwargs): | |||||
| super().__init__(os.path.join(kwargs['model_dir'], kwargs['cfg_name'])) | |||||
| self.cfg_modify_fn = cfg_modify_fn | |||||
| self.cfg = self.rebuild_config(self.cfg) | |||||
| setup_seed(self.cfg.Trainer.seed) | |||||
| # set reader and evaluator | |||||
| self.bpe = MultiWOZBPETextField(self.cfg, **kwargs) | |||||
| self.cfg.Model.num_token_embeddings = self.bpe.vocab_size | |||||
| self.cfg.Model.num_turn_embeddings = self.bpe.max_ctx_turn + 1 | |||||
| if 'work_dir' in kwargs: | |||||
| self.cfg.Trainer.save_dir = kwargs['work_dir'] | |||||
| else: | |||||
| self.cfg.Trainer.save_dir = './default_save_dir' | |||||
| # set data and data status | |||||
| self.train_data = self.bpe.get_batches('train') | |||||
| self.dev_data = self.bpe.get_batches('dev') | |||||
| self.evaluator = MultiWOZEvaluator(reader=self.bpe, **kwargs) | |||||
| # set generator | |||||
| self.generator = SpaceGenerator.create(self.cfg, reader=self.bpe) | |||||
| self._load_model(**kwargs) | |||||
| def _load_model(self, **kwargs): | |||||
| def to_tensor(array): | |||||
| """ | |||||
| numpy array -> tensor | |||||
| """ | |||||
| import torch | |||||
| array = torch.tensor(array) | |||||
| return array.cuda( | |||||
| ) if self.cfg.use_gpu and torch.cuda.is_available() else array | |||||
| # construct model | |||||
| if 'model' in kwargs: | |||||
| self.model = kwargs['model'] | |||||
| else: | |||||
| self.model = SpaceModelBase.create( | |||||
| kwargs['model_dir'], | |||||
| self.cfg, | |||||
| reader=self.bpe, | |||||
| generator=self.generator) | |||||
| import torch | |||||
| # multi-gpu | |||||
| if self.cfg.Trainer.gpu > 1 and torch.cuda.device_count() > 1: | |||||
| self.model = torch.nn.DataParallel(self.model) | |||||
| # construct trainer | |||||
| self.trainer = MultiWOZTrainer( | |||||
| self.model, | |||||
| to_tensor, | |||||
| self.cfg, | |||||
| reader=self.bpe, | |||||
| evaluator=self.evaluator) | |||||
| self.trainer.set_optimizers() | |||||
| # load model, optimizer and lr_scheduler | |||||
| self.trainer.load() | |||||
| def rebuild_config(self, cfg: Config): | |||||
| if self.cfg_modify_fn is not None: | |||||
| return self.cfg_modify_fn(cfg) | |||||
| return cfg | |||||
| def train(self, *args, **kwargs): | |||||
| logger.info('Train') | |||||
| self.trainer.train(train_data=self.train_data, dev_data=self.dev_data) | |||||
| def evaluate(self, | |||||
| checkpoint_path: Optional[str] = None, | |||||
| *args, | |||||
| **kwargs) -> Dict[str, float]: | |||||
| logger.info('Evaluate') | |||||
| self.cfg.do_infer = True | |||||
| # get best checkpoint path | |||||
| pos = checkpoint_path.rfind('/') | |||||
| checkpoint_name = checkpoint_path[pos + 1:] | |||||
| checkpoint_dir = checkpoint_path[:pos] | |||||
| assert checkpoint_name == ModelFile.TORCH_MODEL_BIN_FILE | |||||
| kwargs['model_dir'] = checkpoint_dir | |||||
| self._load_model(**kwargs) | |||||
| self.trainer.infer(data_type='test') | |||||
| @@ -0,0 +1,952 @@ | |||||
| # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||||
| # Copyright from https://github.com/thu-spmi/LABES | |||||
| # Copyright from https://github.com/TonyNemo/UBAR-MultiWOZ | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import math | |||||
| from collections import Counter | |||||
| import json | |||||
| import numpy as np | |||||
| from nltk.util import ngrams | |||||
| from sklearn.metrics import f1_score | |||||
| from modelscope.utils.nlp.space import ontology, utils | |||||
| from modelscope.utils.nlp.space.clean_dataset import clean_slot_values | |||||
| def similar(a, b): | |||||
| return a == b or a in b or b in a or a.split()[0] == b.split( | |||||
| )[0] or a.split()[-1] == b.split()[-1] | |||||
| def setsub(a, b): | |||||
| junks_a = [] | |||||
| useless_constraint = [ | |||||
| 'temperature', 'week', 'est ', 'quick', 'reminder', 'near' | |||||
| ] | |||||
| for i in a: | |||||
| flg = False | |||||
| for j in b: | |||||
| if similar(i, j): | |||||
| flg = True | |||||
| if not flg: | |||||
| junks_a.append(i) | |||||
| for junk in junks_a: | |||||
| flg = False | |||||
| for item in useless_constraint: | |||||
| if item in junk: | |||||
| flg = True | |||||
| if not flg: | |||||
| return False | |||||
| return True | |||||
| def setsim(a, b): | |||||
| a, b = set(a), set(b) | |||||
| return setsub(a, b) and setsub(b, a) | |||||
| def DA_evaluate(preds, labels): | |||||
| preds = np.array(preds) | |||||
| labels = np.array(labels) | |||||
| results = {} | |||||
| for avg_name in ['micro']: | |||||
| my_f1_score = f1_score(y_true=labels, y_pred=preds, average=avg_name) | |||||
| results['f1_{}'.format(avg_name)] = my_f1_score | |||||
| return results | |||||
| class BLEUScorer(object): | |||||
| # BLEU score calculator via GentScorer interface | |||||
| # it calculates the BLEU-4 by taking the entire corpus in | |||||
| # Calulate based multiple candidates against multiple references | |||||
| def __init__(self): | |||||
| pass | |||||
| def score(self, parallel_corpus): | |||||
| # containers | |||||
| count = [0, 0, 0, 0] | |||||
| clip_count = [0, 0, 0, 0] | |||||
| r = 0 | |||||
| c = 0 | |||||
| weights = [0.25, 0.25, 0.25, 0.25] | |||||
| # accumulate ngram statistics | |||||
| for hyps, refs in parallel_corpus: | |||||
| hyps = [hyp.split() for hyp in hyps] | |||||
| refs = [ref.split() for ref in refs] | |||||
| for hyp in hyps: | |||||
| for i in range(4): | |||||
| # accumulate ngram counts | |||||
| hypcnts = Counter(ngrams(hyp, i + 1)) | |||||
| cnt = sum(hypcnts.values()) | |||||
| count[i] += cnt | |||||
| # compute clipped counts | |||||
| max_counts = {} | |||||
| for ref in refs: | |||||
| refcnts = Counter(ngrams(ref, i + 1)) | |||||
| for ng in hypcnts: | |||||
| max_counts[ng] = max( | |||||
| max_counts.get(ng, 0), refcnts[ng]) | |||||
| clipcnt = \ | |||||
| dict((ng, min(count, max_counts[ng])) for ng, count in hypcnts.items()) | |||||
| clip_count[i] += sum(clipcnt.values()) | |||||
| # accumulate r & c | |||||
| bestmatch = [1000, 1000] | |||||
| for ref in refs: | |||||
| if bestmatch[0] == 0: | |||||
| break | |||||
| diff = abs(len(ref) - len(hyp)) | |||||
| if diff < bestmatch[0]: | |||||
| bestmatch[0] = diff | |||||
| bestmatch[1] = len(ref) | |||||
| r += bestmatch[1] | |||||
| c += len(hyp) | |||||
| # computing bleu score | |||||
| p0 = 1e-7 | |||||
| bp = \ | |||||
| 1 if c > r else math.exp(1 - float(r) / float(c)) | |||||
| p_ns = \ | |||||
| [float(clip_count[i]) / float(count[i] + p0) + p0 for i in range(4)] | |||||
| s = \ | |||||
| math.fsum(w * math.log(p_n) for w, p_n in zip(weights, p_ns) if p_n) | |||||
| bleu = bp * math.exp(s) | |||||
| return bleu * 100 | |||||
| """" | |||||
| For the data preparation and evaluation on MultiWOZ2.0/2.1, | |||||
| we refer to the code of UBAR (https://github.com/TonyNemo/UBAR-MultiWOZ) | |||||
| """ | |||||
| class MultiWOZEvaluator(object): | |||||
| def __init__(self, reader, **kwargs): | |||||
| self.reader = reader | |||||
| self.domains = ontology.all_domains | |||||
| self.all_data = self.reader.data | |||||
| self.test_data = self.reader.test | |||||
| self.bleu_scorer = BLEUScorer() | |||||
| self.all_info_slot = [] | |||||
| for d, s_list in ontology.informable_slots.items(): | |||||
| for s in s_list: | |||||
| self.all_info_slot.append(d + '-' + s) | |||||
| # only evaluate these slots for dialog success | |||||
| self.requestables = ['phone', 'address', 'postcode', 'reference', 'id'] | |||||
| self.db_dir = kwargs['data_dir'] | |||||
| def pack_dial(self, data): | |||||
| dials = {} | |||||
| for turn in data: | |||||
| dial_id = turn['dial_id'] | |||||
| if dial_id not in dials: | |||||
| dials[dial_id] = [] | |||||
| dials[dial_id].append(turn) | |||||
| return dials | |||||
| def validation_metric(self, data, fout=None): | |||||
| bleu = self.bleu_metric(data) | |||||
| # accu_single_dom, accu_multi_dom, multi_dom_num = self.domain_eval(data) | |||||
| success, match, req_offer_counts, dial_num = \ | |||||
| self.context_to_response_eval(data, same_eval_as_cambridge=True, fout=fout) | |||||
| return bleu, success, match | |||||
| def bleu_metric(self, data, eval_dial_list=None): | |||||
| gen, truth = [], [] | |||||
| for row in data: | |||||
| if eval_dial_list and row[ | |||||
| 'dial_id'] + '.json' not in eval_dial_list: | |||||
| continue | |||||
| gen.append(row['resp_gen']) | |||||
| truth.append(row['resp']) | |||||
| wrap_generated = [[_] for _ in gen] | |||||
| wrap_truth = [[_] for _ in truth] | |||||
| if gen and truth: | |||||
| try: | |||||
| sc = self.bleu_scorer.score(zip(wrap_generated, wrap_truth)) | |||||
| except Exception: | |||||
| sc = 0.0 | |||||
| else: | |||||
| sc = 0.0 | |||||
| return sc | |||||
| def context_to_response_eval(self, | |||||
| data, | |||||
| eval_dial_list=None, | |||||
| same_eval_as_cambridge=False, | |||||
| fout=None): | |||||
| dials = self.pack_dial(data) | |||||
| counts = {} | |||||
| for req in self.requestables: | |||||
| counts[req + '_total'] = 0 | |||||
| counts[req + '_offer'] = 0 | |||||
| dial_num, successes, matches = 0, 0, 0 | |||||
| for dial_id in dials: | |||||
| if eval_dial_list and dial_id + '.json' not in eval_dial_list: | |||||
| continue | |||||
| dial = dials[dial_id] | |||||
| reqs = {} | |||||
| goal = {} | |||||
| if '.json' not in dial_id and '.json' in list( | |||||
| self.all_data.keys())[0]: | |||||
| dial_id = dial_id + '.json' | |||||
| for domain in ontology.all_domains: | |||||
| if self.all_data[dial_id]['goal'].get(domain): | |||||
| true_goal = self.all_data[dial_id]['goal'] | |||||
| goal = self._parseGoal(goal, true_goal, domain) | |||||
| for domain in goal.keys(): | |||||
| reqs[domain] = goal[domain]['requestable'] | |||||
| success, match, stats, counts = \ | |||||
| self._evaluateGeneratedDialogue(dial, goal, reqs, counts, | |||||
| same_eval_as_cambridge=same_eval_as_cambridge, fout=fout) | |||||
| successes += success | |||||
| matches += match | |||||
| dial_num += 1 | |||||
| succ_rate = successes / (float(dial_num) + 1e-10) * 100 | |||||
| match_rate = matches / (float(dial_num) + 1e-10) * 100 | |||||
| return succ_rate, match_rate, counts, dial_num | |||||
| def _evaluateGeneratedDialogue(self, | |||||
| dialog, | |||||
| goal, | |||||
| real_requestables, | |||||
| counts, | |||||
| soft_acc=False, | |||||
| same_eval_as_cambridge=False, | |||||
| fout=None): | |||||
| """Evaluates the dialogue created by the model. | |||||
| First we load the user goal of the dialogue, then for each turn | |||||
| generated by the system we look for key-words. | |||||
| For the Inform rate we look whether the entity was proposed. | |||||
| For the Success rate we look for requestables slots""" | |||||
| # for computing corpus success | |||||
| requestables = self.requestables | |||||
| # CHECK IF MATCH HAPPENED | |||||
| provided_requestables = {} | |||||
| venue_offered = {} | |||||
| domains_in_goal = [] | |||||
| log = [] | |||||
| bspans = {} | |||||
| for domain in goal.keys(): | |||||
| venue_offered[domain] = [] | |||||
| provided_requestables[domain] = [] | |||||
| domains_in_goal.append(domain) | |||||
| for t, turn in enumerate(dialog): | |||||
| if t == 0: | |||||
| continue | |||||
| if fout is not None: | |||||
| log.append({ | |||||
| 'turn_num': turn['turn_num'], | |||||
| 'turn_domain': turn['dspn'], | |||||
| 'user': turn['user'], | |||||
| 'aspn': turn['aspn'], | |||||
| 'aspn_gen': turn['aspn_gen'], | |||||
| 'resp': turn['resp'], | |||||
| 'resp_gen': turn['resp_gen'], | |||||
| 'pointer': turn['pointer'], | |||||
| }) | |||||
| sent_t = turn['resp_gen'] | |||||
| for domain in goal.keys(): | |||||
| # for computing success | |||||
| if same_eval_as_cambridge: | |||||
| # [restaurant_name], [hotel_name] instead of [value_name] | |||||
| if self.reader.use_true_domain_for_ctr_eval: | |||||
| dom_pred = [d[1:-1] for d in turn['dspn'].split()] | |||||
| else: | |||||
| dom_pred = [d[1:-1] for d in turn['dspn_gen'].split()] | |||||
| if domain not in dom_pred: # fail | |||||
| continue | |||||
| if '[value_name]' in sent_t or '[value_id]' in sent_t: | |||||
| if domain in [ | |||||
| 'restaurant', 'hotel', 'attraction', 'train' | |||||
| ]: | |||||
| # HERE YOU CAN PUT YOUR BELIEF STATE ESTIMATION | |||||
| if not self.reader.use_true_curr_bspn and not self.reader.use_true_bspn_for_ctr_eval: | |||||
| bspn = turn['bspn_gen'] | |||||
| else: | |||||
| bspn = turn['bspn'] | |||||
| constraint_dict = self.reader.bspan_to_constraint_dict( | |||||
| bspn) | |||||
| if constraint_dict.get(domain): | |||||
| venues = self.reader.db.queryJsons( | |||||
| domain, | |||||
| constraint_dict[domain], | |||||
| return_name=True) | |||||
| else: | |||||
| venues = [] | |||||
| if len(venue_offered[domain]) == 0 and venues: | |||||
| venue_offered[domain] = venues | |||||
| bspans[domain] = constraint_dict[domain] | |||||
| else: | |||||
| flag = False | |||||
| for ven in venues: | |||||
| if ven not in venue_offered[domain]: | |||||
| flag = True | |||||
| break | |||||
| if flag and venues: # sometimes there are no results so sample won't work | |||||
| venue_offered[domain] = venues | |||||
| bspans[domain] = constraint_dict[domain] | |||||
| else: # not limited so we can provide one | |||||
| venue_offered[domain] = '[value_name]' | |||||
| # ATTENTION: assumption here - we didn't provide phone or address twice! etc | |||||
| for requestable in requestables: | |||||
| if requestable == 'reference': | |||||
| if '[value_reference]' in sent_t: | |||||
| if domain in ['restaurant', 'hotel', 'train']: | |||||
| if 'booked' in turn['pointer'] or 'ok' in turn[ | |||||
| 'pointer'] or '[value_reference]' in turn[ | |||||
| 'resp']: | |||||
| # if pointer was allowing for that? | |||||
| provided_requestables[domain].append( | |||||
| 'reference') | |||||
| else: | |||||
| provided_requestables[domain].append( | |||||
| 'reference') | |||||
| else: | |||||
| if '[value_' + requestable + ']' in sent_t: | |||||
| provided_requestables[domain].append(requestable) | |||||
| # if name was given in the task | |||||
| for domain in goal.keys(): | |||||
| # if name was provided for the user, the match is being done automatically | |||||
| if 'name' in goal[domain]['informable']: | |||||
| venue_offered[domain] = '[value_name]' | |||||
| # special domains - entity does not need to be provided | |||||
| if domain in ['taxi', 'police', 'hospital']: | |||||
| venue_offered[domain] = '[value_name]' | |||||
| if domain == 'train': | |||||
| if not venue_offered[domain] and 'id' not in goal[domain][ | |||||
| 'requestable']: | |||||
| venue_offered[domain] = '[value_name]' | |||||
| """ | |||||
| Given all inform and requestable slots | |||||
| we go through each domain from the user goal | |||||
| and check whether right entity was provided and | |||||
| all requestable slots were given to the user. | |||||
| The dialogue is successful if that's the case for all domains. | |||||
| """ | |||||
| # HARD EVAL | |||||
| stats = { | |||||
| 'restaurant': [0, 0, 0], | |||||
| 'hotel': [0, 0, 0], | |||||
| 'attraction': [0, 0, 0], | |||||
| 'train': [0, 0, 0], | |||||
| 'taxi': [0, 0, 0], | |||||
| 'hospital': [0, 0, 0], | |||||
| 'police': [0, 0, 0] | |||||
| } | |||||
| match = 0 | |||||
| success = 0 | |||||
| # MATCH | |||||
| for domain in goal.keys(): | |||||
| match_stat = 0 | |||||
| if domain in ['restaurant', 'hotel', 'attraction', 'train']: | |||||
| goal_venues = self.reader.db.queryJsons( | |||||
| domain, goal[domain]['informable'], return_name=True) | |||||
| if type(venue_offered[domain] | |||||
| ) is str and '_name' in venue_offered[domain]: | |||||
| match += 1 | |||||
| match_stat = 1 | |||||
| elif len(venue_offered[domain]) > 0 and len( | |||||
| set(venue_offered[domain]) & set(goal_venues)) > 0: | |||||
| match += 1 | |||||
| match_stat = 1 | |||||
| else: | |||||
| if '_name]' in venue_offered[domain]: | |||||
| match += 1 | |||||
| match_stat = 1 | |||||
| stats[domain][0] = match_stat | |||||
| stats[domain][2] = 1 | |||||
| if soft_acc: | |||||
| match = float(match) / len(goal.keys()) | |||||
| else: | |||||
| if match == len(goal.keys()): | |||||
| match = 1.0 | |||||
| else: | |||||
| match = 0.0 | |||||
| for domain in domains_in_goal: | |||||
| for request in real_requestables[domain]: | |||||
| counts[request + '_total'] += 1 | |||||
| if request in provided_requestables[domain]: | |||||
| counts[request + '_offer'] += 1 | |||||
| # SUCCESS | |||||
| if fout is not None: | |||||
| for domain in domains_in_goal: | |||||
| success_stat = 0 | |||||
| domain_success = 0 | |||||
| if len(real_requestables[domain]) == 0: | |||||
| success += 1 | |||||
| success_stat = 1 | |||||
| stats[domain][1] = success_stat | |||||
| continue | |||||
| # if values in sentences are super set of requestables | |||||
| for request in real_requestables[domain]: | |||||
| if request in provided_requestables[domain]: | |||||
| domain_success += 1 | |||||
| if domain_success == len(real_requestables[domain]): | |||||
| success += 1 | |||||
| success_stat = 1 | |||||
| stats[domain][1] = success_stat | |||||
| # final eval | |||||
| if soft_acc: | |||||
| success = float(success) / len(real_requestables) | |||||
| else: | |||||
| if success >= len(real_requestables): | |||||
| success = 1 | |||||
| else: | |||||
| success = 0 | |||||
| else: | |||||
| if match == 1.0: | |||||
| for domain in domains_in_goal: | |||||
| success_stat = 0 | |||||
| domain_success = 0 | |||||
| if len(real_requestables[domain]) == 0: | |||||
| success += 1 | |||||
| success_stat = 1 | |||||
| stats[domain][1] = success_stat | |||||
| continue | |||||
| # if values in sentences are super set of requestables | |||||
| for request in real_requestables[domain]: | |||||
| if request in provided_requestables[domain]: | |||||
| domain_success += 1 | |||||
| if domain_success == len(real_requestables[domain]): | |||||
| success += 1 | |||||
| success_stat = 1 | |||||
| stats[domain][1] = success_stat | |||||
| # final eval | |||||
| if soft_acc: | |||||
| success = float(success) / len(real_requestables) | |||||
| else: | |||||
| if success >= len(real_requestables): | |||||
| success = 1 | |||||
| else: | |||||
| success = 0 | |||||
| if fout is not None and success == 0: | |||||
| sample = { | |||||
| dialog[0]['dial_id']: { | |||||
| 'log': log, | |||||
| 'real_requestables': real_requestables, | |||||
| 'provided_requestables': provided_requestables | |||||
| } | |||||
| } | |||||
| line = json.dumps(sample) | |||||
| fout.write(line) | |||||
| fout.write('\n') | |||||
| return success, match, stats, counts | |||||
| def _parseGoal(self, goal, true_goal, domain): | |||||
| """Parses user goal into dictionary format.""" | |||||
| goal[domain] = {} | |||||
| goal[domain] = {'informable': {}, 'requestable': [], 'booking': []} | |||||
| if 'info' in true_goal[domain]: | |||||
| if domain == 'train': | |||||
| # we consider dialogues only where train had to be booked! | |||||
| if 'book' in true_goal[domain]: | |||||
| goal[domain]['requestable'].append('reference') | |||||
| if 'reqt' in true_goal[domain]: | |||||
| if 'id' in true_goal[domain]['reqt']: | |||||
| goal[domain]['requestable'].append('id') | |||||
| else: | |||||
| if 'reqt' in true_goal[domain]: | |||||
| for s in true_goal[domain]['reqt']: # addtional requests: | |||||
| if s in [ | |||||
| 'phone', 'address', 'postcode', 'reference', | |||||
| 'id' | |||||
| ]: | |||||
| # ones that can be easily delexicalized | |||||
| goal[domain]['requestable'].append(s) | |||||
| if 'book' in true_goal[domain]: | |||||
| goal[domain]['requestable'].append('reference') | |||||
| for s, v in true_goal[domain]['info'].items(): | |||||
| s_, v_ = clean_slot_values(self.db_dir, domain, s, v) | |||||
| if len(v_.split()) > 1: | |||||
| v_ = ' '.join( | |||||
| [token.text for token in self.reader.nlp(v_)]).strip() | |||||
| goal[domain]['informable'][s_] = v_ | |||||
| if 'book' in true_goal[domain]: | |||||
| goal[domain]['booking'] = true_goal[domain]['book'] | |||||
| return goal | |||||
| class GenericEvaluator: | |||||
| def __init__(self, reader): | |||||
| self.reader = reader | |||||
| self.metric_dict = {} | |||||
| def pack_dial(self, data): | |||||
| dials = {} | |||||
| for turn in data: | |||||
| dial_id = turn['dial_id'] | |||||
| if dial_id not in dials: | |||||
| dials[dial_id] = [] | |||||
| dials[dial_id].append(turn) | |||||
| return dials | |||||
| def run_metrics(self, results): | |||||
| raise ValueError('Please specify the evaluator first') | |||||
| def bleu_metric(self, data, type='bleu'): | |||||
| gen, truth = [], [] | |||||
| for row in data: | |||||
| gen.append(self.clean(row['resp_gen'])) | |||||
| # gen.append(self.clean(row['resp'])) | |||||
| truth.append(self.clean(row['resp'])) | |||||
| wrap_generated = [[_] for _ in gen] | |||||
| wrap_truth = [[_] for _ in truth] | |||||
| sc = BLEUScorer().score(zip(wrap_generated, wrap_truth)) | |||||
| return sc | |||||
| def _normalize_constraint(self, | |||||
| constraint, | |||||
| ignore_dontcare=False, | |||||
| intersection=True): | |||||
| """ | |||||
| Normalize belief span, e.g. delete repeated words | |||||
| :param constraint - {'food': 'asian oritental', 'pricerange': 'cheap'} | |||||
| :param intersection: if true, only keeps the words that appear in th ontology | |||||
| we set intersection=True as in previous works | |||||
| :returns: normalized constraint dict | |||||
| e.g. - {'food': 'asian oritental', 'pricerange': 'cheap', 'area': ''} | |||||
| """ | |||||
| normalized = {} | |||||
| for s in self.informable_slots: | |||||
| normalized[s] = '' | |||||
| for s, v in constraint.items(): | |||||
| if ignore_dontcare and v == 'dontcare': | |||||
| continue | |||||
| if intersection and v != 'dontcare' and v not in self.entities_flat: | |||||
| continue | |||||
| normalized[s] = v | |||||
| return normalized | |||||
| def _normalize_act(self, aspn, intersection=False): | |||||
| aspn_list = aspn.split('|') | |||||
| normalized = {} | |||||
| for i, v in enumerate(aspn_list): | |||||
| seq = v.strip() | |||||
| word_set = set() | |||||
| for w in seq.split(): | |||||
| if intersection: | |||||
| if self.reader.act_order[i] == 'av': | |||||
| if '[value' in w: | |||||
| word_set.add(w) | |||||
| else: | |||||
| if w in self.requestable_slots: | |||||
| word_set.add(w) | |||||
| else: | |||||
| word_set.add(w) | |||||
| normalized[self.reader.act_order[i]] = word_set | |||||
| return normalized | |||||
| def tracker_metric(self, data, normalize=True): | |||||
| # turn level metric | |||||
| tp, fp, fn, db_correct = 0, 0, 0, 0 | |||||
| goal_accr, slot_accr, total = 0, {}, 1e-8 | |||||
| for s in self.informable_slots: | |||||
| slot_accr[s] = 0 | |||||
| for row in data: | |||||
| if normalize: | |||||
| gen = self._normalize_constraint(row['bspn_gen']) | |||||
| truth = self._normalize_constraint(row['bspn']) | |||||
| else: | |||||
| gen = self._normalize_constraint( | |||||
| row['bspn_gen'], intersection=False) | |||||
| truth = self._normalize_constraint( | |||||
| row['bspn'], intersection=False) | |||||
| valid = 'thank' not in row['user'] and 'bye' not in row['user'] | |||||
| if valid: | |||||
| for slot, value in gen.items(): | |||||
| if value in truth[slot]: | |||||
| tp += 1 | |||||
| else: | |||||
| fp += 1 | |||||
| for slot, value in truth.items(): | |||||
| if value not in gen[slot]: | |||||
| fn += 1 | |||||
| if truth and valid: | |||||
| total += 1 | |||||
| for s in self.informable_slots: | |||||
| if gen[s] == truth[s]: | |||||
| slot_accr[s] += 1 | |||||
| if gen == truth: | |||||
| goal_accr += 1 | |||||
| if row.get('db_gen') and row.get('db_match'): | |||||
| if row['db_gen'] == row['db_match']: | |||||
| db_correct += 1 | |||||
| precision, recall = tp / (tp + fp + 1e-8), tp / (tp + fn + 1e-8) | |||||
| f1 = 2 * precision * recall / (precision + recall + 1e-8) | |||||
| goal_accr /= total | |||||
| db_correct /= total | |||||
| for s in slot_accr: | |||||
| slot_accr[s] /= total | |||||
| return precision, recall, f1, goal_accr, slot_accr, db_correct | |||||
| def request_metric(self, data): | |||||
| # dialog level metric | |||||
| dials = self.pack_dial(data) | |||||
| tp, fp, fn = 0, 0, 0 | |||||
| for dial_id in dials: | |||||
| truth_req, gen_req = set(), set() | |||||
| dial = dials[dial_id] | |||||
| for turn_num, turn in enumerate(dial): | |||||
| resp_gen_token = self.clean(turn['resp_gen']).split() | |||||
| resp_token = self.clean(turn['resp']).split() | |||||
| for w in resp_gen_token: | |||||
| if '[value_' in w and w.endswith( | |||||
| ']') and w != '[value_name]': | |||||
| gen_req.add(w[1:-1].split('_')[1]) | |||||
| for w in resp_token: | |||||
| if '[value_' in w and w.endswith( | |||||
| ']') and w != '[value_name]': | |||||
| truth_req.add(w[1:-1].split('_')[1]) | |||||
| for req in gen_req: | |||||
| if req in truth_req: | |||||
| tp += 1 | |||||
| else: | |||||
| fp += 1 | |||||
| for req in truth_req: | |||||
| if req not in gen_req: | |||||
| fn += 1 | |||||
| precision, recall = tp / (tp + fp + 1e-8), tp / (tp + fn + 1e-8) | |||||
| f1 = 2 * precision * recall / (precision + recall + 1e-8) | |||||
| return f1, precision, recall | |||||
| def act_metric(self, data): | |||||
| # turn level metric | |||||
| tp, fp, fn = { | |||||
| 'all_s': 0, | |||||
| 'all_v': 0 | |||||
| }, { | |||||
| 'all_s': 0, | |||||
| 'all_v': 0 | |||||
| }, { | |||||
| 'all_s': 0, | |||||
| 'all_v': 0 | |||||
| } | |||||
| for s in self.requestable_slots: | |||||
| tp[s], fp[s], fn[s] = 0, 0, 0 | |||||
| tp['[value_%s]' % s], fp['[value_%s]' % s], fn['[value_%s]' | |||||
| % s] = 0, 0, 0 | |||||
| for row in data: | |||||
| gen = self._normalize_act(row['aspn_gen']) | |||||
| truth = self._normalize_act(row['aspn']) | |||||
| valid = 'thank' not in row['user'] and 'bye' not in row['user'] | |||||
| if valid: | |||||
| # how well the act decoder captures user's requests | |||||
| for value in gen['av']: | |||||
| if value in truth['av']: | |||||
| tp['all_v'] += 1 | |||||
| if tp.get(value): | |||||
| tp[value] += 1 | |||||
| else: | |||||
| fp['all_v'] += 1 | |||||
| if fp.get(value): | |||||
| fp[value] += 1 | |||||
| for value in truth['av']: | |||||
| if value not in gen['av']: | |||||
| fn['all_v'] += 1 | |||||
| if fn.get(value): | |||||
| fn[value] += 1 | |||||
| # how accurately the act decoder predicts system's question | |||||
| if 'as' not in gen: | |||||
| continue | |||||
| for slot in gen['as']: | |||||
| if slot in truth['as']: | |||||
| tp['all_s'] += 1 | |||||
| if tp.get(slot): | |||||
| tp[slot] += 1 | |||||
| else: | |||||
| fp['all_s'] += 1 | |||||
| if fp.get(slot): | |||||
| fp[slot] += 1 | |||||
| for slot in truth['as']: | |||||
| if slot not in gen['as']: | |||||
| fn['all_s'] += 1 | |||||
| if fn.get(slot): | |||||
| fn[slot] += 1 | |||||
| result = {} | |||||
| for k, v in tp.items(): | |||||
| precision, recall = tp[k] / (tp[k] + fp[k] + 1e-8), tp[k] / ( | |||||
| tp[k] + fn[k] + 1e-8) | |||||
| f1 = 2 * precision * recall / (precision + recall + 1e-8) | |||||
| result[k] = [f1, precision, recall] | |||||
| return result | |||||
| """ | |||||
| For the data preparation and evaluation on In-Car Assistant/CamRest, | |||||
| we refer to the code of LABES (https://github.com/thu-spmi/LABES) | |||||
| """ | |||||
| class CamRestEvaluator(GenericEvaluator): | |||||
| def __init__(self, reader): | |||||
| super().__init__(reader) | |||||
| self.entities_flat, self.entitiy_to_slot_dict = self.get_entities( | |||||
| self.reader.ontology_path) | |||||
| self.informable_slots = self.reader.otlg.informable_slots | |||||
| self.requestable_slots = self.reader.otlg.requestable_slots | |||||
| def run_metrics(self, results): | |||||
| metrics = {} | |||||
| bleu = self.bleu_metric(results) | |||||
| p, r, f1, goal_acc, slot_acc, db_acc = self.tracker_metric(results) | |||||
| match = self.match_metric(results) | |||||
| req_f1, req_p, req_r = self.request_metric(results) | |||||
| metrics['bleu'] = bleu | |||||
| metrics['match'] = match | |||||
| metrics['req_f1'] = req_f1 | |||||
| metrics['joint_goal'] = goal_acc | |||||
| metrics['slot_accu'] = slot_acc | |||||
| metrics['slot-p/r/f1'] = (p, r, f1) | |||||
| metrics['db_acc'] = db_acc | |||||
| return metrics | |||||
| def get_entities(self, entity_path): | |||||
| entities_flat = [] | |||||
| entitiy_to_slot_dict = {} | |||||
| raw_entities = json.loads(open(entity_path).read().lower()) | |||||
| for s in raw_entities['informable']: | |||||
| entities_flat.extend(raw_entities['informable'][s]) | |||||
| for v in raw_entities['informable'][s]: | |||||
| entitiy_to_slot_dict[v] = s | |||||
| return entities_flat, entitiy_to_slot_dict | |||||
| def constraint_same(self, truth_cons, gen_cons): | |||||
| if not truth_cons and not gen_cons: | |||||
| return True | |||||
| if not truth_cons or not gen_cons: | |||||
| return False | |||||
| return setsim(gen_cons, truth_cons) | |||||
| def match_metric(self, data): | |||||
| dials = self.pack_dial(data) | |||||
| match, total = 0, 1e-8 | |||||
| for dial_id in dials: | |||||
| dial = dials[dial_id] | |||||
| truth_cons, gen_cons = {'1': '', '2': '', '3': ''}, None | |||||
| for turn_num, turn in enumerate(dial): | |||||
| # find the last turn which the system provide an entity | |||||
| if '[value' in turn['resp_gen']: | |||||
| gen_cons = self._normalize_constraint( | |||||
| turn['bspn_gen'], ignore_dontcare=True) | |||||
| if '[value' in turn['resp']: | |||||
| truth_cons = self._normalize_constraint( | |||||
| turn['bspn'], ignore_dontcare=True) | |||||
| if not gen_cons: | |||||
| # if no entity is provided, choose the state of the last dialog turn | |||||
| gen_cons = self._normalize_constraint( | |||||
| dial[-1]['bspn_gen'], ignore_dontcare=True) | |||||
| if list(truth_cons.values()) != ['', '', '']: | |||||
| if gen_cons == truth_cons: | |||||
| match += 1 | |||||
| total += 1 | |||||
| return match / total | |||||
| def clean(self, resp): | |||||
| # we use the same clean process as in Sequicity, SEDST, FSDM | |||||
| # to ensure comparable results | |||||
| resp = resp.replace(f'{self.reader.sos_r_token} ', '') | |||||
| resp = resp.replace(f' {self.reader.eos_r_token}', '') | |||||
| resp = f'{self.reader.sos_r_token} {resp} {self.reader.eos_r_token}' | |||||
| for value, slot in self.entitiy_to_slot_dict.items(): | |||||
| resp = utils.clean_replace(resp, value, '[value_%s]' % slot) | |||||
| return resp | |||||
| class KvretEvaluator(GenericEvaluator): | |||||
| def __init__(self, reader): | |||||
| super().__init__(reader) | |||||
| self.entities_flat, self.entitiy_to_slot_dict = self.get_entities( | |||||
| self.reader.ontology_path) | |||||
| self.informable_slots = self.reader.otlg.informable_slots | |||||
| self.requestable_slots = self.reader.otlg.requestable_slots | |||||
| def run_metrics(self, results): | |||||
| metrics = {} | |||||
| bleu = self.bleu_metric(results) | |||||
| p, r, f1, goal_acc, slot_acc, db_acc = self.tracker_metric( | |||||
| results, normalize=True) | |||||
| match = self.match_metric(results) | |||||
| req_f1, req_p, req_r = self.request_metric(results) | |||||
| metrics['bleu'] = bleu | |||||
| metrics['match'] = match | |||||
| metrics['req_f1'] = req_f1 | |||||
| metrics['joint_goal'] = goal_acc | |||||
| metrics['slot_accu'] = slot_acc | |||||
| metrics['slot-p/r/f1'] = (p, r, f1) | |||||
| metrics['db_acc'] = db_acc | |||||
| return metrics | |||||
| def _normalize_constraint(self, | |||||
| constraint, | |||||
| ignore_dontcare=False, | |||||
| intersection=True): | |||||
| """ | |||||
| Normalize belief span, e.g. delete repeated words | |||||
| :param constraint - {'food': 'asian oritental', 'pricerange': 'cheap'} | |||||
| :param intersection: if true, only keeps the words that appear in th ontology | |||||
| we set intersection=True as in previous works | |||||
| :returns: normalized constraint dict | |||||
| e.g. - {'food': 'asian oritental', 'pricerange': 'cheap', 'area': ''} | |||||
| """ | |||||
| junk = [ | |||||
| 'good', 'great', 'quickest', 'shortest', 'route', 'week', | |||||
| 'fastest', 'nearest', 'next', 'closest', 'way', 'mile', 'activity', | |||||
| 'restaurant', 'appointment' | |||||
| ] | |||||
| normalized = {} | |||||
| for s in self.informable_slots: | |||||
| normalized[s] = '' | |||||
| for s, v in constraint.items(): | |||||
| for j in junk: | |||||
| v = ' '.join(v.replace(j, '').split()) | |||||
| if intersection and v not in self.entities_flat: | |||||
| continue | |||||
| if s in self.informable_slots: | |||||
| normalized[s] = v | |||||
| else: | |||||
| # TODO only use slot (not domain) in s for matching !!! | |||||
| pass | |||||
| return normalized | |||||
| def get_entities(self, entity_path): | |||||
| entities_flat = [] | |||||
| entitiy_to_slot_dict = {} | |||||
| entitiy_to_slot_dict = self.reader.entity_dict | |||||
| for s in entitiy_to_slot_dict: | |||||
| if s not in entities_flat: | |||||
| entities_flat.append(s) | |||||
| return entities_flat, entitiy_to_slot_dict | |||||
| def constraint_same(self, truth_cons, gen_cons): | |||||
| if not truth_cons and not gen_cons: | |||||
| return True | |||||
| if not truth_cons or not gen_cons: | |||||
| return False | |||||
| return setsim(gen_cons, truth_cons) | |||||
| def match_metric(self, data): | |||||
| dials = self.pack_dial(data) | |||||
| match, total = 0, 1e-8 | |||||
| for dial_id in dials: | |||||
| dial = dials[dial_id] | |||||
| truth_cons, gen_cons = { | |||||
| '1': '', | |||||
| '2': '', | |||||
| '3': '', | |||||
| '4': '', | |||||
| '5': '', | |||||
| '6': '', | |||||
| '7': '', | |||||
| '8': '', | |||||
| '9': '', | |||||
| '10': '', | |||||
| '11': '' | |||||
| }, None | |||||
| for turn_num, turn in enumerate(dial): | |||||
| # find the last turn which the system provide an entity | |||||
| if '[value' in turn['resp_gen']: | |||||
| gen_cons = self._normalize_constraint( | |||||
| turn['bspn_gen'], ignore_dontcare=True) | |||||
| if '[value' in turn['resp']: | |||||
| truth_cons = self._normalize_constraint( | |||||
| turn['bspn'], ignore_dontcare=True) | |||||
| if not gen_cons: | |||||
| # if no entity is provided, choose the state of the last dialog turn | |||||
| gen_cons = self._normalize_constraint( | |||||
| dial[-1]['bspn_gen'], ignore_dontcare=True) | |||||
| if list(truth_cons.values()) != [''] * 11: | |||||
| gen_cons = [x for x in gen_cons.values() if x] | |||||
| truth_cons = [x for x in truth_cons.values() if x] | |||||
| if self.constraint_same(gen_cons, truth_cons): | |||||
| match += 1 | |||||
| total += 1 | |||||
| return match / total | |||||
| def clean(self, resp): | |||||
| # we use the same clean process as in Sequicity, SEDST, FSDM | |||||
| # to ensure comparable results | |||||
| resp = resp.replace(f'{self.reader.sos_r_token} ', '') | |||||
| resp = resp.replace(f' {self.reader.eos_r_token}', '') | |||||
| resp = f'{self.reader.sos_r_token} {resp} {self.reader.eos_r_token}' | |||||
| for value, slot in self.entitiy_to_slot_dict.items(): | |||||
| resp = utils.clean_replace(resp, value, '[value_%s]' % slot) | |||||
| return resp | |||||
| @@ -15,27 +15,11 @@ from transformers.optimization import AdamW, get_linear_schedule_with_warmup | |||||
| from modelscope.trainers.nlp.space.metrics.metrics_tracker import \ | from modelscope.trainers.nlp.space.metrics.metrics_tracker import \ | ||||
| MetricsTracker | MetricsTracker | ||||
| from modelscope.utils.constant import ModelFile | |||||
| from modelscope.utils.logger import get_logger | |||||
| from modelscope.utils.nlp.space import ontology | from modelscope.utils.nlp.space import ontology | ||||
| def get_logger(log_path, name='default'): | |||||
| logger = logging.getLogger(name) | |||||
| logger.propagate = False | |||||
| logger.setLevel(logging.DEBUG) | |||||
| formatter = logging.Formatter('%(message)s') | |||||
| sh = logging.StreamHandler(sys.stdout) | |||||
| sh.setFormatter(formatter) | |||||
| logger.addHandler(sh) | |||||
| fh = logging.FileHandler(log_path, mode='w') | |||||
| fh.setFormatter(formatter) | |||||
| logger.addHandler(fh) | |||||
| return logger | |||||
| class Trainer(object): | class Trainer(object): | ||||
| def __init__(self, | def __init__(self, | ||||
| @@ -51,15 +35,16 @@ class Trainer(object): | |||||
| self.do_train = config.do_train | self.do_train = config.do_train | ||||
| self.do_infer = config.do_infer | self.do_infer = config.do_infer | ||||
| self.is_decreased_valid_metric = config.Trainer.valid_metric_name[ | |||||
| 0] == '-' | |||||
| self.valid_metric_name = config.Trainer.valid_metric_name[1:] | |||||
| self.num_epochs = config.Trainer.num_epochs | |||||
| # self.save_dir = config.Trainer.save_dir | |||||
| self.log_steps = config.Trainer.log_steps | |||||
| self.valid_steps = config.Trainer.valid_steps | |||||
| self.save_checkpoint = config.Trainer.save_checkpoint | |||||
| self.save_summary = config.Trainer.save_summary | |||||
| if self.do_train: | |||||
| self.is_decreased_valid_metric = config.Trainer.valid_metric_name[ | |||||
| 0] == '-' | |||||
| self.valid_metric_name = config.Trainer.valid_metric_name[1:] | |||||
| self.num_epochs = config.Trainer.num_epochs | |||||
| self.save_dir = config.Trainer.save_dir | |||||
| self.log_steps = config.Trainer.log_steps | |||||
| self.valid_steps = config.Trainer.valid_steps | |||||
| self.save_checkpoint = config.Trainer.save_checkpoint | |||||
| self.save_summary = config.Trainer.save_summary | |||||
| self.lr = config.Model.lr | self.lr = config.Model.lr | ||||
| self.weight_decay = config.Model.weight_decay | self.weight_decay = config.Model.weight_decay | ||||
| self.batch_size = config.Trainer.batch_size | self.batch_size = config.Trainer.batch_size | ||||
| @@ -71,22 +56,21 @@ class Trainer(object): | |||||
| self.optimizer = optimizer | self.optimizer = optimizer | ||||
| self.model = model | self.model = model | ||||
| self.func_model = self.model.module if self.gpu > 1 else self.model | |||||
| self.func_model = self.model.module if self.gpu > 1 and config.use_gpu else self.model | |||||
| self.reader = reader | self.reader = reader | ||||
| self.evaluator = evaluator | self.evaluator = evaluator | ||||
| self.tokenizer = reader.tokenizer | self.tokenizer = reader.tokenizer | ||||
| # if not os.path.exists(self.save_dir): | |||||
| # os.makedirs(self.save_dir) | |||||
| # self.logger = logger or get_logger(os.path.join(self.save_dir, "trainer.log"), "trainer") | |||||
| self.logger = logger or get_logger('trainer.log', 'trainer') | |||||
| self.logger = get_logger() | |||||
| self.batch_metrics_tracker = MetricsTracker() | self.batch_metrics_tracker = MetricsTracker() | ||||
| self.token_metrics_tracker = MetricsTracker() | self.token_metrics_tracker = MetricsTracker() | ||||
| self.best_valid_metric = float( | |||||
| 'inf' if self.is_decreased_valid_metric else '-inf') | |||||
| if self.do_train: | |||||
| if not os.path.exists(self.save_dir): | |||||
| os.makedirs(self.save_dir) | |||||
| self.best_valid_metric = float( | |||||
| 'inf' if self.is_decreased_valid_metric else '-inf') | |||||
| self.epoch = 0 | self.epoch = 0 | ||||
| def decode_generated_bspn_resp(self, generated): | def decode_generated_bspn_resp(self, generated): | ||||
| @@ -248,9 +232,12 @@ class Trainer(object): | |||||
| # Save current best model | # Save current best model | ||||
| if is_best: | if is_best: | ||||
| best_model_file = os.path.join(self.save_dir, 'best.model') | |||||
| best_model_file = os.path.join(self.save_dir, | |||||
| ModelFile.TORCH_MODEL_BIN_FILE) | |||||
| torch.save(self.model.state_dict(), best_model_file) | torch.save(self.model.state_dict(), best_model_file) | ||||
| best_train_file = os.path.join(self.save_dir, 'best.train') | |||||
| best_train_file = os.path.join( | |||||
| self.save_dir, | |||||
| '{}.train'.format(ModelFile.TORCH_MODEL_BIN_FILE)) | |||||
| torch.save(train_state, best_train_file) | torch.save(train_state, best_train_file) | ||||
| self.logger.info( | self.logger.info( | ||||
| f"Saved best model state to '{best_model_file}' with new best valid metric " | f"Saved best model state to '{best_model_file}' with new best valid metric " | ||||
| @@ -324,8 +311,7 @@ class Trainer(object): | |||||
| self.func_model.load_state_dict(model_state_dict) | self.func_model.load_state_dict(model_state_dict) | ||||
| self.logger.info( | self.logger.info( | ||||
| f"Loaded model state from '{self.func_model.init_checkpoint}.model'" | |||||
| ) | |||||
| f"Loaded model state from '{self.func_model.init_checkpoint}'") | |||||
| def _load_train_state(): | def _load_train_state(): | ||||
| train_file = f'{self.func_model.init_checkpoint}.train' | train_file = f'{self.func_model.init_checkpoint}.train' | ||||
| @@ -558,19 +544,17 @@ class MultiWOZTrainer(Trainer): | |||||
| generated_bs = outputs[0].cpu().numpy().tolist() | generated_bs = outputs[0].cpu().numpy().tolist() | ||||
| bspn_gen = self.decode_generated_bspn(generated_bs) | bspn_gen = self.decode_generated_bspn(generated_bs) | ||||
| # check DB result | # check DB result | ||||
| if self.reader.use_true_db_pointer: # To control whether current db is ground truth | |||||
| if self.reader.use_true_db_pointer: | |||||
| db = turn['db'] | db = turn['db'] | ||||
| else: | else: | ||||
| db_result = self.reader.bspan_to_DBpointer( | db_result = self.reader.bspan_to_DBpointer( | ||||
| self.tokenizer.decode(bspn_gen), | self.tokenizer.decode(bspn_gen), | ||||
| turn['turn_domain']) | turn['turn_domain']) | ||||
| assert len(turn['db']) == 4 | |||||
| book_result = turn['db'][2] | |||||
| assert len(turn['db']) == 3 | |||||
| assert isinstance(db_result, str) | assert isinstance(db_result, str) | ||||
| db = \ | db = \ | ||||
| [self.reader.sos_db_id] + \ | [self.reader.sos_db_id] + \ | ||||
| self.tokenizer.convert_tokens_to_ids([db_result]) + \ | self.tokenizer.convert_tokens_to_ids([db_result]) + \ | ||||
| [book_result] + \ | |||||
| [self.reader.eos_db_id] | [self.reader.eos_db_id] | ||||
| prompt_id = self.reader.sos_a_id | prompt_id = self.reader.sos_a_id | ||||
| @@ -636,7 +620,7 @@ class MultiWOZTrainer(Trainer): | |||||
| score = 0.5 * (success + match) + bleu | score = 0.5 * (success + match) + bleu | ||||
| # log results | # log results | ||||
| metrics_message = 'match: %2.2f success: %2.2f bleu: %2.2f score: %.2f' %\ | |||||
| metrics_message = 'match: %2.2f success: %2.2f bleu: %2.2f score: %.2f' % \ | |||||
| (match, success, bleu, score) | (match, success, bleu, score) | ||||
| message_prefix = f'[Infer][{self.epoch}]' | message_prefix = f'[Infer][{self.epoch}]' | ||||
| time_cost = f'TIME-{time.time() - begin_time:.3f}' | time_cost = f'TIME-{time.time() - begin_time:.3f}' | ||||
| @@ -0,0 +1,333 @@ | |||||
| import os | |||||
| import re | |||||
| from . import ontology | |||||
| def clean_text_split_dot(text): | |||||
| text = re.sub(r'([a-zT]+)\.([a-z])', r'\1 . \2', | |||||
| text) # 'abc.xyz' -> 'abc . xyz' | |||||
| text = re.sub(r'(\w+)\.\.? ', r'\1 . ', text) # if 'abc. ' -> 'abc . ' | |||||
| return text | |||||
| def clean_text(data_dir, text): | |||||
| text = text.strip() | |||||
| text = text.lower() | |||||
| text = text.replace(u'’', "'") | |||||
| text = text.replace(u'‘', "'") | |||||
| text = text.replace(';', ',') | |||||
| text = text.replace('"', ' ') | |||||
| text = text.replace('/', ' and ') | |||||
| text = text.replace("don't", "do n't") | |||||
| text = clean_time(text) | |||||
| baddata = { | |||||
| r'c\.b (\d), (\d) ([a-z])\.([a-z])': r'cb\1\2\3\4', | |||||
| 'c.b. 1 7 d.y': 'cb17dy', | |||||
| 'c.b.1 7 d.y': 'cb17dy', | |||||
| 'c.b 25, 9 a.q': 'cb259aq', | |||||
| 'isc.b 25, 9 a.q': 'is cb259aq', | |||||
| 'c.b2, 1 u.f': 'cb21uf', | |||||
| 'c.b 1,2 q.a': 'cb12qa', | |||||
| '0-122-336-5664': '01223365664', | |||||
| 'postcodecb21rs': 'postcode cb21rs', | |||||
| r'i\.d': 'id', | |||||
| ' i d ': 'id', | |||||
| 'Telephone:01223358966': 'Telephone: 01223358966', | |||||
| 'depature': 'departure', | |||||
| 'depearting': 'departing', | |||||
| '-type': ' type', | |||||
| r'b[\s]?&[\s]?b': 'bed and breakfast', | |||||
| 'b and b': 'bed and breakfast', | |||||
| r'guesthouse[s]?': 'guest house', | |||||
| r'swimmingpool[s]?': 'swimming pool', | |||||
| "wo n\'t": 'will not', | |||||
| " \'d ": ' would ', | |||||
| " \'m ": ' am ', | |||||
| " \'re' ": ' are ', | |||||
| " \'ll' ": ' will ', | |||||
| " \'ve ": ' have ', | |||||
| r'^\'': '', | |||||
| r'\'$': '', | |||||
| } | |||||
| for tmpl, good in baddata.items(): | |||||
| text = re.sub(tmpl, good, text) | |||||
| text = re.sub(r'([a-zT]+)\.([a-z])', r'\1 . \2', | |||||
| text) # 'abc.xyz' -> 'abc . xyz' | |||||
| text = re.sub(r'(\w+)\.\.? ', r'\1 . ', text) # if 'abc. ' -> 'abc . ' | |||||
| with open(os.path.join(data_dir, 'mapping.pair'), 'r') as fin: | |||||
| for line in fin.readlines(): | |||||
| fromx, tox = line.replace('\n', '').split('\t') | |||||
| text = ' ' + text + ' ' | |||||
| text = text.replace(' ' + fromx + ' ', ' ' + tox + ' ')[1:-1] | |||||
| return text | |||||
| def clean_time(utter): | |||||
| utter = re.sub(r'(\d+) ([ap]\.?m)', lambda x: x.group(1) + x.group(2), | |||||
| utter) # 9 am -> 9am | |||||
| utter = re.sub(r'((?<!\d)\d:\d+)(am)?', r'0\1', utter) | |||||
| utter = re.sub(r'((?<!\d)\d)am', r'0\1:00', utter) | |||||
| utter = re.sub(r'((?<!\d)\d)pm', | |||||
| lambda x: str(int(x.group(1)) + 12) + ':00', utter) | |||||
| utter = re.sub(r'(\d+)(:\d+)pm', | |||||
| lambda x: str(int(x.group(1)) + 12) + x.group(2), utter) | |||||
| utter = re.sub(r'(\d+)a\.?m', r'\1', utter) | |||||
| return utter | |||||
| def clean_slot_values(data_dir, domain, slot, value): | |||||
| value = clean_text(data_dir, value) | |||||
| if not value: | |||||
| value = '' | |||||
| elif value == 'not mentioned': | |||||
| value = '' | |||||
| # value = 'not mentioned' # if in DST setting | |||||
| elif domain == 'attraction': | |||||
| if slot == 'name': | |||||
| if value == 't': | |||||
| value = '' | |||||
| if value == 'trinity': | |||||
| value = 'trinity college' | |||||
| elif slot == 'area': | |||||
| if value in ['town centre', 'cent', 'center', 'ce']: | |||||
| value = 'centre' | |||||
| elif value in [ | |||||
| 'ely', 'in town', 'museum', 'norwich', 'same area as hotel' | |||||
| ]: | |||||
| value = '' | |||||
| elif value in ['we']: | |||||
| value = 'west' | |||||
| elif slot == 'type': | |||||
| if value in ['m', 'mus', 'musuem']: | |||||
| value = 'museum' | |||||
| elif value in ['art', 'architectural']: | |||||
| value = 'architecture' | |||||
| elif value in ['churches']: | |||||
| value = 'church' | |||||
| elif value in ['coll']: | |||||
| value = 'college' | |||||
| elif value in ['concert', 'concerthall']: | |||||
| value = 'concert hall' | |||||
| elif value in ['night club']: | |||||
| value = 'nightclub' | |||||
| elif value in [ | |||||
| 'mutiple sports', 'mutliple sports', 'sports', 'galleria' | |||||
| ]: | |||||
| value = 'multiple sports' | |||||
| elif value in ['ol', 'science', 'gastropub', 'la raza']: | |||||
| value = '' | |||||
| elif value in ['swimmingpool', 'pool']: | |||||
| value = 'swimming pool' | |||||
| elif value in ['fun']: | |||||
| value = 'entertainment' | |||||
| elif domain == 'hotel': | |||||
| if slot == 'area': | |||||
| if value in [ | |||||
| 'cen', 'centre of town', 'near city center', 'center' | |||||
| ]: | |||||
| value = 'centre' | |||||
| elif value in ['east area', 'east side']: | |||||
| value = 'east' | |||||
| elif value in ['in the north', 'north part of town']: | |||||
| value = 'north' | |||||
| elif value in ['we']: | |||||
| value = 'west' | |||||
| elif slot == 'day': | |||||
| if value == 'monda': | |||||
| value = 'monday' | |||||
| elif value == 't': | |||||
| value = 'tuesday' | |||||
| elif slot == 'name': | |||||
| if value == 'uni': | |||||
| value = 'university arms hotel' | |||||
| elif value == 'university arms': | |||||
| value = 'university arms hotel' | |||||
| elif value == 'acron': | |||||
| value = 'acorn guest house' | |||||
| elif value == 'ashley': | |||||
| value = 'ashley hotel' | |||||
| elif value == 'arbury lodge guesthouse': | |||||
| value = 'arbury lodge guest house' | |||||
| elif value == 'la': | |||||
| value = 'la margherit' | |||||
| elif value == 'no': | |||||
| value = '' | |||||
| elif slot == 'internet': | |||||
| if value == 'does not': | |||||
| value = 'no' | |||||
| elif value in ['y', 'free', 'free internet']: | |||||
| value = 'yes' | |||||
| elif value in ['4']: | |||||
| value = '' | |||||
| elif slot == 'parking': | |||||
| if value == 'n': | |||||
| value = 'no' | |||||
| elif value in ['free parking']: | |||||
| value = 'yes' | |||||
| elif value in ['y']: | |||||
| value = 'yes' | |||||
| elif slot in ['pricerange', 'price range']: | |||||
| slot = 'pricerange' | |||||
| if value == 'moderately': | |||||
| value = 'moderate' | |||||
| elif value in ['any']: | |||||
| value = "do n't care" | |||||
| elif value in ['any']: | |||||
| value = "do n't care" | |||||
| elif value in ['inexpensive']: | |||||
| value = 'cheap' | |||||
| elif value in ['2', '4']: | |||||
| value = '' | |||||
| elif slot == 'stars': | |||||
| if value == 'two': | |||||
| value = '2' | |||||
| elif value == 'three': | |||||
| value = '3' | |||||
| elif value in [ | |||||
| '4-star', '4 stars', '4 star', 'four star', 'four stars' | |||||
| ]: | |||||
| value = '4' | |||||
| elif slot == 'type': | |||||
| if value == '0 star rarting': | |||||
| value = '' | |||||
| elif value == 'guesthouse': | |||||
| value = 'guest house' | |||||
| elif value not in ['hotel', 'guest house', "do n't care"]: | |||||
| value = '' | |||||
| elif domain == 'restaurant': | |||||
| if slot == 'area': | |||||
| if value in [ | |||||
| 'center', 'scentre', 'center of town', 'city center', | |||||
| 'cb30aq', 'town center', 'centre of cambridge', | |||||
| 'city centre' | |||||
| ]: | |||||
| value = 'centre' | |||||
| elif value == 'west part of town': | |||||
| value = 'west' | |||||
| elif value == 'n': | |||||
| value = 'north' | |||||
| elif value in ['the south']: | |||||
| value = 'south' | |||||
| elif value not in [ | |||||
| 'centre', 'south', "do n't care", 'west', 'east', 'north' | |||||
| ]: | |||||
| value = '' | |||||
| elif slot == 'day': | |||||
| if value == 'monda': | |||||
| value = 'monday' | |||||
| elif value == 't': | |||||
| value = 'tuesday' | |||||
| elif slot in ['pricerange', 'price range']: | |||||
| slot = 'pricerange' | |||||
| if value in ['moderately', 'mode', 'mo']: | |||||
| value = 'moderate' | |||||
| elif value in ['not']: | |||||
| value = '' | |||||
| elif value in ['inexpensive', 'ch']: | |||||
| value = 'cheap' | |||||
| elif slot == 'food': | |||||
| if value == 'barbecue': | |||||
| value = 'barbeque' | |||||
| elif slot == 'pricerange': | |||||
| if value == 'moderately': | |||||
| value = 'moderate' | |||||
| elif slot == 'time': | |||||
| if value == '9:00': | |||||
| value = '09:00' | |||||
| elif value == '9:45': | |||||
| value = '09:45' | |||||
| elif value == '1330': | |||||
| value = '13:30' | |||||
| elif value == '1430': | |||||
| value = '14:30' | |||||
| elif value == '9:15': | |||||
| value = '09:15' | |||||
| elif value == '9:30': | |||||
| value = '09:30' | |||||
| elif value == '1830': | |||||
| value = '18:30' | |||||
| elif value == '9': | |||||
| value = '09:00' | |||||
| elif value == '2:00': | |||||
| value = '14:00' | |||||
| elif value == '1:00': | |||||
| value = '13:00' | |||||
| elif value == '3:00': | |||||
| value = '15:00' | |||||
| elif domain == 'taxi': | |||||
| if slot in ['arriveBy', 'arrive by']: | |||||
| slot = 'arriveby' | |||||
| if value == '1530': | |||||
| value = '15:30' | |||||
| elif value == '15 minutes': | |||||
| value = '' | |||||
| elif slot in ['leaveAt', 'leave at']: | |||||
| slot = 'leaveat' | |||||
| if value == '1:00': | |||||
| value = '01:00' | |||||
| elif value == '21:4': | |||||
| value = '21:04' | |||||
| elif value == '4:15': | |||||
| value = '04:15' | |||||
| elif value == '5:45': | |||||
| value = '05:45' | |||||
| elif value == '0700': | |||||
| value = '07:00' | |||||
| elif value == '4:45': | |||||
| value = '04:45' | |||||
| elif value == '8:30': | |||||
| value = '08:30' | |||||
| elif value == '9:30': | |||||
| value = '09:30' | |||||
| value = value.replace('.', ':') | |||||
| elif domain == 'train': | |||||
| if slot in ['arriveBy', 'arrive by']: | |||||
| slot = 'arriveby' | |||||
| if value == '1': | |||||
| value = '01:00' | |||||
| elif value in ['does not care', 'doesnt care', "doesn't care"]: | |||||
| value = "do n't care" | |||||
| elif value == '8:30': | |||||
| value = '08:30' | |||||
| elif value == 'not 15:45': | |||||
| value = '' | |||||
| value = value.replace('.', ':') | |||||
| elif slot == 'day': | |||||
| if value == 'doesnt care' or value == "doesn't care": | |||||
| value = "do n't care" | |||||
| elif slot in ['leaveAt', 'leave at']: | |||||
| slot = 'leaveat' | |||||
| if value == '2:30': | |||||
| value = '02:30' | |||||
| elif value == '7:54': | |||||
| value = '07:54' | |||||
| elif value == 'after 5:45 pm': | |||||
| value = '17:45' | |||||
| elif value in [ | |||||
| 'early evening', 'friday', 'sunday', 'tuesday', 'afternoon' | |||||
| ]: | |||||
| value = '' | |||||
| elif value == '12': | |||||
| value = '12:00' | |||||
| elif value == '1030': | |||||
| value = '10:30' | |||||
| elif value == '1700': | |||||
| value = '17:00' | |||||
| elif value in [ | |||||
| 'does not care', 'doesnt care', 'do nt care', | |||||
| "doesn't care" | |||||
| ]: | |||||
| value = "do n't care" | |||||
| value = value.replace('.', ':') | |||||
| if value in ['dont care', "don't care", 'do nt care', "doesn't care"]: | |||||
| value = "do n't care" | |||||
| if ontology.normlize_slot_names.get(slot): | |||||
| slot = ontology.normlize_slot_names[slot] | |||||
| return slot, value | |||||
| @@ -4,8 +4,11 @@ from collections import OrderedDict | |||||
| import json | import json | ||||
| import numpy as np | import numpy as np | ||||
| from modelscope.utils.logger import get_logger | |||||
| from . import ontology | from . import ontology | ||||
| logger = get_logger() | |||||
| def max_lens(X): | def max_lens(X): | ||||
| lens = [len(X)] | lens = [len(X)] | ||||
| @@ -117,8 +120,8 @@ class MultiWOZVocab(object): | |||||
| def construct(self): | def construct(self): | ||||
| freq_dict_sorted = sorted( | freq_dict_sorted = sorted( | ||||
| self._freq_dict.keys(), key=lambda x: -self._freq_dict[x]) | self._freq_dict.keys(), key=lambda x: -self._freq_dict[x]) | ||||
| print('Vocabulary size including oov: %d' % | |||||
| (len(freq_dict_sorted) + len(self._idx2word))) | |||||
| logger.info('Vocabulary size including oov: %d' % | |||||
| (len(freq_dict_sorted) + len(self._idx2word))) | |||||
| if len(freq_dict_sorted) + len(self._idx2word) < self.vocab_size: | if len(freq_dict_sorted) + len(self._idx2word) < self.vocab_size: | ||||
| logging.warning( | logging.warning( | ||||
| 'actual label set smaller than that configured: {}/{}'.format( | 'actual label set smaller than that configured: {}/{}'.format( | ||||
| @@ -148,8 +151,9 @@ class MultiWOZVocab(object): | |||||
| for w, idx in self._word2idx.items(): | for w, idx in self._word2idx.items(): | ||||
| self._idx2word[idx] = w | self._idx2word[idx] = w | ||||
| self.vocab_size_oov = len(self._idx2word) | self.vocab_size_oov = len(self._idx2word) | ||||
| print('vocab file loaded from "' + vocab_path + '"') | |||||
| print('Vocabulary size including oov: %d' % (self.vocab_size_oov)) | |||||
| logger.info('vocab file loaded from "' + vocab_path + '"') | |||||
| logger.info('Vocabulary size including oov: %d' % | |||||
| (self.vocab_size_oov)) | |||||
| def save_vocab(self, vocab_path): | def save_vocab(self, vocab_path): | ||||
| _freq_dict = OrderedDict( | _freq_dict = OrderedDict( | ||||
| @@ -0,0 +1,68 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import unittest | |||||
| import torch | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.metainfo import Preprocessors, Trainers | |||||
| from modelscope.msdatasets import MsDataset | |||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.constant import DownloadMode, ModelFile | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class TestDialogModelingTrainer(unittest.TestCase): | |||||
| model_id = 'damo/nlp_space_pretrained-dialog-model' | |||||
| output_dir = './dialog_fintune_result' | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_trainer_with_model_and_args(self): | |||||
| # download data set | |||||
| data_multiwoz = MsDataset.load( | |||||
| 'MultiWoz2.0', download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) | |||||
| data_dir = os.path.join( | |||||
| data_multiwoz._hf_ds.config_kwargs['split_config']['train'], | |||||
| 'data') | |||||
| # download model | |||||
| model_dir = snapshot_download(self.model_id) | |||||
| # dialog finetune config | |||||
| def cfg_modify_fn(cfg): | |||||
| config = { | |||||
| 'seed': 10, | |||||
| 'gpu': 4, | |||||
| 'use_data_distributed': False, | |||||
| 'valid_metric_name': '-loss', | |||||
| 'num_epochs': 60, | |||||
| 'save_dir': self.output_dir, | |||||
| 'token_loss': True, | |||||
| 'batch_size': 32, | |||||
| 'log_steps': 10, | |||||
| 'valid_steps': 0, | |||||
| 'save_checkpoint': True, | |||||
| 'save_summary': False, | |||||
| 'shuffle': True, | |||||
| 'sort_pool_size': 0 | |||||
| } | |||||
| cfg.Trainer = config | |||||
| cfg.use_gpu = torch.cuda.is_available() and config['gpu'] >= 1 | |||||
| return cfg | |||||
| # trainer config | |||||
| kwargs = dict( | |||||
| model_dir=model_dir, | |||||
| cfg_name='gen_train_config.json', | |||||
| data_dir=data_dir, | |||||
| cfg_modify_fn=cfg_modify_fn) | |||||
| trainer = build_trainer( | |||||
| name=Trainers.dialog_modeling_trainer, default_args=kwargs) | |||||
| trainer.train() | |||||
| checkpoint_path = os.path.join(self.output_dir, | |||||
| ModelFile.TORCH_MODEL_BIN_FILE) | |||||
| assert os.path.exists(checkpoint_path) | |||||
| trainer.evaluate(checkpoint_path=checkpoint_path) | |||||