import time from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np from maas_lib.utils.constant import Tasks from maas_lib.utils.logger import get_logger from ..base import BaseTrainer from ..builder import TRAINERS # __all__ = ["SequenceClassificationTrainer"] PATH = None logger = get_logger(PATH) @TRAINERS.register_module( Tasks.text_classification, module_name=r'bert-sentiment-analysis') class SequenceClassificationTrainer(BaseTrainer): def __init__(self, cfg_file: str, *args, **kwargs): """ A trainer is used for Sequence Classification Based on Config file (*.yaml or *.json), the trainer trains or evaluates on a dataset Args: cfg_file (str): the path of config file Raises: ValueError: _description_ """ super().__init__(cfg_file) def train(self, *args, **kwargs): logger.info('Train') ... def __attr_is_exist(self, attr: str) -> Tuple[Union[str, bool]]: """get attribute from config, if the attribute does exist, return false Example: >>> self.__attr_is_exist("model path") out: (model-path, "/workspace/bert-base-sst2") >>> self.__attr_is_exist("model weights") out: (model-weights, False) Args: attr (str): attribute str, "model path" -> config["model"][path] Returns: Tuple[Union[str, bool]]:[target attribute name, the target attribute or False] """ paths = attr.split(' ') attr_str: str = '-'.join(paths) target = self.cfg[paths[0]] if hasattr(self.cfg, paths[0]) else None for path_ in paths[1:]: if not hasattr(target, path_): return attr_str, False target = target[path_] if target and target != '': return attr_str, target return attr_str, False def evaluate(self, checkpoint_path: Optional[str] = None, *args, **kwargs) -> Dict[str, float]: """evaluate a dataset evaluate a dataset via a specific model from the `checkpoint_path` path, if the `checkpoint_path` does not exist, read from the config file. Args: checkpoint_path (Optional[str], optional): the model path. Defaults to None. Returns: Dict[str, float]: the results about the evaluation Example: {"accuracy": 0.5091743119266054, "f1": 0.673780487804878} """ import torch from easynlp.appzoo import load_dataset from easynlp.appzoo.dataset import GeneralDataset from easynlp.appzoo.sequence_classification.model import SequenceClassification from easynlp.utils import losses from sklearn.metrics import f1_score from torch.utils.data import DataLoader raise_str = 'Attribute {} is not given in config file!' metrics = self.__attr_is_exist('evaluation metrics') eval_batch_size = self.__attr_is_exist('evaluation batch_size') test_dataset_path = self.__attr_is_exist('dataset valid file') attrs = [metrics, eval_batch_size, test_dataset_path] for attr_ in attrs: if not attr_[-1]: raise AttributeError(raise_str.format(attr_[0])) if not checkpoint_path: checkpoint_path = self.__attr_is_exist('evaluation model_path')[-1] if not checkpoint_path: raise ValueError( 'Argument checkout_path must be passed if the evaluation-model_path is not given in config file!' ) max_sequence_length = kwargs.get( 'max_sequence_length', self.__attr_is_exist('evaluation max_sequence_length')[-1]) if not max_sequence_length: raise ValueError( 'Argument max_sequence_length must be passed ' 'if the evaluation-max_sequence_length does not exist in config file!' ) # get the raw online dataset raw_dataset = load_dataset(*test_dataset_path[-1].split('/')) valid_dataset = raw_dataset['validation'] # generate a standard dataloader pre_dataset = GeneralDataset(valid_dataset, checkpoint_path, max_sequence_length) valid_dataloader = DataLoader( pre_dataset, batch_size=eval_batch_size[-1], shuffle=False, collate_fn=pre_dataset.batch_fn) # generate a model model = SequenceClassification.from_pretrained(checkpoint_path) # copy from easynlp (start) model.eval() total_loss = 0 total_steps = 0 total_samples = 0 hit_num = 0 total_num = 0 logits_list = list() y_trues = list() total_spent_time = 0.0 device = 'cuda:0' if torch.cuda.is_available() else 'cpu' model.to(device) for _step, batch in enumerate(valid_dataloader): try: batch = { # key: val.cuda() if isinstance(val, torch.Tensor) else val # for key, val in batch.items() key: val.to(device) if isinstance(val, torch.Tensor) else val for key, val in batch.items() } except RuntimeError: batch = {key: val for key, val in batch.items()} infer_start_time = time.time() with torch.no_grad(): label_ids = batch.pop('label_ids') outputs = model(batch) infer_end_time = time.time() total_spent_time += infer_end_time - infer_start_time assert 'logits' in outputs logits = outputs['logits'] y_trues.extend(label_ids.tolist()) logits_list.extend(logits.tolist()) hit_num += torch.sum( torch.argmax(logits, dim=-1) == label_ids).item() total_num += label_ids.shape[0] if len(logits.shape) == 1 or logits.shape[-1] == 1: tmp_loss = losses.mse_loss(logits, label_ids) elif len(logits.shape) == 2: tmp_loss = losses.cross_entropy(logits, label_ids) else: raise RuntimeError total_loss += tmp_loss.mean().item() total_steps += 1 total_samples += valid_dataloader.batch_size if (_step + 1) % 100 == 0: total_step = len( valid_dataloader.dataset) // valid_dataloader.batch_size logger.info('Eval: {}/{} steps finished'.format( _step + 1, total_step)) logger.info('Inference time = {:.2f}s, [{:.4f} ms / sample] '.format( total_spent_time, total_spent_time * 1000 / total_samples)) eval_loss = total_loss / total_steps logger.info('Eval loss: {}'.format(eval_loss)) logits_list = np.array(logits_list) eval_outputs = list() for metric in metrics[-1]: if metric.endswith('accuracy'): acc = hit_num / total_num logger.info('Accuracy: {}'.format(acc)) eval_outputs.append(('accuracy', acc)) elif metric == 'f1': if model.config.num_labels == 2: f1 = f1_score(y_trues, np.argmax(logits_list, axis=-1)) logger.info('F1: {}'.format(f1)) eval_outputs.append(('f1', f1)) else: f1 = f1_score( y_trues, np.argmax(logits_list, axis=-1), average='macro') logger.info('Macro F1: {}'.format(f1)) eval_outputs.append(('macro-f1', f1)) f1 = f1_score( y_trues, np.argmax(logits_list, axis=-1), average='micro') logger.info('Micro F1: {}'.format(f1)) eval_outputs.append(('micro-f1', f1)) else: raise NotImplementedError('Metric %s not implemented' % metric) # copy from easynlp (end) return dict(eval_outputs)