| @@ -129,9 +129,9 @@ class GenUnifiedTransformer(UnifiedTransformer): | |||||
| enc_out = src_embed | enc_out = src_embed | ||||
| cache = {} | cache = {} | ||||
| for l, layer in enumerate(self.layers): | |||||
| cache[f'layer_{l}'] = {} | |||||
| enc_out = layer(enc_out, mask, cache[f'layer_{l}']) | |||||
| for _l, layer in enumerate(self.layers): | |||||
| cache[f'layer_{_l}'] = {} | |||||
| enc_out = layer(enc_out, mask, cache[f'layer_{_l}']) | |||||
| state['cache'] = cache | state['cache'] = cache | ||||
| state['mask'] = mask[:, :1] | state['mask'] = mask[:, :1] | ||||
| @@ -176,9 +176,9 @@ class GenUnifiedTransformer(UnifiedTransformer): | |||||
| mask = self._join_mask(enc_mask, dec_mask) | mask = self._join_mask(enc_mask, dec_mask) | ||||
| cache = {} | cache = {} | ||||
| for l, layer in enumerate(self.layers): | |||||
| cache[f'layer_{l}'] = {} | |||||
| enc_out = layer(enc_out, mask, cache[f'layer_{l}']) | |||||
| for _l, layer in enumerate(self.layers): | |||||
| cache[f'layer_{_l}'] = {} | |||||
| enc_out = layer(enc_out, mask, cache[f'layer_{_l}']) | |||||
| state['cache'] = cache | state['cache'] = cache | ||||
| state['mask'] = mask[:, -1:] # state["mask"] = mask[:, :1] | state['mask'] = mask[:, -1:] # state["mask"] = mask[:, :1] | ||||
| @@ -220,8 +220,8 @@ class GenUnifiedTransformer(UnifiedTransformer): | |||||
| mask = torch.cat([mask, 1 - pred_mask], dim=2) | mask = torch.cat([mask, 1 - pred_mask], dim=2) | ||||
| # shape: [batch_size, 1, hidden_dim] | # shape: [batch_size, 1, hidden_dim] | ||||
| for l, layer in enumerate(self.layers): | |||||
| pred_embed = layer(pred_embed, mask, cache[f'layer_{l}']) | |||||
| for _l, layer in enumerate(self.layers): | |||||
| pred_embed = layer(pred_embed, mask, cache[f'layer_{_l}']) | |||||
| # shape: [batch_size, vocab_size] | # shape: [batch_size, vocab_size] | ||||
| pred_probs = self._dec_head(dec_embed=pred_embed[:, 0]) | pred_probs = self._dec_head(dec_embed=pred_embed[:, 0]) | ||||
| @@ -101,11 +101,11 @@ class IntentUnifiedTransformer(UnifiedTransformer): | |||||
| if self.with_contrastive: | if self.with_contrastive: | ||||
| features = features if self.with_pool else self.pooler(features) | features = features if self.with_pool else self.pooler(features) | ||||
| batch_size = features.size(0) // 2 | batch_size = features.size(0) // 2 | ||||
| features = torch.cat([ | |||||
| features[:batch_size].unsqueeze(1), | |||||
| features[batch_size:].unsqueeze(1) | |||||
| ], | |||||
| dim=1) | |||||
| features = \ | |||||
| torch.cat( | |||||
| [features[:batch_size].unsqueeze(1), features[batch_size:].unsqueeze(1)], | |||||
| dim=1 | |||||
| ) | |||||
| features = F.normalize(features, dim=-1, p=2) | features = F.normalize(features, dim=-1, p=2) | ||||
| outputs['features'] = features | outputs['features'] = features | ||||
| @@ -202,11 +202,11 @@ class UnifiedTransformer(ModelBase): | |||||
| def _refactor_feature(self, features): | def _refactor_feature(self, features): | ||||
| features = self.pooler(features) if self.with_pool else features | features = self.pooler(features) if self.with_pool else features | ||||
| batch_size = features.size(0) // 2 | batch_size = features.size(0) // 2 | ||||
| features = torch.cat([ | |||||
| features[:batch_size].unsqueeze(1), | |||||
| features[batch_size:].unsqueeze(1) | |||||
| ], | |||||
| dim=1) | |||||
| features = \ | |||||
| torch.cat( | |||||
| [features[:batch_size].unsqueeze(1), features[batch_size:].unsqueeze(1)], | |||||
| dim=1 | |||||
| ) | |||||
| features = F.normalize(features, dim=-1, p=2) | features = F.normalize(features, dim=-1, p=2) | ||||
| return features | return features | ||||
| @@ -19,9 +19,9 @@ class MetricsTracker(object): | |||||
| if val is not None: | if val is not None: | ||||
| val = float(val) # [val] -> val | val = float(val) # [val] -> val | ||||
| self.metrics_val[key] = val | self.metrics_val[key] = val | ||||
| avg_val = (self.metrics_avg.get(key, 0) * self.num_samples + | |||||
| val * num_samples) / ( | |||||
| self.num_samples + num_samples) | |||||
| avg_val = \ | |||||
| (self.metrics_avg.get(key, 0) * self.num_samples + val * num_samples) / \ | |||||
| (self.num_samples + num_samples) | |||||
| self.metrics_avg[key] = avg_val | self.metrics_avg[key] = avg_val | ||||
| self.num_samples += num_samples | self.num_samples += num_samples | ||||
| @@ -117,15 +117,15 @@ class Trainer(object): | |||||
| decoded = {} | decoded = {} | ||||
| eos_a_id = self.reader.eos_a_id | eos_a_id = self.reader.eos_a_id | ||||
| eos_r_id = self.reader.eos_r_id | eos_r_id = self.reader.eos_r_id | ||||
| eos_b_id = self.reader.eos_b_id | |||||
| # eos_b_id = self.reader.eos_b_id | |||||
| # eos_r may not exists if gpt2 generated repetitive words. | # eos_r may not exists if gpt2 generated repetitive words. | ||||
| if eos_r_id in generated: | if eos_r_id in generated: | ||||
| eos_r_idx = generated.index(eos_r_id) | eos_r_idx = generated.index(eos_r_id) | ||||
| else: | else: | ||||
| eos_r_idx = len(generated) - 1 | eos_r_idx = len(generated) - 1 | ||||
| self.logger.info('eos_r not in generated: ' + | |||||
| self.tokenizer.decode(generated)) | |||||
| msg = 'eos_r not in generated: ' + self.tokenizer.decode(generated) | |||||
| self.logger.info(msg) | |||||
| if self.reader.use_true_curr_aspn: # only predict resp | if self.reader.use_true_curr_aspn: # only predict resp | ||||
| decoded['resp'] = generated[:eos_r_idx + 1] | decoded['resp'] = generated[:eos_r_idx + 1] | ||||
| @@ -173,8 +173,10 @@ class Trainer(object): | |||||
| ] | ] | ||||
| optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr) | optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr) | ||||
| num_training_steps = self.reader.set_stats['train']['num_training_steps_per_epoch'] * \ | |||||
| self.num_epochs // self.gradient_accumulation_steps | |||||
| num_training_steps = \ | |||||
| self.reader.set_stats['train']['num_training_steps_per_epoch'] \ | |||||
| * self.num_epochs \ | |||||
| // self.gradient_accumulation_steps | |||||
| num_warmup_steps = self.warmup_steps if self.warmup_steps >= 0 else int( | num_warmup_steps = self.warmup_steps if self.warmup_steps >= 0 else int( | ||||
| num_training_steps * 0.1) | num_training_steps * 0.1) | ||||
| lr_scheduler = get_linear_schedule_with_warmup( | lr_scheduler = get_linear_schedule_with_warmup( | ||||
| @@ -198,10 +200,10 @@ class Trainer(object): | |||||
| self.logger.info(' Batch size = %d', self.batch_size) | self.logger.info(' Batch size = %d', self.batch_size) | ||||
| self.logger.info(' Gradient Accumulation steps = %d', | self.logger.info(' Gradient Accumulation steps = %d', | ||||
| self.gradient_accumulation_steps) | self.gradient_accumulation_steps) | ||||
| self.logger.info( | |||||
| ' Total optimization steps = %d', | |||||
| set_stats['num_training_steps_per_epoch'] * self.num_epochs // | |||||
| self.gradient_accumulation_steps) | |||||
| steps = set_stats[ | |||||
| 'num_training_steps_per_epoch'] * self.num_epochs // self.gradient_accumulation_steps | |||||
| msg = ' Total optimization steps = %d' % steps | |||||
| self.logger.info(msg) | |||||
| # begin training | # begin training | ||||
| num_epochs = self.num_epochs - self.epoch | num_epochs = self.num_epochs - self.epoch | ||||
| @@ -346,10 +348,10 @@ class Trainer(object): | |||||
| f"Loaded train state from '{train_file}' with (epoch-{self.epoch} " | f"Loaded train state from '{train_file}' with (epoch-{self.epoch} " | ||||
| f'best_valid_metric={self.best_valid_metric:.3f})') | f'best_valid_metric={self.best_valid_metric:.3f})') | ||||
| else: | else: | ||||
| self.logger.info(f'Loaded no train state') | |||||
| self.logger.info('Loaded no train state') | |||||
| if self.func_model.init_checkpoint is None: | if self.func_model.init_checkpoint is None: | ||||
| self.logger.info(f'Loaded no model !!!') | |||||
| self.logger.info('Loaded no model !!!') | |||||
| return | return | ||||
| if self.do_train: | if self.do_train: | ||||
| @@ -388,8 +390,9 @@ class MultiWOZTrainer(Trainer): | |||||
| self.epoch += 1 | self.epoch += 1 | ||||
| self.batch_metrics_tracker.clear() | self.batch_metrics_tracker.clear() | ||||
| self.token_metrics_tracker.clear() | self.token_metrics_tracker.clear() | ||||
| num_training_steps = self.reader.set_stats['train']['num_training_steps_per_epoch'] // \ | |||||
| self.gradient_accumulation_steps # similar to the original num_batches | |||||
| num_training_steps = \ | |||||
| self.reader.set_stats['train']['num_training_steps_per_epoch'] // \ | |||||
| self.gradient_accumulation_steps # similar to the original num_batches | |||||
| self.model.zero_grad() | self.model.zero_grad() | ||||
| data_iterator = self.reader.get_data_iterator(all_batches=train_data) | data_iterator = self.reader.get_data_iterator(all_batches=train_data) | ||||
| @@ -417,8 +420,9 @@ class MultiWOZTrainer(Trainer): | |||||
| metrics = {} | metrics = {} | ||||
| token_num = torch.sum(token_num) | token_num = torch.sum(token_num) | ||||
| token_nll = torch.sum(nll) * (batch_size / | |||||
| self.gpu) / token_num | |||||
| token_nll = \ | |||||
| torch.sum(nll) * (batch_size / self.gpu) / \ | |||||
| token_num | |||||
| nll = torch.mean(nll) | nll = torch.mean(nll) | ||||
| metrics['token_num'] = token_num | metrics['token_num'] = token_num | ||||
| metrics['token_nll'] = token_nll | metrics['token_nll'] = token_nll | ||||
| @@ -567,10 +571,11 @@ class MultiWOZTrainer(Trainer): | |||||
| assert len(turn['db']) == 4 | assert len(turn['db']) == 4 | ||||
| book_result = turn['db'][2] | book_result = turn['db'][2] | ||||
| assert isinstance(db_result, str) | assert isinstance(db_result, str) | ||||
| db = [self.reader.sos_db_id] + \ | |||||
| self.tokenizer.convert_tokens_to_ids([db_result]) + \ | |||||
| [book_result] + \ | |||||
| [self.reader.eos_db_id] | |||||
| db = \ | |||||
| [self.reader.sos_db_id] + \ | |||||
| self.tokenizer.convert_tokens_to_ids([db_result]) + \ | |||||
| [book_result] + \ | |||||
| [self.reader.eos_db_id] | |||||
| prompt_id = self.reader.sos_a_id | prompt_id = self.reader.sos_a_id | ||||
| prev_input = torch.tensor(bspn_gen + db) | prev_input = torch.tensor(bspn_gen + db) | ||||
| @@ -694,10 +699,11 @@ class MultiWOZTrainer(Trainer): | |||||
| self.tokenizer.decode(bspn_gen), ['[taxi]']) | self.tokenizer.decode(bspn_gen), ['[taxi]']) | ||||
| print(db_result) | print(db_result) | ||||
| book_result = 21 | book_result = 21 | ||||
| db = [self.reader.sos_db_id] + \ | |||||
| self.tokenizer.convert_tokens_to_ids([db_result]) + \ | |||||
| [book_result] + \ | |||||
| [self.reader.eos_db_id] | |||||
| db = \ | |||||
| [self.reader.sos_db_id] + \ | |||||
| self.tokenizer.convert_tokens_to_ids([db_result]) + \ | |||||
| [book_result] + \ | |||||
| [self.reader.eos_db_id] | |||||
| prompt_id = self.reader.sos_a_id | prompt_id = self.reader.sos_a_id | ||||
| prev_input = torch.tensor(bspn_gen + db) | prev_input = torch.tensor(bspn_gen + db) | ||||
| @@ -148,7 +148,7 @@ class Trainer(object): | |||||
| self.batch_size_nolabel) | self.batch_size_nolabel) | ||||
| self.logger.info(' Total optimization steps = %d', num_training_steps) | self.logger.info(' Total optimization steps = %d', num_training_steps) | ||||
| self.logger.info(' Total warmup steps = %d', num_warmup_steps) | self.logger.info(' Total warmup steps = %d', num_warmup_steps) | ||||
| self.logger.info(f'************************************') | |||||
| self.logger.info('************************************') | |||||
| def train(self, | def train(self, | ||||
| train_label_iter, | train_label_iter, | ||||
| @@ -298,10 +298,10 @@ class Trainer(object): | |||||
| f"Loaded train state from '{train_file}' with (epoch-{self.epoch} " | f"Loaded train state from '{train_file}' with (epoch-{self.epoch} " | ||||
| f'best_valid_metric={self.best_valid_metric:.3f})') | f'best_valid_metric={self.best_valid_metric:.3f})') | ||||
| else: | else: | ||||
| self.logger.info(f'Loaded no train state') | |||||
| self.logger.info('Loaded no train state') | |||||
| if self.func_model.init_checkpoint is None: | if self.func_model.init_checkpoint is None: | ||||
| self.logger.info(f'Loaded no model !!!') | |||||
| self.logger.info('Loaded no model !!!') | |||||
| return | return | ||||
| _load_model_state() | _load_model_state() | ||||
| @@ -324,8 +324,8 @@ class IntentTrainer(Trainer): | |||||
| k = 3 | k = 3 | ||||
| y_pred_topk = np.sort(y_pred, axis=1)[:, -k:] | y_pred_topk = np.sort(y_pred, axis=1)[:, -k:] | ||||
| y_pred_topk /= y_pred_topk.sum(axis=1, keepdims=True) | y_pred_topk /= y_pred_topk.sum(axis=1, keepdims=True) | ||||
| y_pred_uncertainty = -(y_pred_topk * | |||||
| np.log(y_pred_topk)).sum(1) / np.log(k) | |||||
| y_pred_uncertainty =\ | |||||
| -(y_pred_topk * np.log(y_pred_topk)).sum(1) / np.log(k) | |||||
| # 选择阈值,划分高、低置信度两部分 | # 选择阈值,划分高、低置信度两部分 | ||||
| # print(np.sort(y_pred_uncertainty)[-100:].tolist()) | # print(np.sort(y_pred_uncertainty)[-100:].tolist()) | ||||
| @@ -368,8 +368,9 @@ class IntentTrainer(Trainer): | |||||
| right += 1 | right += 1 | ||||
| # 输出修正后的准确率 | # 输出修正后的准确率 | ||||
| acc_final = (acc_confident * len(y_pred_confident) + | |||||
| right) / len(y_pred) | |||||
| acc_final = \ | |||||
| (acc_confident * len(y_pred_confident) + right) / \ | |||||
| len(y_pred) | |||||
| if len(y_pred_unconfident): | if len(y_pred_unconfident): | ||||
| message += ' new unconfident acc: %s' % ( | message += ' new unconfident acc: %s' % ( | ||||
| right / len(y_pred_unconfident)) | right / len(y_pred_unconfident)) | ||||
| @@ -508,7 +509,7 @@ class IntentTrainer(Trainer): | |||||
| report_for_unlabeled_data, cur_valid_metric=-accuracy) | report_for_unlabeled_data, cur_valid_metric=-accuracy) | ||||
| def forward(self, batch): | def forward(self, batch): | ||||
| pred, true = [], [] | |||||
| pred = [] | |||||
| with torch.no_grad(): | with torch.no_grad(): | ||||
| batch = type(batch)( | batch = type(batch)( | ||||
| @@ -808,10 +809,10 @@ class IntentTrainer(Trainer): | |||||
| f"Loaded train state from '{train_file}' with (epoch-{self.epoch} " | f"Loaded train state from '{train_file}' with (epoch-{self.epoch} " | ||||
| f'best_valid_metric={self.best_valid_metric:.3f})') | f'best_valid_metric={self.best_valid_metric:.3f})') | ||||
| else: | else: | ||||
| self.logger.info(f'Loaded no train state') | |||||
| self.logger.info('Loaded no train state') | |||||
| if self.func_model.init_checkpoint is None: | if self.func_model.init_checkpoint is None: | ||||
| self.logger.info(f'Loaded no model !!!') | |||||
| self.logger.info('Loaded no model !!!') | |||||
| return | return | ||||
| if self.do_train: | if self.do_train: | ||||