| @@ -183,7 +183,7 @@ class BeamSearch(Generator): | |||||
| scores_after_end = np.full(self.vocab_size, -1e10, dtype='float32') | scores_after_end = np.full(self.vocab_size, -1e10, dtype='float32') | ||||
| scores_after_end[ | scores_after_end[ | ||||
| self.pad_id] = 0 # 希望<eos>之后只生成<pad>,故使词表中log(p(<pad>))最高(0) | |||||
| self.pad_id] = 0 # we want <pad> is generated after <eos>,so maximum log(p(<pad>)) is (0) | |||||
| scores_after_end = torch.from_numpy(scores_after_end) | scores_after_end = torch.from_numpy(scores_after_end) | ||||
| if self.use_gpu: | if self.use_gpu: | ||||
| @@ -245,10 +245,8 @@ class BeamSearch(Generator): | |||||
| scores = scores.reshape(batch_size, beam_size * self.vocab_size) | scores = scores.reshape(batch_size, beam_size * self.vocab_size) | ||||
| topk_scores, topk_indices = torch.topk(scores, beam_size) | topk_scores, topk_indices = torch.topk(scores, beam_size) | ||||
| # topk_indices: [batch_size, beam_size * self.vocab_size] (已reshape) | |||||
| # 判断当前时间步产生词的前一个词在哪个beam中,对vocab_size取商 | |||||
| # topk_indices: [batch_size, beam_size * self.vocab_size] (already reshaped) | |||||
| parent_idx = topk_indices.floor_divide(self.vocab_size) | parent_idx = topk_indices.floor_divide(self.vocab_size) | ||||
| # 对vocab_size取余 | |||||
| preds = topk_indices % self.vocab_size | preds = topk_indices % self.vocab_size | ||||
| # Gather state / sequence_scores | # Gather state / sequence_scores | ||||
| @@ -262,14 +260,14 @@ class BeamSearch(Generator): | |||||
| predictions = predictions.reshape(batch_size, beam_size, step) | predictions = predictions.reshape(batch_size, beam_size, step) | ||||
| predictions = torch.cat([predictions, preds.unsqueeze(2)], dim=2) | predictions = torch.cat([predictions, preds.unsqueeze(2)], dim=2) | ||||
| # 希望生成的整个句子已完结,所以要求最后一个token为<eos>或者<pad>(跟在<eos>之后),否则惩罚 | |||||
| # The last token should be <eos> or <pad> | |||||
| pre_ids = predictions[:, :, -1] | pre_ids = predictions[:, :, -1] | ||||
| pre_eos_mask = (1 - torch.not_equal(pre_ids, eos_id).float()) + \ | pre_eos_mask = (1 - torch.not_equal(pre_ids, eos_id).float()) + \ | ||||
| (1 - torch.not_equal(pre_ids, self.pad_id).float()) | (1 - torch.not_equal(pre_ids, self.pad_id).float()) | ||||
| sequence_scores = sequence_scores * pre_eos_mask + ( | sequence_scores = sequence_scores * pre_eos_mask + ( | ||||
| 1 - pre_eos_mask) * (-1e10) | 1 - pre_eos_mask) * (-1e10) | ||||
| # 先获得ascending排序的index,便于之后对predictions和sequence_scores排序(针对beam size轴) | |||||
| # first get ascending ordered index,then sort "predictions" and "sequence_scores" | |||||
| indices = torch.argsort(sequence_scores, dim=1) | indices = torch.argsort(sequence_scores, dim=1) | ||||
| indices = indices + pos_index | indices = indices + pos_index | ||||
| indices = indices.reshape(-1) | indices = indices.reshape(-1) | ||||
| @@ -122,11 +122,7 @@ class UnifiedTransformer(SpaceModelBase): | |||||
| auto_regressive=False): | auto_regressive=False): | ||||
| """ | """ | ||||
| Create attention mask. | Create attention mask. | ||||
| 创建从序列形式到矩阵形式的mask:[batch_size, max_seq_len, 1] -> [batch_size, max_seq_len, max_seq_len] | |||||
| mask除了要考虑attention mask(自回归),还需要考虑pad的mask(自回归和双向) | |||||
| 注: | |||||
| 1. 一个句子中的非<pad>词看整个句子,该句中只有<pad>词才被mask | |||||
| 2. 一个句子中的<pad>词看整个句子,该句的所有词都应该被mask | |||||
| from sequence to matrix:[batch_size, max_seq_len, 1] -> [batch_size, max_seq_len, max_seq_len] | |||||
| @param : input_mask | @param : input_mask | ||||
| @type : Variable(shape: [batch_size, max_seq_len]) | @type : Variable(shape: [batch_size, max_seq_len]) | ||||
| @@ -142,13 +138,11 @@ class UnifiedTransformer(SpaceModelBase): | |||||
| mask = mask1 * mask2 | mask = mask1 * mask2 | ||||
| if append_head: | if append_head: | ||||
| # 拼接上句首位置([M]/z)的mask | |||||
| mask = torch.cat([mask[:, :1, :], mask], dim=1) | mask = torch.cat([mask[:, :1, :], mask], dim=1) | ||||
| mask = torch.cat([mask[:, :, :1], mask], dim=2) | mask = torch.cat([mask[:, :, :1], mask], dim=2) | ||||
| seq_len += 1 | seq_len += 1 | ||||
| if auto_regressive: | if auto_regressive: | ||||
| # 将tgt端的<pad> mask和自回归attention mask融合 | |||||
| seq_mask = self.sequence_mask[:seq_len, :seq_len] | seq_mask = self.sequence_mask[:seq_len, :seq_len] | ||||
| seq_mask = seq_mask.to(mask.device) | seq_mask = seq_mask.to(mask.device) | ||||
| mask = mask * seq_mask | mask = mask * seq_mask | ||||
| @@ -159,7 +153,7 @@ class UnifiedTransformer(SpaceModelBase): | |||||
| def _join_mask(self, mask1, mask2): | def _join_mask(self, mask1, mask2): | ||||
| """ | """ | ||||
| Merge source attention mask and target attention mask. | Merge source attention mask and target attention mask. | ||||
| 合并后的整个mask矩阵可以分为四个部分:左上lu/右上ru/左下lb/右下rb | |||||
| There are four parts:left upper (lu) / right upper (ru) / left below (lb) / right below (rb) | |||||
| @param : mask1 : source attention mask | @param : mask1 : source attention mask | ||||
| @type : Variable(shape: [batch_size, max_src_len, max_src_len]) | @type : Variable(shape: [batch_size, max_src_len, max_src_len]) | ||||
| @@ -570,7 +570,6 @@ class multiwoz22Processor(DSTProcessor): | |||||
| def delex_utt(self, utt, values, unk_token='[UNK]'): | def delex_utt(self, utt, values, unk_token='[UNK]'): | ||||
| utt_norm = self.tokenize(utt) | utt_norm = self.tokenize(utt) | ||||
| for s, vals in values.items(): | for s, vals in values.items(): | ||||
| # TODO vals可能不是数组形式,而是初始化的字符串"none" | |||||
| for v in vals: | for v in vals: | ||||
| if v != 'none': | if v != 'none': | ||||
| v_norm = self.tokenize(v) | v_norm = self.tokenize(v) | ||||
| @@ -36,18 +36,10 @@ class BPETextField(object): | |||||
| @property | @property | ||||
| def bot_id(self): | def bot_id(self): | ||||
| """ | |||||
| 用于区分user和bot两个角色 | |||||
| 1和0不是词表中的index,而是专门针对role的index,大小就为2,对应超参数'num_type_embeddings' | |||||
| """ | |||||
| return 0 | return 0 | ||||
| @property | @property | ||||
| def user_id(self): | def user_id(self): | ||||
| """ | |||||
| 用于区分user和bot两个角色 | |||||
| 1和0不是词表中的index,而是专门针对role的index,大小就为2,对应超参数'num_type_embeddings' | |||||
| """ | |||||
| return 1 | return 1 | ||||
| @property | @property | ||||
| @@ -186,7 +178,7 @@ class BPETextField(object): | |||||
| ] | ] | ||||
| src_role.append(list(chain(*role))[-self.max_len:]) | src_role.append(list(chain(*role))[-self.max_len:]) | ||||
| # src端序列和tgt端序列需要分开pad,以保证解码时第一个词对齐 | |||||
| # src sequence and tgt sequence should be padded separately,to make sure the first word is aligned | |||||
| src_token = list2np(src_token, padding=self.pad_id) | src_token = list2np(src_token, padding=self.pad_id) | ||||
| src_pos = list2np(src_pos, padding=self.pad_id) | src_pos = list2np(src_pos, padding=self.pad_id) | ||||
| src_turn = list2np(src_turn, padding=self.pad_id) | src_turn = list2np(src_turn, padding=self.pad_id) | ||||
| @@ -439,7 +431,7 @@ class MultiWOZBPETextField(BPETextField): | |||||
| # logging.info(log_str) | # logging.info(log_str) | ||||
| # cfg.num_training_steps = num_training_steps * cfg.epoch_num | # cfg.num_training_steps = num_training_steps * cfg.epoch_num | ||||
| self.set_stats[set_name][ | self.set_stats[set_name][ | ||||
| 'num_training_steps_per_epoch'] = num_training_steps # turn-level的steps | |||||
| 'num_training_steps_per_epoch'] = num_training_steps # turn-level steps | |||||
| self.set_stats[set_name]['num_turns'] = num_turns | self.set_stats[set_name]['num_turns'] = num_turns | ||||
| self.set_stats[set_name]['num_dials'] = num_dials | self.set_stats[set_name]['num_dials'] = num_dials | ||||
| @@ -548,9 +540,6 @@ class MultiWOZBPETextField(BPETextField): | |||||
| def convert_batch_turn(self, turn_batch, pv_batch, first_turn=False): | def convert_batch_turn(self, turn_batch, pv_batch, first_turn=False): | ||||
| """ | """ | ||||
| URURU:这里的含义是指轮级别的训练(数据整理),区别于session级别的训练方式(convert_batch_session); | |||||
| 但不同于eval时的含义,eval时二者都是逐轮依次生成的,那时URURU的含义请见相关的函数注释; | |||||
| convert the current and the last turn | convert the current and the last turn | ||||
| concat [U_0,R_0,...,U_{t-1}, R_{t-1}, U_t, B_t, A_t, R_t] | concat [U_0,R_0,...,U_{t-1}, R_{t-1}, U_t, B_t, A_t, R_t] | ||||
| firts turn: [U_t, B_t, A_t, R_t] | firts turn: [U_t, B_t, A_t, R_t] | ||||
| @@ -154,18 +154,10 @@ class BPETextField(object): | |||||
| @property | @property | ||||
| def bot_id(self): | def bot_id(self): | ||||
| """ | |||||
| 用于区分user和bot两个角色 | |||||
| 1和0不是词表中的index,而是专门针对role的index,大小就为2,对应超参数'num_type_embeddings' | |||||
| """ | |||||
| return 0 | return 0 | ||||
| @property | @property | ||||
| def user_id(self): | def user_id(self): | ||||
| """ | |||||
| 用于区分user和bot两个角色 | |||||
| 1和0不是词表中的index,而是专门针对role的index,大小就为2,对应超参数'num_type_embeddings' | |||||
| """ | |||||
| return 1 | return 1 | ||||
| def add_sepcial_tokens(self): | def add_sepcial_tokens(self): | ||||
| @@ -862,7 +854,6 @@ class BPETextField(object): | |||||
| ] | ] | ||||
| src_role.append(list(chain(*role))[-self.max_len:]) | src_role.append(list(chain(*role))[-self.max_len:]) | ||||
| # src端序列和tgt端序列需要分开pad,以保证解码时第一个词对齐 | |||||
| src_token = list2np(src_token, padding=self.pad_id) | src_token = list2np(src_token, padding=self.pad_id) | ||||
| src_pos = list2np(src_pos, padding=self.pad_id) | src_pos = list2np(src_pos, padding=self.pad_id) | ||||
| src_turn = list2np(src_turn, padding=self.pad_id) | src_turn = list2np(src_turn, padding=self.pad_id) | ||||
| @@ -1038,7 +1029,6 @@ class IntentBPETextField(BPETextField): | |||||
| ] * l for i, l in enumerate(utt_lens)] | ] * l for i, l in enumerate(utt_lens)] | ||||
| src_role.append(list(chain(*role))[-self.max_len:]) | src_role.append(list(chain(*role))[-self.max_len:]) | ||||
| # src端序列和tgt端序列需要分开pad,以保证解码时第一个词对齐 | |||||
| src_token = list2np(src_token, padding=self.pad_id) | src_token = list2np(src_token, padding=self.pad_id) | ||||
| src_pos = list2np(src_pos, padding=self.pad_id) | src_pos = list2np(src_pos, padding=self.pad_id) | ||||
| src_turn = list2np(src_turn, padding=self.pad_id) | src_turn = list2np(src_turn, padding=self.pad_id) | ||||
| @@ -56,10 +56,6 @@ class Tokenizer(object): | |||||
| self._tokenizer = BertTokenizer( | self._tokenizer = BertTokenizer( | ||||
| vocab_path, never_split=self.special_tokens) | vocab_path, never_split=self.special_tokens) | ||||
| for tok in self.special_tokens: | for tok in self.special_tokens: | ||||
| ''' | |||||
| 需要先保证special_tokens在词表中,这里设置special_tokens的目的是为了这些词能够完整占位,不再切分为子词; | |||||
| 若不在词表中,可以使用词表中的[unused]符号进行转换:spec_convert_dict; | |||||
| ''' | |||||
| assert tok in self._tokenizer.vocab, f"special token '{tok}' is not in the vocabulary" | assert tok in self._tokenizer.vocab, f"special token '{tok}' is not in the vocabulary" | ||||
| self.vocab_size = len(self._tokenizer.vocab) | self.vocab_size = len(self._tokenizer.vocab) | ||||
| elif tokenizer_type == 'GPT2': | elif tokenizer_type == 'GPT2': | ||||
| @@ -10,8 +10,8 @@ class MetricsTracker(object): | |||||
| """ Tracking metrics. """ | """ Tracking metrics. """ | ||||
| def __init__(self): | def __init__(self): | ||||
| self.metrics_val = defaultdict(float) # 记录最新一个batch返回的指标 | |||||
| self.metrics_avg = defaultdict(float) # 维护一个epoch内已训练batches的平均指标 | |||||
| self.metrics_val = defaultdict(float) # for one batch | |||||
| self.metrics_avg = defaultdict(float) # avg batches | |||||
| self.num_samples = 0 | self.num_samples = 0 | ||||
| def update(self, metrics, num_samples): | def update(self, metrics, num_samples): | ||||
| @@ -563,7 +563,7 @@ class MultiWOZTrainer(Trainer): | |||||
| generated_bs = outputs[0].cpu().numpy().tolist() | generated_bs = outputs[0].cpu().numpy().tolist() | ||||
| bspn_gen = self.decode_generated_bspn(generated_bs) | bspn_gen = self.decode_generated_bspn(generated_bs) | ||||
| # check DB result | # check DB result | ||||
| if self.reader.use_true_db_pointer: # 控制当前轮的db是否为ground truth | |||||
| if self.reader.use_true_db_pointer: # To control whether current db is ground truth | |||||
| db = turn['db'] | db = turn['db'] | ||||
| else: | else: | ||||
| db_result = self.reader.bspan_to_DBpointer( | db_result = self.reader.bspan_to_DBpointer( | ||||
| @@ -314,18 +314,18 @@ class IntentTrainer(Trainer): | |||||
| self.can_norm = config.Trainer.can_norm | self.can_norm = config.Trainer.can_norm | ||||
| def can_normalization(self, y_pred, y_true, ex_data_iter): | def can_normalization(self, y_pred, y_true, ex_data_iter): | ||||
| # 预测结果,计算修正前准确率 | |||||
| # compute ACC | |||||
| acc_original = np.mean([y_pred.argmax(1) == y_true]) | acc_original = np.mean([y_pred.argmax(1) == y_true]) | ||||
| message = 'original acc: %s' % acc_original | message = 'original acc: %s' % acc_original | ||||
| # 评价每个预测结果的不确定性 | |||||
| # compute uncertainty | |||||
| k = 3 | k = 3 | ||||
| y_pred_topk = np.sort(y_pred, axis=1)[:, -k:] | y_pred_topk = np.sort(y_pred, axis=1)[:, -k:] | ||||
| y_pred_topk /= y_pred_topk.sum(axis=1, keepdims=True) | y_pred_topk /= y_pred_topk.sum(axis=1, keepdims=True) | ||||
| y_pred_uncertainty =\ | y_pred_uncertainty =\ | ||||
| -(y_pred_topk * np.log(y_pred_topk)).sum(1) / np.log(k) | -(y_pred_topk * np.log(y_pred_topk)).sum(1) / np.log(k) | ||||
| # 选择阈值,划分高、低置信度两部分 | |||||
| # choose threshold | |||||
| # print(np.sort(y_pred_uncertainty)[-100:].tolist()) | # print(np.sort(y_pred_uncertainty)[-100:].tolist()) | ||||
| threshold = 0.7 | threshold = 0.7 | ||||
| y_pred_confident = y_pred[y_pred_uncertainty < threshold] | y_pred_confident = y_pred[y_pred_uncertainty < threshold] | ||||
| @@ -333,8 +333,7 @@ class IntentTrainer(Trainer): | |||||
| y_true_confident = y_true[y_pred_uncertainty < threshold] | y_true_confident = y_true[y_pred_uncertainty < threshold] | ||||
| y_true_unconfident = y_true[y_pred_uncertainty >= threshold] | y_true_unconfident = y_true[y_pred_uncertainty >= threshold] | ||||
| # 显示两部分各自的准确率 | |||||
| # 一般而言,高置信度集准确率会远高于低置信度的 | |||||
| # compute ACC again for high and low confidence sets | |||||
| acc_confident = (y_pred_confident.argmax(1) == y_true_confident).mean() \ | acc_confident = (y_pred_confident.argmax(1) == y_true_confident).mean() \ | ||||
| if len(y_true_confident) else 0. | if len(y_true_confident) else 0. | ||||
| acc_unconfident = (y_pred_unconfident.argmax(1) == y_true_unconfident).mean() \ | acc_unconfident = (y_pred_unconfident.argmax(1) == y_true_unconfident).mean() \ | ||||
| @@ -344,7 +343,7 @@ class IntentTrainer(Trainer): | |||||
| message += ' (%s) unconfident acc: %s' % (len(y_true_unconfident), | message += ' (%s) unconfident acc: %s' % (len(y_true_unconfident), | ||||
| acc_unconfident) | acc_unconfident) | ||||
| # 从训练集统计先验分布 | |||||
| # get prior distribution from training set | |||||
| prior = np.zeros(self.func_model.num_intent) | prior = np.zeros(self.func_model.num_intent) | ||||
| for _, (batch, batch_size) in ex_data_iter: | for _, (batch, batch_size) in ex_data_iter: | ||||
| for intent_label in batch['intent_label']: | for intent_label in batch['intent_label']: | ||||
| @@ -352,7 +351,7 @@ class IntentTrainer(Trainer): | |||||
| prior /= prior.sum() | prior /= prior.sum() | ||||
| # 逐个修改低置信度样本,并重新评价准确率 | |||||
| # revise each sample from the low confidence set, and compute new ACC | |||||
| right, alpha, iters = 0, 1, 1 | right, alpha, iters = 0, 1, 1 | ||||
| for i, y in enumerate(y_pred_unconfident): | for i, y in enumerate(y_pred_unconfident): | ||||
| Y = np.concatenate([y_pred_confident, y[None]], axis=0) | Y = np.concatenate([y_pred_confident, y[None]], axis=0) | ||||
| @@ -365,7 +364,7 @@ class IntentTrainer(Trainer): | |||||
| if y.argmax() == y_true_unconfident[i]: | if y.argmax() == y_true_unconfident[i]: | ||||
| right += 1 | right += 1 | ||||
| # 输出修正后的准确率 | |||||
| # get final ACC | |||||
| acc_final = \ | acc_final = \ | ||||
| (acc_confident * len(y_pred_confident) + right) / \ | (acc_confident * len(y_pred_confident) + right) / \ | ||||
| len(y_pred) | len(y_pred) | ||||
| @@ -172,8 +172,8 @@ class MultiWozDB(object): | |||||
| continue | continue | ||||
| if s in ['people', 'stay'] or (domain == 'hotel' and s == 'day') or \ | if s in ['people', 'stay'] or (domain == 'hotel' and s == 'day') or \ | ||||
| (domain == 'restaurant' and s in ['day', 'time']): | (domain == 'restaurant' and s in ['day', 'time']): | ||||
| # 因为这些inform slot属于book info,而数据库中没有这些slot; | |||||
| # 能否book是根据user goal中的信息判断,而非通过数据库查询; | |||||
| # These inform slots belong to "book info",which do not exist in DB | |||||
| # "book" is according to the user goal,not DB | |||||
| continue | continue | ||||
| skip_case = { | skip_case = { | ||||