|
- """
- Trainer class.
- """
-
- import logging
- import os
- import sys
- import time
- from collections import OrderedDict
-
- import json
- import numpy as np
- import torch
- from tqdm import tqdm
- from transformers.optimization import AdamW, get_linear_schedule_with_warmup
-
- from ..metrics.metrics_tracker import MetricsTracker
-
-
- def get_logger(log_path, name='default'):
- logger = logging.getLogger(name)
- logger.propagate = False
- logger.setLevel(logging.DEBUG)
-
- formatter = logging.Formatter('%(message)s')
-
- sh = logging.StreamHandler(sys.stdout)
- sh.setFormatter(formatter)
- logger.addHandler(sh)
-
- fh = logging.FileHandler(log_path, mode='w')
- fh.setFormatter(formatter)
- logger.addHandler(fh)
-
- return logger
-
-
- class Trainer(object):
-
- def __init__(self,
- model,
- to_tensor,
- config,
- reader=None,
- logger=None,
- lr_scheduler=None,
- optimizer=None):
- self.model = model
- self.to_tensor = to_tensor
- self.do_train = config.do_train
- self.do_infer = config.do_infer
-
- self.is_decreased_valid_metric = config.Trainer.valid_metric_name[
- 0] == '-'
- self.valid_metric_name = config.Trainer.valid_metric_name[1:]
- self.num_epochs = config.Trainer.num_epochs
- self.save_dir = config.Trainer.save_dir
- self.log_steps = config.Trainer.log_steps
- self.valid_steps = config.Trainer.valid_steps
- self.save_checkpoint = config.Trainer.save_checkpoint
- self.save_summary = config.Trainer.save_summary
- self.learning_method = config.Dataset.learning_method
- self.weight_decay = config.Model.weight_decay
- self.warmup_steps = config.Model.warmup_steps
- self.batch_size_label = config.Trainer.batch_size_label
- self.batch_size_nolabel = config.Trainer.batch_size_nolabel
- self.gpu = config.Trainer.gpu
- self.lr = config.Model.lr
-
- self.model = model
- self.func_model = self.model.module if self.gpu > 1 else self.model
- self.reader = reader
- self.tokenizer = reader.tokenizer
-
- self.lr_scheduler = lr_scheduler
- self.optimizer = optimizer
-
- # if not os.path.exists(self.save_dir):
- # os.makedirs(self.save_dir)
-
- # self.logger = logger or get_logger(os.path.join(self.save_dir, "trainer.log"), "trainer")
- self.logger = logger or get_logger('trainer.log', 'trainer')
-
- self.batch_metrics_tracker_label = MetricsTracker()
- self.token_metrics_tracker_label = MetricsTracker()
- self.batch_metrics_tracker_nolabel = MetricsTracker()
- self.token_metrics_tracker_nolabel = MetricsTracker()
-
- self.best_valid_metric = float(
- 'inf' if self.is_decreased_valid_metric else '-inf')
- self.epoch = 0
- self.batch_num = 0
-
- def set_optimizers(self, num_training_steps_per_epoch):
- """
- Setup the optimizer and the learning rate scheduler.
-
- from transformers.Trainer
-
- parameters from cfg: lr (1e-3); warmup_steps
- """
- # Prepare optimizer and schedule (linear warmup and decay)
- no_decay = ['bias', 'norm.weight']
- optimizer_grouped_parameters = [
- {
- 'params': [
- p for n, p in self.model.named_parameters()
- if not any(nd in n for nd in no_decay)
- ],
- 'weight_decay':
- self.weight_decay,
- },
- {
- 'params': [
- p for n, p in self.model.named_parameters()
- if any(nd in n for nd in no_decay)
- ],
- 'weight_decay':
- 0.0,
- },
- ]
- optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr)
-
- num_training_steps = num_training_steps_per_epoch * self.num_epochs
- num_warmup_steps = self.warmup_steps if self.warmup_steps >= 0 else int(
- num_training_steps * 0.1)
- lr_scheduler = get_linear_schedule_with_warmup(
- optimizer,
- num_warmup_steps=num_warmup_steps,
- num_training_steps=num_training_steps)
-
- # reset optimizer and lr_scheduler
- self.optimizer = optimizer
- self.lr_scheduler = lr_scheduler
-
- # log info
- self.logger.info(
- f'***** Running training: {self.learning_method} *****')
- self.logger.info(' Num Epochs = %d', self.num_epochs)
- self.logger.info(
- ' Num Training steps(one turn in a batch of dialogs) per epoch = %d',
- num_training_steps_per_epoch)
- self.logger.info(' Batch size for labeled data = %d',
- self.batch_size_label)
- self.logger.info(' Batch size for unlabeled data = %d',
- self.batch_size_nolabel)
- self.logger.info(' Total optimization steps = %d', num_training_steps)
- self.logger.info(' Total warmup steps = %d', num_warmup_steps)
- self.logger.info('************************************')
-
- def train(self,
- train_label_iter,
- train_nolabel_iter=None,
- valid_label_iter=None,
- valid_nolabel_iter=None):
- # begin training
- num_epochs = self.num_epochs - self.epoch
- for epoch in range(num_epochs):
- self.train_epoch(
- train_label_iter=train_label_iter,
- train_nolabel_iter=train_nolabel_iter,
- valid_label_iter=valid_label_iter,
- valid_nolabel_iter=valid_nolabel_iter)
-
- def train_epoch(self, train_label_iter, train_nolabel_iter,
- valid_label_iter, valid_nolabel_iter):
- """
- Train an epoch.
- """
- raise NotImplementedError
-
- def evaluate(self, data_label_iter, data_nolabel_iter, need_save=True):
- raise NotImplementedError
-
- def infer(self, data_iter, num_batches=None):
- raise NotImplementedError
-
- def save(self, is_best=False):
- """ save """
- train_state = {
- 'epoch': self.epoch,
- 'batch_num': self.batch_num,
- 'best_valid_metric': self.best_valid_metric,
- 'optimizer': self.optimizer.state_dict()
- }
- if self.lr_scheduler is not None:
- train_state['lr_scheduler'] = self.lr_scheduler.state_dict()
-
- # Save checkpoint
- if self.save_checkpoint:
- model_file = os.path.join(self.save_dir,
- f'state_epoch_{self.epoch}.model')
- torch.save(self.model.state_dict(), model_file)
- self.logger.info(f"Saved model state to '{model_file}'")
-
- train_file = os.path.join(self.save_dir,
- f'state_epoch_{self.epoch}.train')
- torch.save(train_state, train_file)
- self.logger.info(f"Saved train state to '{train_file}'")
-
- # Save current best model
- if is_best:
- best_model_file = os.path.join(self.save_dir, 'best.model')
- torch.save(self.model.state_dict(), best_model_file)
- best_train_file = os.path.join(self.save_dir, 'best.train')
- torch.save(train_state, best_train_file)
- self.logger.info(
- f"Saved best model state to '{best_model_file}' with new best valid metric "
- f'{self.valid_metric_name.upper()}={self.best_valid_metric:.3f}'
- )
-
- def load(self):
- """ load """
-
- def _load_model_state():
- model_state_dict = torch.load(
- f'{self.func_model.init_checkpoint}.model',
- map_location=lambda storage, loc: storage)
-
- if 'module.' in list(model_state_dict.keys())[0]:
- new_model_state_dict = OrderedDict()
- for k, v in model_state_dict.items():
- assert k[:7] == 'module.'
- new_model_state_dict[k[7:]] = v
- model_state_dict = new_model_state_dict
-
- new_model_state_dict = OrderedDict()
- parameters = {
- name: param
- for name, param in self.func_model.named_parameters()
- }
- for name, param in model_state_dict.items():
- if name in parameters:
- if param.shape != parameters[name].shape:
- assert hasattr(param, 'numpy')
- arr = param.numpy()
- z = np.random.normal(
- scale=self.func_model.initializer_range,
- size=parameters[name].shape).astype('float32')
- if name == 'embedder.token_embedding.weight':
- z[-param.shape[0]:] = arr
- print(
- f'part of parameter({name}) random normlize initialize'
- )
- else:
- if z.shape[0] < param.shape[0]:
- z = arr[:z.shape[0]]
- print(f'part of parameter({name}) are dropped')
- else:
- z[:param.shape[0]] = arr
- print(
- f'part of parameter({name}) random normlize initialize'
- )
- dtype, device = param.dtype, param.device
- z = torch.tensor(z, dtype=dtype, device=device)
- new_model_state_dict[name] = z
- else:
- new_model_state_dict[name] = param
- else:
- print(f'parameter({name}) are dropped')
- model_state_dict = new_model_state_dict
-
- for name in parameters:
- if name not in model_state_dict:
- if parameters[name].requires_grad:
- print(f'parameter({name}) random normlize initialize')
- z = np.random.normal(
- scale=self.func_model.initializer_range,
- size=parameters[name].shape).astype('float32')
- dtype, device = parameters[name].dtype, parameters[
- name].device
- model_state_dict[name] = torch.tensor(
- z, dtype=dtype, device=device)
- else:
- model_state_dict[name] = parameters[name]
-
- self.func_model.load_state_dict(model_state_dict)
- self.logger.info(
- f"Loaded model state from '{self.func_model.init_checkpoint}.model'"
- )
-
- def _load_train_state():
- train_file = f'{self.func_model.init_checkpoint}.train'
- if os.path.exists(train_file):
- train_state_dict = torch.load(
- train_file, map_location=lambda storage, loc: storage)
- self.epoch = train_state_dict['epoch']
- self.best_valid_metric = train_state_dict['best_valid_metric']
- if self.optimizer is not None and 'optimizer' in train_state_dict:
- self.optimizer.load_state_dict(
- train_state_dict['optimizer'])
- if self.lr_scheduler is not None and 'lr_scheduler' in train_state_dict:
- self.lr_scheduler.load_state_dict(
- train_state_dict['lr_scheduler'])
- self.logger.info(
- f"Loaded train state from '{train_file}' with (epoch-{self.epoch} "
- f'best_valid_metric={self.best_valid_metric:.3f})')
- else:
- self.logger.info('Loaded no train state')
-
- if self.func_model.init_checkpoint is None:
- self.logger.info('Loaded no model !!!')
- return
-
- _load_model_state()
- _load_train_state()
-
-
- class IntentTrainer(Trainer):
-
- def __init__(self, model, to_tensor, config, reader=None):
- super(IntentTrainer, self).__init__(model, to_tensor, config, reader)
- self.example = config.Model.example
- self.can_norm = config.Trainer.can_norm
-
- def can_normalization(self, y_pred, y_true, ex_data_iter):
- # 预测结果,计算修正前准确率
- acc_original = np.mean([y_pred.argmax(1) == y_true])
- message = 'original acc: %s' % acc_original
-
- # 评价每个预测结果的不确定性
- k = 3
- y_pred_topk = np.sort(y_pred, axis=1)[:, -k:]
- y_pred_topk /= y_pred_topk.sum(axis=1, keepdims=True)
- y_pred_uncertainty =\
- -(y_pred_topk * np.log(y_pred_topk)).sum(1) / np.log(k)
-
- # 选择阈值,划分高、低置信度两部分
- # print(np.sort(y_pred_uncertainty)[-100:].tolist())
- threshold = 0.7
- y_pred_confident = y_pred[y_pred_uncertainty < threshold]
- y_pred_unconfident = y_pred[y_pred_uncertainty >= threshold]
- y_true_confident = y_true[y_pred_uncertainty < threshold]
- y_true_unconfident = y_true[y_pred_uncertainty >= threshold]
-
- # 显示两部分各自的准确率
- # 一般而言,高置信度集准确率会远高于低置信度的
- acc_confident = (y_pred_confident.argmax(1) == y_true_confident).mean() \
- if len(y_true_confident) else 0.
- acc_unconfident = (y_pred_unconfident.argmax(1) == y_true_unconfident).mean() \
- if len(y_true_unconfident) else 0.
- message += ' (%s) confident acc: %s' % (len(y_true_confident),
- acc_confident)
- message += ' (%s) unconfident acc: %s' % (len(y_true_unconfident),
- acc_unconfident)
-
- # 从训练集统计先验分布
- prior = np.zeros(self.func_model.num_intent)
- for _, (batch, batch_size) in ex_data_iter:
- for intent_label in batch['intent_label']:
- prior[intent_label] += 1.
-
- prior /= prior.sum()
-
- # 逐个修改低置信度样本,并重新评价准确率
- right, alpha, iters = 0, 1, 1
- for i, y in enumerate(y_pred_unconfident):
- Y = np.concatenate([y_pred_confident, y[None]], axis=0)
- for j in range(iters):
- Y = Y**alpha
- Y /= Y.mean(axis=0, keepdims=True)
- Y *= prior[None]
- Y /= Y.sum(axis=1, keepdims=True)
- y = Y[-1]
- if y.argmax() == y_true_unconfident[i]:
- right += 1
-
- # 输出修正后的准确率
- acc_final = \
- (acc_confident * len(y_pred_confident) + right) / \
- len(y_pred)
- if len(y_pred_unconfident):
- message += ' new unconfident acc: %s' % (
- right / len(y_pred_unconfident))
- else:
- message += ' no unconfident predictions'
- message += ' final acc: %s' % acc_final
- return acc_original, acc_final, message
-
- def train_epoch(self, train_label_iter, train_nolabel_iter,
- valid_label_iter, valid_nolabel_iter):
- """
- Train an epoch.
- """
- times = []
- self.epoch += 1
- self.batch_metrics_tracker_label.clear()
- self.token_metrics_tracker_label.clear()
- self.batch_metrics_tracker_nolabel.clear()
- self.token_metrics_tracker_nolabel.clear()
-
- num_label_batches = len(train_label_iter)
- num_nolabel_batches = len(
- train_nolabel_iter) if train_nolabel_iter is not None else 0
- num_batches = max(num_label_batches, num_nolabel_batches)
-
- train_label_iter_loop = iter(train_label_iter)
- train_nolabel_iter_loop = iter(
- train_nolabel_iter) if train_nolabel_iter is not None else None
- report_for_unlabeled_data = True if train_nolabel_iter is not None else False
-
- for batch_id in range(1, num_batches + 1):
- # Do a training iteration
- start_time = time.time()
- batch_list, batch_size_list, with_label_list, loss_list, metrics_list = [], [], [], [], []
- data_file_list = []
-
- # collect batch for labeled data
- try:
- data_file_label, (
- batch_label,
- batch_size_label) = next(train_label_iter_loop)
- except StopIteration:
- train_label_iter_loop = iter(train_label_iter)
- data_file_label, (
- batch_label,
- batch_size_label) = next(train_label_iter_loop)
- batch_list.append(batch_label)
- batch_size_list.append(batch_size_label)
- with_label_list.append(True)
- data_file_list.append(data_file_label)
-
- # collect batch for unlabeled data
- if train_nolabel_iter is not None:
- try:
- data_file_nolabel, (
- batch_nolabel,
- batch_size_nolabel) = next(train_nolabel_iter_loop)
- except StopIteration:
- train_nolabel_iter_loop = iter(train_nolabel_iter)
- data_file_nolabel, (
- batch_nolabel,
- batch_size_nolabel) = next(train_nolabel_iter_loop)
- batch_list.append(batch_nolabel)
- batch_size_list.append(batch_size_nolabel)
- with_label_list.append(False)
- data_file_list.append(data_file_nolabel)
-
- # forward labeled batch and unlabeled batch and collect outputs, respectively
- for (batch, batch_size, with_label, data_file) in \
- zip(batch_list, batch_size_list, with_label_list, data_file_list):
- batch = type(batch)(
- map(lambda kv: (kv[0], self.to_tensor(kv[1])),
- batch.items()))
- if self.example and with_label:
- current_dataset = train_label_iter.data_file_to_dataset[
- data_file]
- example_batch = self.reader.retrieve_examples(
- dataset=current_dataset,
- labels=batch['intent_label'],
- inds=batch['ids'],
- task='intent')
- example_batch = type(example_batch)(
- map(lambda kv: (kv[0], self.to_tensor(kv[1])),
- example_batch.items()))
- for k, v in example_batch.items():
- batch[k] = v
- batch['epoch'] = self.epoch
- batch['num_steps'] = self.batch_num
- metrics = self.model(
- batch,
- is_training=True,
- with_label=with_label,
- data_file=data_file)
- loss, metrics = self.balance_metrics(
- metrics=metrics, batch_size=batch_size)
- loss_list.append(loss)
- metrics_list.append(metrics)
-
- # combine loss for labeled data and unlabeled data
- # TODO change the computation of combined loss of labeled batch and unlabeled batch
- loss = loss_list[0] if len(
- loss_list) == 1 else loss_list[0] + loss_list[1]
-
- # optimization procedure
- self.func_model._optimize(
- loss, optimizer=self.optimizer, lr_scheduler=self.lr_scheduler)
- elapsed = time.time() - start_time
- times.append(elapsed)
- self.batch_num += 1
-
- # track metrics and log temporary message
- for (batch_size, metrics,
- with_label) in zip(batch_size_list, metrics_list,
- with_label_list):
- self.track_and_log_message(
- metrics=metrics,
- batch_id=batch_id,
- batch_size=batch_size,
- num_batches=num_batches,
- times=times,
- with_label=with_label)
-
- # evaluate
- if self.valid_steps > 0 and valid_label_iter is not None and valid_nolabel_iter is not None \
- and batch_id % self.valid_steps == 0:
- self.evaluate(
- data_label_iter=valid_label_iter,
- data_nolabel_iter=valid_nolabel_iter)
-
- # compute accuracy for valid dataset
- accuracy = self.infer(
- data_iter=valid_label_iter, ex_data_iter=train_label_iter)
-
- # report summary message and save checkpoints
- self.save_and_log_message(
- report_for_unlabeled_data, cur_valid_metric=-accuracy)
-
- def forward(self, batch):
- pred = []
-
- with torch.no_grad():
- batch = type(batch)(
- map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items()))
- result = self.model.infer(inputs=batch)
- result = {
- name: result[name].cpu().detach().numpy()
- for name in result
- }
- intent_probs = result['intent_probs']
- if self.can_norm:
- pred += [intent_probs]
- else:
- pred += np.argmax(intent_probs, axis=1).tolist()
-
- return pred
-
- def infer(self, data_iter, num_batches=None, ex_data_iter=None):
- """
- Inference interface.
- """
- self.logger.info('Generation starts ...')
- infer_save_file = os.path.join(self.save_dir,
- f'infer_{self.epoch}.result.json')
-
- # Inference
- batch_cnt = 0
- pred, true = [], []
- outputs, labels = [], []
- begin_time = time.time()
-
- with torch.no_grad():
- if self.example:
- for _, (batch, batch_size) in tqdm(
- ex_data_iter, desc='Building train memory.'):
- batch = type(batch)(
- map(lambda kv: (kv[0], self.to_tensor(kv[1])),
- batch.items()))
- result = self.model.infer(inputs=batch)
- result = {
- name: result[name].cpu().detach().numpy()
- for name in result
- }
- outputs.append(torch.from_numpy(result['features']))
- labels += batch['intent_label'].tolist()
-
- mem = torch.cat(outputs, dim=0)
- mem = mem.cuda() if self.func_model.use_gpu else mem
- labels = torch.LongTensor(labels).unsqueeze(0)
- labels = labels.cuda() if self.func_model.use_gpu else labels
- self.logger.info(f'Memory size: {mem.size()}')
-
- for _, (batch, batch_size) in tqdm(data_iter, total=num_batches):
- batch = type(batch)(
- map(lambda kv: (kv[0], self.to_tensor(kv[1])),
- batch.items()))
- result = self.model.infer(inputs=batch)
- result = {
- name: result[name].cpu().detach().numpy()
- for name in result
- }
-
- if self.example:
- features = torch.from_numpy(result['features'])
- features = features.cuda(
- ) if self.func_model.use_gpu else features
- probs = torch.softmax(features.mm(mem.t()), dim=-1)
- intent_probs = torch.zeros(
- probs.size(0), self.func_model.num_intent)
- intent_probs = intent_probs.cuda(
- ) if self.func_model.use_gpu else intent_probs
- intent_probs = intent_probs.scatter_add(
- -1, labels.repeat(probs.size(0), 1), probs)
- intent_probs = intent_probs.cpu().detach().numpy()
- else:
- intent_probs = result['intent_probs']
-
- if self.can_norm:
- pred += [intent_probs]
- true += batch['intent_label'].cpu().detach().tolist()
- else:
- pred += np.argmax(intent_probs, axis=1).tolist()
- true += batch['intent_label'].cpu().detach().tolist()
-
- batch_cnt += 1
- if batch_cnt == num_batches:
- break
-
- if self.can_norm:
- true = np.array(true)
- pred = np.concatenate(pred, axis=0)
- acc_original, acc_final, message = self.can_normalization(
- y_pred=pred, y_true=true, ex_data_iter=ex_data_iter)
- accuracy = max(acc_original, acc_final)
- infer_results = {
- 'accuracy': accuracy,
- 'pred_labels': pred.tolist(),
- 'message': message
- }
- metrics_message = f'Accuracy: {accuracy} {message}'
- else:
- accuracy = sum(p == t for p, t in zip(pred, true)) / len(pred)
- infer_results = {'accuracy': accuracy, 'pred_labels': pred}
- metrics_message = f'Accuracy: {accuracy}'
-
- self.logger.info(f'Saved inference results to {infer_save_file}')
- with open(infer_save_file, 'w') as fp:
- json.dump(infer_results, fp, indent=2)
- message_prefix = f'[Infer][{self.epoch}]'
- time_cost = f'TIME-{time.time() - begin_time:.3f}'
- message = ' '.join([message_prefix, metrics_message, time_cost])
- self.logger.info(message)
- return accuracy
-
- def track_and_log_message(self, metrics, batch_id, batch_size, num_batches,
- times, with_label):
- # track metrics
- batch_metrics_tracker = self.batch_metrics_tracker_label if with_label else self.batch_metrics_tracker_nolabel
- token_metrics_tracker = self.token_metrics_tracker_label if with_label else self.token_metrics_tracker_nolabel
-
- metrics = {
- k: v.cpu().detach().numpy() if isinstance(v, torch.Tensor) else v
- for k, v in metrics.items()
- }
- mlm_num = metrics.pop('mlm_num', 0)
-
- batch_metrics = {k: v for k, v in metrics.items() if 'token' not in k}
- token_metrics = {k: v for k, v in metrics.items() if 'token' in k}
- batch_metrics_tracker.update(batch_metrics, batch_size)
- token_metrics_tracker.update(token_metrics, mlm_num)
-
- # log message
- if self.log_steps > 0 and batch_id % self.log_steps == 0:
- batch_metrics_message = batch_metrics_tracker.value()
- token_metrics_message = token_metrics_tracker.value()
- label_prefix = 'Labeled' if with_label else 'Unlabeled'
- message_prefix = f'[Train][{self.epoch}][{batch_id}/{num_batches}][{label_prefix}]'
- avg_time = f'AVG_Time-{sum(times[-self.log_steps:]) / self.log_steps:.3f}'
- message = ' '.join([
- message_prefix, batch_metrics_message, token_metrics_message,
- avg_time
- ])
- self.logger.info(message)
-
- def save_and_log_message(self,
- report_for_unlabeled_data,
- cur_valid_metric=None):
- # report message
- batch_metrics_message = self.batch_metrics_tracker_label.summary()
- token_metrics_message = self.token_metrics_tracker_label.summary()
- message_prefix = f'[Valid][{self.epoch}][Labeled]'
- message = ' '.join(
- [message_prefix, batch_metrics_message, token_metrics_message])
- self.logger.info(message)
- if report_for_unlabeled_data:
- batch_metrics_message = self.batch_metrics_tracker_nolabel.summary(
- )
- token_metrics_message = self.token_metrics_tracker_nolabel.summary(
- )
- message_prefix = f'[Valid][{self.epoch}][Unlabeled]'
- message = ' '.join(
- [message_prefix, batch_metrics_message, token_metrics_message])
- self.logger.info(message)
-
- # save checkpoints
- assert cur_valid_metric is not None
- if self.is_decreased_valid_metric:
- is_best = cur_valid_metric < self.best_valid_metric
- else:
- is_best = cur_valid_metric > self.best_valid_metric
- if is_best:
- self.best_valid_metric = cur_valid_metric
- self.save(is_best)
-
- def balance_metrics(self, metrics, batch_size):
- if self.gpu > 1:
- for metric in metrics:
- if metric is not None:
- assert len(metric) == self.gpu
-
- intent_loss, mlm, token_mlm, mlm_num, kl, con = metrics
- metrics = {}
-
- intent_loss = torch.mean(intent_loss)
- metrics['intent_loss'] = intent_loss
- loss = intent_loss
-
- if mlm is not None:
- mlm_num = torch.sum(mlm_num)
- token_mlm = torch.sum(mlm) * (batch_size / self.gpu) / mlm_num
- mlm = torch.mean(mlm)
- metrics['mlm_num'] = mlm_num
- metrics['token_mlm'] = token_mlm
- metrics['mlm'] = mlm
- loss = loss + (token_mlm if self.func_model.token_loss else
- mlm) * self.func_model.mlm_ratio
-
- if kl is not None:
- kl = torch.mean(kl)
- metrics['kl'] = kl
- loss = loss + kl * self.func_model.kl_ratio
-
- if con is not None:
- con = torch.mean(con)
- metrics['con'] = con
- loss = loss + con
-
- metrics['loss'] = loss
-
- assert 'loss' in metrics
- return metrics['loss'], metrics
-
- def load(self):
- """ load """
-
- def _load_model_state():
- model_state_dict = torch.load(
- f'{self.func_model.init_checkpoint}',
- map_location=lambda storage, loc: storage)
-
- if 'module.' in list(model_state_dict.keys())[0]:
- new_model_state_dict = OrderedDict()
- for k, v in model_state_dict.items():
- assert k[:7] == 'module.'
- new_model_state_dict[k[7:]] = v
- model_state_dict = new_model_state_dict
-
- new_model_state_dict = OrderedDict()
- parameters = {
- name: param
- for name, param in self.func_model.named_parameters()
- }
- for name, param in model_state_dict.items():
- if name in parameters:
- if param.shape != parameters[name].shape:
- assert hasattr(param, 'numpy')
- arr = param.numpy()
- z = np.random.normal(
- scale=self.func_model.initializer_range,
- size=parameters[name].shape).astype('float32')
- if name == 'embedder.token_embedding.weight':
- z[-param.shape[0]:] = arr
- print(
- f'part of parameter({name}) random normlize initialize'
- )
- else:
- if z.shape[0] < param.shape[0]:
- z = arr[:z.shape[0]]
- print(f'part of parameter({name}) are dropped')
- else:
- z[:param.shape[0]] = arr
- print(
- f'part of parameter({name}) random normlize initialize'
- )
- dtype, device = param.dtype, param.device
- z = torch.tensor(z, dtype=dtype, device=device)
- new_model_state_dict[name] = z
- else:
- new_model_state_dict[name] = param
- else:
- print(f'parameter({name}) are dropped')
- model_state_dict = new_model_state_dict
-
- for name in parameters:
- if name not in model_state_dict:
- if parameters[name].requires_grad:
- print(f'parameter({name}) random normlize initialize')
- z = np.random.normal(
- scale=self.func_model.initializer_range,
- size=parameters[name].shape).astype('float32')
- dtype, device = parameters[name].dtype, parameters[
- name].device
- model_state_dict[name] = torch.tensor(
- z, dtype=dtype, device=device)
- else:
- model_state_dict[name] = parameters[name]
-
- self.func_model.load_state_dict(model_state_dict)
- self.logger.info(
- f"Loaded model state from '{self.func_model.init_checkpoint}.model'"
- )
-
- def _load_train_state():
- train_file = f'{self.func_model.init_checkpoint}.train'
- if os.path.exists(train_file):
- train_state_dict = torch.load(
- train_file, map_location=lambda storage, loc: storage)
- self.epoch = train_state_dict['epoch']
- self.best_valid_metric = train_state_dict['best_valid_metric']
- if self.optimizer is not None and 'optimizer' in train_state_dict:
- self.optimizer.load_state_dict(
- train_state_dict['optimizer'])
- if self.lr_scheduler is not None and 'lr_scheduler' in train_state_dict:
- self.lr_scheduler.load_state_dict(
- train_state_dict['lr_scheduler'])
- self.logger.info(
- f"Loaded train state from '{train_file}' with (epoch-{self.epoch} "
- f'best_valid_metric={self.best_valid_metric:.3f})')
- else:
- self.logger.info('Loaded no train state')
-
- if self.func_model.init_checkpoint is None:
- self.logger.info('Loaded no model !!!')
- return
-
- if self.do_train:
- _load_model_state()
- return
-
- if self.do_infer:
- _load_model_state()
- _load_train_state()
|