| @@ -129,9 +129,9 @@ class GenUnifiedTransformer(UnifiedTransformer): | |||
| enc_out = src_embed | |||
| cache = {} | |||
| for l, layer in enumerate(self.layers): | |||
| cache[f'layer_{l}'] = {} | |||
| enc_out = layer(enc_out, mask, cache[f'layer_{l}']) | |||
| for _l, layer in enumerate(self.layers): | |||
| cache[f'layer_{_l}'] = {} | |||
| enc_out = layer(enc_out, mask, cache[f'layer_{_l}']) | |||
| state['cache'] = cache | |||
| state['mask'] = mask[:, :1] | |||
| @@ -176,9 +176,9 @@ class GenUnifiedTransformer(UnifiedTransformer): | |||
| mask = self._join_mask(enc_mask, dec_mask) | |||
| cache = {} | |||
| for l, layer in enumerate(self.layers): | |||
| cache[f'layer_{l}'] = {} | |||
| enc_out = layer(enc_out, mask, cache[f'layer_{l}']) | |||
| for _l, layer in enumerate(self.layers): | |||
| cache[f'layer_{_l}'] = {} | |||
| enc_out = layer(enc_out, mask, cache[f'layer_{_l}']) | |||
| state['cache'] = cache | |||
| state['mask'] = mask[:, -1:] # state["mask"] = mask[:, :1] | |||
| @@ -220,8 +220,8 @@ class GenUnifiedTransformer(UnifiedTransformer): | |||
| mask = torch.cat([mask, 1 - pred_mask], dim=2) | |||
| # shape: [batch_size, 1, hidden_dim] | |||
| for l, layer in enumerate(self.layers): | |||
| pred_embed = layer(pred_embed, mask, cache[f'layer_{l}']) | |||
| for _l, layer in enumerate(self.layers): | |||
| pred_embed = layer(pred_embed, mask, cache[f'layer_{_l}']) | |||
| # shape: [batch_size, vocab_size] | |||
| pred_probs = self._dec_head(dec_embed=pred_embed[:, 0]) | |||
| @@ -101,11 +101,11 @@ class IntentUnifiedTransformer(UnifiedTransformer): | |||
| if self.with_contrastive: | |||
| features = features if self.with_pool else self.pooler(features) | |||
| batch_size = features.size(0) // 2 | |||
| features = torch.cat([ | |||
| features[:batch_size].unsqueeze(1), | |||
| features[batch_size:].unsqueeze(1) | |||
| ], | |||
| dim=1) | |||
| features = \ | |||
| torch.cat( | |||
| [features[:batch_size].unsqueeze(1), features[batch_size:].unsqueeze(1)], | |||
| dim=1 | |||
| ) | |||
| features = F.normalize(features, dim=-1, p=2) | |||
| outputs['features'] = features | |||
| @@ -202,11 +202,11 @@ class UnifiedTransformer(ModelBase): | |||
| def _refactor_feature(self, features): | |||
| features = self.pooler(features) if self.with_pool else features | |||
| batch_size = features.size(0) // 2 | |||
| features = torch.cat([ | |||
| features[:batch_size].unsqueeze(1), | |||
| features[batch_size:].unsqueeze(1) | |||
| ], | |||
| dim=1) | |||
| features = \ | |||
| torch.cat( | |||
| [features[:batch_size].unsqueeze(1), features[batch_size:].unsqueeze(1)], | |||
| dim=1 | |||
| ) | |||
| features = F.normalize(features, dim=-1, p=2) | |||
| return features | |||
| @@ -19,9 +19,9 @@ class MetricsTracker(object): | |||
| if val is not None: | |||
| val = float(val) # [val] -> val | |||
| self.metrics_val[key] = val | |||
| avg_val = (self.metrics_avg.get(key, 0) * self.num_samples + | |||
| val * num_samples) / ( | |||
| self.num_samples + num_samples) | |||
| avg_val = \ | |||
| (self.metrics_avg.get(key, 0) * self.num_samples + val * num_samples) / \ | |||
| (self.num_samples + num_samples) | |||
| self.metrics_avg[key] = avg_val | |||
| self.num_samples += num_samples | |||
| @@ -117,15 +117,15 @@ class Trainer(object): | |||
| decoded = {} | |||
| eos_a_id = self.reader.eos_a_id | |||
| eos_r_id = self.reader.eos_r_id | |||
| eos_b_id = self.reader.eos_b_id | |||
| # eos_b_id = self.reader.eos_b_id | |||
| # eos_r may not exists if gpt2 generated repetitive words. | |||
| if eos_r_id in generated: | |||
| eos_r_idx = generated.index(eos_r_id) | |||
| else: | |||
| eos_r_idx = len(generated) - 1 | |||
| self.logger.info('eos_r not in generated: ' + | |||
| self.tokenizer.decode(generated)) | |||
| msg = 'eos_r not in generated: ' + self.tokenizer.decode(generated) | |||
| self.logger.info(msg) | |||
| if self.reader.use_true_curr_aspn: # only predict resp | |||
| decoded['resp'] = generated[:eos_r_idx + 1] | |||
| @@ -173,8 +173,10 @@ class Trainer(object): | |||
| ] | |||
| optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr) | |||
| num_training_steps = self.reader.set_stats['train']['num_training_steps_per_epoch'] * \ | |||
| self.num_epochs // self.gradient_accumulation_steps | |||
| num_training_steps = \ | |||
| self.reader.set_stats['train']['num_training_steps_per_epoch'] \ | |||
| * self.num_epochs \ | |||
| // self.gradient_accumulation_steps | |||
| num_warmup_steps = self.warmup_steps if self.warmup_steps >= 0 else int( | |||
| num_training_steps * 0.1) | |||
| lr_scheduler = get_linear_schedule_with_warmup( | |||
| @@ -198,10 +200,10 @@ class Trainer(object): | |||
| self.logger.info(' Batch size = %d', self.batch_size) | |||
| self.logger.info(' Gradient Accumulation steps = %d', | |||
| self.gradient_accumulation_steps) | |||
| self.logger.info( | |||
| ' Total optimization steps = %d', | |||
| set_stats['num_training_steps_per_epoch'] * self.num_epochs // | |||
| self.gradient_accumulation_steps) | |||
| steps = set_stats[ | |||
| 'num_training_steps_per_epoch'] * self.num_epochs // self.gradient_accumulation_steps | |||
| msg = ' Total optimization steps = %d' % steps | |||
| self.logger.info(msg) | |||
| # begin training | |||
| num_epochs = self.num_epochs - self.epoch | |||
| @@ -346,10 +348,10 @@ class Trainer(object): | |||
| f"Loaded train state from '{train_file}' with (epoch-{self.epoch} " | |||
| f'best_valid_metric={self.best_valid_metric:.3f})') | |||
| else: | |||
| self.logger.info(f'Loaded no train state') | |||
| self.logger.info('Loaded no train state') | |||
| if self.func_model.init_checkpoint is None: | |||
| self.logger.info(f'Loaded no model !!!') | |||
| self.logger.info('Loaded no model !!!') | |||
| return | |||
| if self.do_train: | |||
| @@ -388,8 +390,9 @@ class MultiWOZTrainer(Trainer): | |||
| self.epoch += 1 | |||
| self.batch_metrics_tracker.clear() | |||
| self.token_metrics_tracker.clear() | |||
| num_training_steps = self.reader.set_stats['train']['num_training_steps_per_epoch'] // \ | |||
| self.gradient_accumulation_steps # similar to the original num_batches | |||
| num_training_steps = \ | |||
| self.reader.set_stats['train']['num_training_steps_per_epoch'] // \ | |||
| self.gradient_accumulation_steps # similar to the original num_batches | |||
| self.model.zero_grad() | |||
| data_iterator = self.reader.get_data_iterator(all_batches=train_data) | |||
| @@ -417,8 +420,9 @@ class MultiWOZTrainer(Trainer): | |||
| metrics = {} | |||
| token_num = torch.sum(token_num) | |||
| token_nll = torch.sum(nll) * (batch_size / | |||
| self.gpu) / token_num | |||
| token_nll = \ | |||
| torch.sum(nll) * (batch_size / self.gpu) / \ | |||
| token_num | |||
| nll = torch.mean(nll) | |||
| metrics['token_num'] = token_num | |||
| metrics['token_nll'] = token_nll | |||
| @@ -567,10 +571,11 @@ class MultiWOZTrainer(Trainer): | |||
| assert len(turn['db']) == 4 | |||
| book_result = turn['db'][2] | |||
| assert isinstance(db_result, str) | |||
| db = [self.reader.sos_db_id] + \ | |||
| self.tokenizer.convert_tokens_to_ids([db_result]) + \ | |||
| [book_result] + \ | |||
| [self.reader.eos_db_id] | |||
| db = \ | |||
| [self.reader.sos_db_id] + \ | |||
| self.tokenizer.convert_tokens_to_ids([db_result]) + \ | |||
| [book_result] + \ | |||
| [self.reader.eos_db_id] | |||
| prompt_id = self.reader.sos_a_id | |||
| prev_input = torch.tensor(bspn_gen + db) | |||
| @@ -694,10 +699,11 @@ class MultiWOZTrainer(Trainer): | |||
| self.tokenizer.decode(bspn_gen), ['[taxi]']) | |||
| print(db_result) | |||
| book_result = 21 | |||
| db = [self.reader.sos_db_id] + \ | |||
| self.tokenizer.convert_tokens_to_ids([db_result]) + \ | |||
| [book_result] + \ | |||
| [self.reader.eos_db_id] | |||
| db = \ | |||
| [self.reader.sos_db_id] + \ | |||
| self.tokenizer.convert_tokens_to_ids([db_result]) + \ | |||
| [book_result] + \ | |||
| [self.reader.eos_db_id] | |||
| prompt_id = self.reader.sos_a_id | |||
| prev_input = torch.tensor(bspn_gen + db) | |||
| @@ -148,7 +148,7 @@ class Trainer(object): | |||
| self.batch_size_nolabel) | |||
| self.logger.info(' Total optimization steps = %d', num_training_steps) | |||
| self.logger.info(' Total warmup steps = %d', num_warmup_steps) | |||
| self.logger.info(f'************************************') | |||
| self.logger.info('************************************') | |||
| def train(self, | |||
| train_label_iter, | |||
| @@ -298,10 +298,10 @@ class Trainer(object): | |||
| f"Loaded train state from '{train_file}' with (epoch-{self.epoch} " | |||
| f'best_valid_metric={self.best_valid_metric:.3f})') | |||
| else: | |||
| self.logger.info(f'Loaded no train state') | |||
| self.logger.info('Loaded no train state') | |||
| if self.func_model.init_checkpoint is None: | |||
| self.logger.info(f'Loaded no model !!!') | |||
| self.logger.info('Loaded no model !!!') | |||
| return | |||
| _load_model_state() | |||
| @@ -324,8 +324,8 @@ class IntentTrainer(Trainer): | |||
| k = 3 | |||
| y_pred_topk = np.sort(y_pred, axis=1)[:, -k:] | |||
| y_pred_topk /= y_pred_topk.sum(axis=1, keepdims=True) | |||
| y_pred_uncertainty = -(y_pred_topk * | |||
| np.log(y_pred_topk)).sum(1) / np.log(k) | |||
| y_pred_uncertainty =\ | |||
| -(y_pred_topk * np.log(y_pred_topk)).sum(1) / np.log(k) | |||
| # 选择阈值,划分高、低置信度两部分 | |||
| # print(np.sort(y_pred_uncertainty)[-100:].tolist()) | |||
| @@ -368,8 +368,9 @@ class IntentTrainer(Trainer): | |||
| right += 1 | |||
| # 输出修正后的准确率 | |||
| acc_final = (acc_confident * len(y_pred_confident) + | |||
| right) / len(y_pred) | |||
| acc_final = \ | |||
| (acc_confident * len(y_pred_confident) + right) / \ | |||
| len(y_pred) | |||
| if len(y_pred_unconfident): | |||
| message += ' new unconfident acc: %s' % ( | |||
| right / len(y_pred_unconfident)) | |||
| @@ -508,7 +509,7 @@ class IntentTrainer(Trainer): | |||
| report_for_unlabeled_data, cur_valid_metric=-accuracy) | |||
| def forward(self, batch): | |||
| pred, true = [], [] | |||
| pred = [] | |||
| with torch.no_grad(): | |||
| batch = type(batch)( | |||
| @@ -808,10 +809,10 @@ class IntentTrainer(Trainer): | |||
| f"Loaded train state from '{train_file}' with (epoch-{self.epoch} " | |||
| f'best_valid_metric={self.best_valid_metric:.3f})') | |||
| else: | |||
| self.logger.info(f'Loaded no train state') | |||
| self.logger.info('Loaded no train state') | |||
| if self.func_model.init_checkpoint is None: | |||
| self.logger.info(f'Loaded no model !!!') | |||
| self.logger.info('Loaded no model !!!') | |||
| return | |||
| if self.do_train: | |||