diff --git a/maas_lib/models/nlp/__init__.py b/maas_lib/models/nlp/__init__.py index a8489c12..99b56c17 100644 --- a/maas_lib/models/nlp/__init__.py +++ b/maas_lib/models/nlp/__init__.py @@ -1,2 +1,3 @@ from .sequence_classification_model import * # noqa F403 from .space.dialog_generation_model import * # noqa F403 +from .space.dialog_intent_model import * diff --git a/maas_lib/models/nlp/space/dialog_generation_model.py b/maas_lib/models/nlp/space/dialog_generation_model.py index 440c1163..be3d7261 100644 --- a/maas_lib/models/nlp/space/dialog_generation_model.py +++ b/maas_lib/models/nlp/space/dialog_generation_model.py @@ -10,7 +10,8 @@ from .model.model_base import ModelBase __all__ = ['DialogGenerationModel'] -@MODELS.register_module(Tasks.dialog_generation, module_name=r'space') +@MODELS.register_module( + Tasks.dialog_generation, module_name=r'space-generation') class DialogGenerationModel(Model): def __init__(self, model_dir: str, *args, **kwargs): diff --git a/maas_lib/models/nlp/space/dialog_intent_model.py b/maas_lib/models/nlp/space/dialog_intent_model.py new file mode 100644 index 00000000..226c5da8 --- /dev/null +++ b/maas_lib/models/nlp/space/dialog_intent_model.py @@ -0,0 +1,69 @@ +from typing import Any, Dict, Optional + +from maas_lib.trainers.nlp.space.trainers.intent_trainer import IntentTrainer +from maas_lib.utils.constant import Tasks +from ...base import Model, Tensor +from ...builder import MODELS +from .model.generator import Generator +from .model.model_base import ModelBase + +__all__ = ['DialogIntentModel'] + + +@MODELS.register_module(Tasks.dialog_intent, module_name=r'space-intent') +class DialogIntentModel(Model): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the test generation model from the `model_dir` path. + + Args: + model_dir (str): the model path. + model_cls (Optional[Any], optional): model loader, if None, use the + default loader to load model weights, by default None. + """ + + super().__init__(model_dir, *args, **kwargs) + self.model_dir = model_dir + self.text_field = kwargs.pop('text_field') + self.config = kwargs.pop('config') + self.generator = Generator.create(self.config, reader=self.text_field) + self.model = ModelBase.create( + model_dir=model_dir, + config=self.config, + reader=self.text_field, + generator=self.generator) + + def to_tensor(array): + """ + numpy array -> tensor + """ + import torch + array = torch.tensor(array) + return array.cuda() if self.config.use_gpu else array + + self.trainer = IntentTrainer( + model=self.model, + to_tensor=to_tensor, + config=self.config, + reader=self.text_field) + self.trainer.load() + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """return the result by the model + + Args: + input (Dict[str, Any]): the preprocessed data + + Returns: + Dict[str, np.ndarray]: results + Example: + { + 'predictions': array([1]), # lable 0-negative 1-positive + 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), + 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value + } + """ + from numpy import array, float32 + import torch + + return {} diff --git a/maas_lib/models/nlp/space/model/__init__.py b/maas_lib/models/nlp/space/model/__init__.py index e69de29b..7e1b5264 100644 --- a/maas_lib/models/nlp/space/model/__init__.py +++ b/maas_lib/models/nlp/space/model/__init__.py @@ -0,0 +1,3 @@ +from .gen_unified_transformer import GenUnifiedTransformer +from .intent_unified_transformer import IntentUnifiedTransformer +from .unified_transformer import UnifiedTransformer diff --git a/maas_lib/models/nlp/space/model/generator.py b/maas_lib/models/nlp/space/model/generator.py index 2567102f..aa2e5f20 100644 --- a/maas_lib/models/nlp/space/model/generator.py +++ b/maas_lib/models/nlp/space/model/generator.py @@ -7,9 +7,6 @@ import math import numpy as np import torch -from .gen_unified_transformer import GenUnifiedTransformer -from .unified_transformer import UnifiedTransformer - def repeat(var, times): if isinstance(var, list): diff --git a/maas_lib/models/nlp/space/model/intent_unified_transformer.py b/maas_lib/models/nlp/space/model/intent_unified_transformer.py new file mode 100644 index 00000000..dd63df39 --- /dev/null +++ b/maas_lib/models/nlp/space/model/intent_unified_transformer.py @@ -0,0 +1,198 @@ +""" +IntentUnifiedTransformer +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from maas_lib.utils.nlp.space.criterions import compute_kl_loss +from .unified_transformer import UnifiedTransformer + + +class IntentUnifiedTransformer(UnifiedTransformer): + """ + Implement intent unified transformer. + """ + + def __init__(self, model_dir, config, reader, generator): + super(IntentUnifiedTransformer, self).__init__(model_dir, config, + reader, generator) + self.example = config.Model.example + self.num_intent = config.Model.num_intent + self.with_rdrop = config.Model.with_rdrop + self.kl_ratio = config.Model.kl_ratio + self.loss_fct = nn.CrossEntropyLoss() + if self.example: + self.loss_fct = nn.NLLLoss() + else: + self.intent_classifier = nn.Linear(self.hidden_dim, + self.num_intent) + self.loss_fct = nn.CrossEntropyLoss() + + if self.use_gpu: + self.cuda() + return + + def _forward(self, inputs, is_training, with_label): + """ Real forward process of model in different mode(train/test). """ + + def aug(v): + assert isinstance(v, torch.Tensor) + return torch.cat([v, v], dim=0) + + outputs = {} + + if self.with_mlm: + mlm_embed = self._encoder_network( + input_token=inputs['mlm_token'], + input_mask=inputs['src_mask'], + input_pos=inputs['src_pos'], + input_type=inputs['src_type'], + input_turn=inputs['src_turn']) + outputs['mlm_probs'] = self._mlm_head(mlm_embed=mlm_embed) + + if self.with_rdrop or self.with_contrastive: + enc_embed, dec_embed = self._encoder_decoder_network( + src_token=aug(inputs['src_token']), + src_mask=aug(inputs['src_mask']), + tgt_token=aug(inputs['tgt_token']), + tgt_mask=aug(inputs['tgt_mask']), + src_pos=aug(inputs['src_pos']), + src_type=aug(inputs['src_type']), + src_turn=aug(inputs['src_turn'])) + else: + enc_embed, dec_embed = self._encoder_decoder_network( + src_token=inputs['src_token'], + src_mask=inputs['src_mask'], + tgt_token=inputs['tgt_token'], + tgt_mask=inputs['tgt_mask'], + src_pos=inputs['src_pos'], + src_type=inputs['src_type'], + src_turn=inputs['src_turn']) + features = dec_embed[:, -1] + features = self.pooler(features) if self.with_pool else features + + if self.example: + assert not self.with_rdrop + ex_enc_embed, ex_dec_embed = self._encoder_decoder_network( + src_token=inputs['example_src_token'], + src_mask=inputs['example_src_mask'], + tgt_token=inputs['example_tgt_token'], + tgt_mask=inputs['example_tgt_mask'], + src_pos=inputs['example_src_pos'], + src_type=inputs['example_src_type'], + src_turn=inputs['example_src_turn']) + ex_features = ex_dec_embed[:, -1] + ex_features = self.pooler( + ex_features) if self.with_pool else ex_features + + probs = self.softmax(features.mm(ex_features.t())) + example_intent = inputs['example_intent'].unsqueeze(0) + intent_probs = torch.zeros(probs.size(0), self.num_intent) + intent_probs = intent_probs.cuda( + ) if self.use_gpu else intent_probs + intent_probs = intent_probs.scatter_add( + -1, example_intent.repeat(probs.size(0), 1), probs) + outputs['intent_probs'] = intent_probs + else: + intent_logits = self.intent_classifier(features) + outputs['intent_logits'] = intent_logits + + if self.with_contrastive: + features = features if self.with_pool else self.pooler(features) + batch_size = features.size(0) // 2 + features = torch.cat([ + features[:batch_size].unsqueeze(1), + features[batch_size:].unsqueeze(1) + ], + dim=1) + features = F.normalize(features, dim=-1, p=2) + outputs['features'] = features + + return outputs + + def _collect_metrics(self, inputs, outputs, with_label, data_file): + + metrics = {} + batch_size = inputs['src_token'].size(0) + + intent_label = torch.cat([inputs['intent_label'], inputs['intent_label']], dim=0) \ + if self.with_rdrop or self.with_contrastive else inputs['intent_label'] + + if self.example: + intent_loss = self.loss_fct( + torch.log(outputs['intent_probs'] + 1e-12).view( + -1, self.num_intent), intent_label.type(torch.long)) + else: + intent_loss = self.loss_fct( + outputs['intent_logits'].view(-1, self.num_intent), + intent_label.type(torch.long)) + metrics['intent_loss'] = intent_loss + loss = intent_loss + + if self.with_mlm: + mlm_num = torch.sum(torch.sum(inputs['mlm_mask'], dim=1)) + mlm = self.nll_loss( + torch.log(outputs['mlm_probs'] + 1e-12).permute(0, 2, 1), + inputs['mlm_label']) + mlm = torch.sum(mlm, dim=1) + token_mlm = torch.sum(mlm) / mlm_num + mlm = torch.mean(mlm) + metrics['mlm'] = mlm + metrics['token_mlm'] = token_mlm + metrics['mlm_num'] = mlm_num + loss = loss + (token_mlm + if self.token_loss else mlm) * self.mlm_ratio + else: + mlm, token_mlm, mlm_num = None, None, None + + if self.with_rdrop: + kl = compute_kl_loss( + p=outputs['intent_logits'][:batch_size], + q=outputs['intent_logits'][batch_size:]) + metrics['kl'] = kl + loss = loss + kl * self.kl_ratio + else: + kl = None + + if self.with_contrastive: + pass + con = None + else: + con = None + + metrics['loss'] = loss + + if self.gpu > 1: + return intent_loss, mlm, token_mlm, mlm_num, kl, con + else: + return metrics + + def _infer(self, + inputs, + start_id=None, + eos_id=None, + max_gen_len=None, + prev_input=None): + """ Real inference process of model. """ + results = {} + enc_embed, dec_embed = self._encoder_decoder_network( + src_token=inputs['src_token'], + src_mask=inputs['src_mask'], + tgt_token=inputs['tgt_token'], + tgt_mask=inputs['tgt_mask'], + src_pos=inputs['src_pos'], + src_type=inputs['src_type'], + src_turn=inputs['src_turn']) + features = dec_embed[:, -1] + features = self.pooler(features) if self.with_pool else features + if self.example: + results['features'] = features + else: + intent_logits = self.intent_classifier(features) + intent_probs = self.softmax(intent_logits) + results['intent_probs'] = intent_probs + return results + + +IntentUnifiedTransformer.register('IntentUnifiedTransformer') diff --git a/maas_lib/pipelines/nlp/__init__.py b/maas_lib/pipelines/nlp/__init__.py index 01bd0e2a..8a97070b 100644 --- a/maas_lib/pipelines/nlp/__init__.py +++ b/maas_lib/pipelines/nlp/__init__.py @@ -1,2 +1,3 @@ from .sequence_classification_pipeline import * # noqa F403 from .space.dialog_generation_pipeline import * # noqa F403 +from .space.dialog_intent_pipeline import * # noqa F403 diff --git a/maas_lib/pipelines/nlp/space/dialog_generation_pipeline.py b/maas_lib/pipelines/nlp/space/dialog_generation_pipeline.py index 8a5e1c26..a7b2d057 100644 --- a/maas_lib/pipelines/nlp/space/dialog_generation_pipeline.py +++ b/maas_lib/pipelines/nlp/space/dialog_generation_pipeline.py @@ -9,7 +9,8 @@ from ...builder import PIPELINES __all__ = ['DialogGenerationPipeline'] -@PIPELINES.register_module(Tasks.dialog_generation, module_name=r'space') +@PIPELINES.register_module( + Tasks.dialog_generation, module_name=r'space-generation') class DialogGenerationPipeline(Model): def __init__(self, model: DialogGenerationModel, diff --git a/maas_lib/pipelines/nlp/space/dialog_intent_pipeline.py b/maas_lib/pipelines/nlp/space/dialog_intent_pipeline.py new file mode 100644 index 00000000..e9d10551 --- /dev/null +++ b/maas_lib/pipelines/nlp/space/dialog_intent_pipeline.py @@ -0,0 +1,50 @@ +from typing import Any, Dict, Optional + +from maas_lib.models.nlp import DialogIntentModel +from maas_lib.preprocessors import DialogIntentPreprocessor +from maas_lib.utils.constant import Tasks +from ...base import Model, Tensor +from ...builder import PIPELINES + +__all__ = ['DialogIntentPipeline'] + + +@PIPELINES.register_module(Tasks.dialog_intent, module_name=r'space-intent') +class DialogIntentPipeline(Model): + + def __init__(self, model: DialogIntentModel, + preprocessor: DialogIntentPreprocessor, **kwargs): + """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction + + Args: + model (SequenceClassificationModel): a model instance + preprocessor (SequenceClassificationPreprocessor): a preprocessor instance + """ + + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + self.model = model + self.tokenizer = preprocessor.tokenizer + + def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + + vocab_size = len(self.tokenizer.vocab) + pred_list = inputs['predictions'] + pred_ids = pred_list[0][0].cpu().numpy().tolist() + for j in range(len(pred_ids)): + if pred_ids[j] >= vocab_size: + pred_ids[j] = 100 + pred = self.tokenizer.convert_ids_to_tokens(pred_ids) + pred_string = ''.join(pred).replace( + '##', + '').split('[SEP]')[0].replace('[CLS]', + '').replace('[SEP]', + '').replace('[UNK]', '') + return {'pred_string': pred_string} diff --git a/maas_lib/preprocessors/__init__.py b/maas_lib/preprocessors/__init__.py index 9ed0d181..4a146843 100644 --- a/maas_lib/preprocessors/__init__.py +++ b/maas_lib/preprocessors/__init__.py @@ -4,5 +4,6 @@ from .base import Preprocessor from .builder import PREPROCESSORS, build_preprocessor from .common import Compose from .image import LoadImage, load_image -from .nlp.nlp import * # noqa F403 -from .nlp.space.dialog_generation_preprcessor import * # noqa F403 +from .nlp import * # noqa F403 +from .space.dialog_generation_preprocessor import * # noqa F403 +from .space.dialog_intent_preprocessor import * # noqa F403 diff --git a/maas_lib/preprocessors/nlp/nlp.py b/maas_lib/preprocessors/nlp.py similarity index 97% rename from maas_lib/preprocessors/nlp/nlp.py rename to maas_lib/preprocessors/nlp.py index ea496883..0a03328a 100644 --- a/maas_lib/preprocessors/nlp/nlp.py +++ b/maas_lib/preprocessors/nlp.py @@ -7,8 +7,8 @@ from transformers import AutoTokenizer from maas_lib.utils.constant import Fields, InputFields from maas_lib.utils.type_assert import type_assert -from ..base import Preprocessor -from ..builder import PREPROCESSORS +from .base import Preprocessor +from .builder import PREPROCESSORS __all__ = [ 'Tokenize', diff --git a/maas_lib/preprocessors/nlp/space/__init__.py b/maas_lib/preprocessors/nlp/space/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/maas_lib/preprocessors/nlp/__init__.py b/maas_lib/preprocessors/space/__init__.py similarity index 100% rename from maas_lib/preprocessors/nlp/__init__.py rename to maas_lib/preprocessors/space/__init__.py diff --git a/maas_lib/preprocessors/nlp/space/dialog_generation_preprcessor.py b/maas_lib/preprocessors/space/dialog_generation_preprocessor.py similarity index 90% rename from maas_lib/preprocessors/nlp/space/dialog_generation_preprcessor.py rename to maas_lib/preprocessors/space/dialog_generation_preprocessor.py index f47eed7e..5b127e8e 100644 --- a/maas_lib/preprocessors/nlp/space/dialog_generation_preprcessor.py +++ b/maas_lib/preprocessors/space/dialog_generation_preprocessor.py @@ -8,13 +8,13 @@ from maas_lib.data.nlp.space.fields.gen_field import MultiWOZBPETextField from maas_lib.utils.config import Config from maas_lib.utils.constant import Fields, InputFields from maas_lib.utils.type_assert import type_assert -from ...base import Preprocessor -from ...builder import PREPROCESSORS +from ..base import Preprocessor +from ..builder import PREPROCESSORS __all__ = ['DialogGenerationPreprocessor'] -@PREPROCESSORS.register_module(Fields.nlp, module_name=r'space') +@PREPROCESSORS.register_module(Fields.nlp, module_name=r'space-generation') class DialogGenerationPreprocessor(Preprocessor): def __init__(self, model_dir: str, *args, **kwargs): diff --git a/maas_lib/preprocessors/space/dialog_intent_preprocessor.py b/maas_lib/preprocessors/space/dialog_intent_preprocessor.py new file mode 100644 index 00000000..43ced78c --- /dev/null +++ b/maas_lib/preprocessors/space/dialog_intent_preprocessor.py @@ -0,0 +1,49 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import uuid +from typing import Any, Dict, Union + +from maas_lib.data.nlp.space.fields.intent_field import IntentBPETextField +from maas_lib.utils.config import Config +from maas_lib.utils.constant import Fields, InputFields +from maas_lib.utils.type_assert import type_assert +from ..base import Preprocessor +from ..builder import PREPROCESSORS + +__all__ = ['DialogIntentPreprocessor'] + + +@PREPROCESSORS.register_module(Fields.nlp, module_name=r'space-intent') +class DialogIntentPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + model_dir (str): model path + """ + super().__init__(*args, **kwargs) + + self.model_dir: str = model_dir + self.config = Config.from_file( + os.path.join(self.model_dir, 'configuration.json')) + self.text_field = IntentBPETextField( + self.model_dir, config=self.config) + + @type_assert(object, str) + def __call__(self, data: str) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str): a sentence + Example: + 'you are so handsome.' + + Returns: + Dict[str, Any]: the preprocessed data + """ + + # idx = self.text_field.get_ids(data) + + return {'user_idx': idx} diff --git a/maas_lib/trainers/nlp/space/trainers/intent_trainer.py b/maas_lib/trainers/nlp/space/trainers/intent_trainer.py new file mode 100644 index 00000000..f736a739 --- /dev/null +++ b/maas_lib/trainers/nlp/space/trainers/intent_trainer.py @@ -0,0 +1,803 @@ +""" +Trainer class. +""" + +import logging +import os +import sys +import time +from collections import OrderedDict + +import json +import numpy as np +import torch +from tqdm import tqdm +from transformers.optimization import AdamW, get_linear_schedule_with_warmup + +from maas_lib.trainers.nlp.space.metrics.metrics_tracker import MetricsTracker +from maas_lib.utils.nlp.space.args import str2bool + + +def get_logger(log_path, name='default'): + logger = logging.getLogger(name) + logger.propagate = False + logger.setLevel(logging.DEBUG) + + formatter = logging.Formatter('%(message)s') + + sh = logging.StreamHandler(sys.stdout) + sh.setFormatter(formatter) + logger.addHandler(sh) + + fh = logging.FileHandler(log_path, mode='w') + fh.setFormatter(formatter) + logger.addHandler(fh) + + return logger + + +class Trainer(object): + + def __init__(self, + model, + to_tensor, + config, + reader=None, + logger=None, + lr_scheduler=None, + optimizer=None): + self.model = model + self.to_tensor = to_tensor + self.do_train = config.do_train + self.do_infer = config.do_infer + + self.is_decreased_valid_metric = config.Trainer.valid_metric_name[ + 0] == '-' + self.valid_metric_name = config.Trainer.valid_metric_name[1:] + self.num_epochs = config.Trainer.num_epochs + self.save_dir = config.Trainer.save_dir + self.log_steps = config.Trainer.log_steps + self.valid_steps = config.Trainer.valid_steps + self.save_checkpoint = config.Trainer.save_checkpoint + self.save_summary = config.Trainer.save_summary + self.learning_method = config.Dataset.learning_method + self.weight_decay = config.Model.weight_decay + self.warmup_steps = config.Model.warmup_steps + self.batch_size_label = config.Trainer.batch_size_label + self.batch_size_nolabel = config.Trainer.batch_size_nolabel + self.gpu = config.Trainer.gpu + self.lr = config.Model.lr + + self.model = model + self.func_model = self.model.module if self.gpu > 1 else self.model + self.reader = reader + self.tokenizer = reader.tokenizer + + self.lr_scheduler = lr_scheduler + self.optimizer = optimizer + + # if not os.path.exists(self.save_dir): + # os.makedirs(self.save_dir) + + # self.logger = logger or get_logger(os.path.join(self.save_dir, "trainer.log"), "trainer") + self.logger = logger or get_logger('trainer.log', 'trainer') + + self.batch_metrics_tracker_label = MetricsTracker() + self.token_metrics_tracker_label = MetricsTracker() + self.batch_metrics_tracker_nolabel = MetricsTracker() + self.token_metrics_tracker_nolabel = MetricsTracker() + + self.best_valid_metric = float( + 'inf' if self.is_decreased_valid_metric else '-inf') + self.epoch = 0 + self.batch_num = 0 + + def set_optimizers(self, num_training_steps_per_epoch): + """ + Setup the optimizer and the learning rate scheduler. + + from transformers.Trainer + + parameters from cfg: lr (1e-3); warmup_steps + """ + # Prepare optimizer and schedule (linear warmup and decay) + no_decay = ['bias', 'norm.weight'] + optimizer_grouped_parameters = [ + { + 'params': [ + p for n, p in self.model.named_parameters() + if not any(nd in n for nd in no_decay) + ], + 'weight_decay': + self.weight_decay, + }, + { + 'params': [ + p for n, p in self.model.named_parameters() + if any(nd in n for nd in no_decay) + ], + 'weight_decay': + 0.0, + }, + ] + optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr) + + num_training_steps = num_training_steps_per_epoch * self.num_epochs + num_warmup_steps = self.warmup_steps if self.warmup_steps >= 0 else int( + num_training_steps * 0.1) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps) + + # reset optimizer and lr_scheduler + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + # log info + self.logger.info( + f'***** Running training: {self.learning_method} *****') + self.logger.info(' Num Epochs = %d', self.num_epochs) + self.logger.info( + ' Num Training steps(one turn in a batch of dialogs) per epoch = %d', + num_training_steps_per_epoch) + self.logger.info(' Batch size for labeled data = %d', + self.batch_size_label) + self.logger.info(' Batch size for unlabeled data = %d', + self.batch_size_nolabel) + self.logger.info(' Total optimization steps = %d', num_training_steps) + self.logger.info(' Total warmup steps = %d', num_warmup_steps) + self.logger.info(f'************************************') + + def train(self, + train_label_iter, + train_nolabel_iter=None, + valid_label_iter=None, + valid_nolabel_iter=None): + # begin training + num_epochs = self.num_epochs - self.epoch + for epoch in range(num_epochs): + self.train_epoch( + train_label_iter=train_label_iter, + train_nolabel_iter=train_nolabel_iter, + valid_label_iter=valid_label_iter, + valid_nolabel_iter=valid_nolabel_iter) + + def train_epoch(self, train_label_iter, train_nolabel_iter, + valid_label_iter, valid_nolabel_iter): + """ + Train an epoch. + """ + raise NotImplementedError + + def evaluate(self, data_label_iter, data_nolabel_iter, need_save=True): + raise NotImplementedError + + def infer(self, data_iter, num_batches=None): + raise NotImplementedError + + def save(self, is_best=False): + """ save """ + train_state = { + 'epoch': self.epoch, + 'batch_num': self.batch_num, + 'best_valid_metric': self.best_valid_metric, + 'optimizer': self.optimizer.state_dict() + } + if self.lr_scheduler is not None: + train_state['lr_scheduler'] = self.lr_scheduler.state_dict() + + # Save checkpoint + if self.save_checkpoint: + model_file = os.path.join(self.save_dir, + f'state_epoch_{self.epoch}.model') + torch.save(self.model.state_dict(), model_file) + self.logger.info(f"Saved model state to '{model_file}'") + + train_file = os.path.join(self.save_dir, + f'state_epoch_{self.epoch}.train') + torch.save(train_state, train_file) + self.logger.info(f"Saved train state to '{train_file}'") + + # Save current best model + if is_best: + best_model_file = os.path.join(self.save_dir, 'best.model') + torch.save(self.model.state_dict(), best_model_file) + best_train_file = os.path.join(self.save_dir, 'best.train') + torch.save(train_state, best_train_file) + self.logger.info( + f"Saved best model state to '{best_model_file}' with new best valid metric " + f'{self.valid_metric_name.upper()}={self.best_valid_metric:.3f}' + ) + + def load(self): + """ load """ + + def _load_model_state(): + model_state_dict = torch.load( + f'{self.func_model.init_checkpoint}.model', + map_location=lambda storage, loc: storage) + + if 'module.' in list(model_state_dict.keys())[0]: + new_model_state_dict = OrderedDict() + for k, v in model_state_dict.items(): + assert k[:7] == 'module.' + new_model_state_dict[k[7:]] = v + model_state_dict = new_model_state_dict + + new_model_state_dict = OrderedDict() + parameters = { + name: param + for name, param in self.func_model.named_parameters() + } + for name, param in model_state_dict.items(): + if name in parameters: + if param.shape != parameters[name].shape: + assert hasattr(param, 'numpy') + arr = param.numpy() + z = np.random.normal( + scale=self.func_model.initializer_range, + size=parameters[name].shape).astype('float32') + if name == 'embedder.token_embedding.weight': + z[-param.shape[0]:] = arr + print( + f'part of parameter({name}) random normlize initialize' + ) + else: + if z.shape[0] < param.shape[0]: + z = arr[:z.shape[0]] + print(f'part of parameter({name}) are dropped') + else: + z[:param.shape[0]] = arr + print( + f'part of parameter({name}) random normlize initialize' + ) + dtype, device = param.dtype, param.device + z = torch.tensor(z, dtype=dtype, device=device) + new_model_state_dict[name] = z + else: + new_model_state_dict[name] = param + else: + print(f'parameter({name}) are dropped') + model_state_dict = new_model_state_dict + + for name in parameters: + if name not in model_state_dict: + if parameters[name].requires_grad: + print(f'parameter({name}) random normlize initialize') + z = np.random.normal( + scale=self.func_model.initializer_range, + size=parameters[name].shape).astype('float32') + dtype, device = parameters[name].dtype, parameters[ + name].device + model_state_dict[name] = torch.tensor( + z, dtype=dtype, device=device) + else: + model_state_dict[name] = parameters[name] + + self.func_model.load_state_dict(model_state_dict) + self.logger.info( + f"Loaded model state from '{self.func_model.init_checkpoint}.model'" + ) + + def _load_train_state(): + train_file = f'{self.func_model.init_checkpoint}.train' + if os.path.exists(train_file): + train_state_dict = torch.load( + train_file, map_location=lambda storage, loc: storage) + self.epoch = train_state_dict['epoch'] + self.best_valid_metric = train_state_dict['best_valid_metric'] + if self.optimizer is not None and 'optimizer' in train_state_dict: + self.optimizer.load_state_dict( + train_state_dict['optimizer']) + if self.lr_scheduler is not None and 'lr_scheduler' in train_state_dict: + self.lr_scheduler.load_state_dict( + train_state_dict['lr_scheduler']) + self.logger.info( + f"Loaded train state from '{train_file}' with (epoch-{self.epoch} " + f'best_valid_metric={self.best_valid_metric:.3f})') + else: + self.logger.info(f'Loaded no train state') + + if self.func_model.init_checkpoint is None: + self.logger.info(f'Loaded no model !!!') + return + + _load_model_state() + _load_train_state() + + +class IntentTrainer(Trainer): + + def __init__(self, model, to_tensor, config, reader=None): + super(IntentTrainer, self).__init__(model, to_tensor, config, reader) + self.example = config.Model.example + self.can_norm = config.Trainer.can_norm + + def can_normalization(self, y_pred, y_true, ex_data_iter): + # 预测结果,计算修正前准确率 + acc_original = np.mean([y_pred.argmax(1) == y_true]) + message = 'original acc: %s' % acc_original + + # 评价每个预测结果的不确定性 + k = 3 + y_pred_topk = np.sort(y_pred, axis=1)[:, -k:] + y_pred_topk /= y_pred_topk.sum(axis=1, keepdims=True) + y_pred_uncertainty = -(y_pred_topk * + np.log(y_pred_topk)).sum(1) / np.log(k) + + # 选择阈值,划分高、低置信度两部分 + # print(np.sort(y_pred_uncertainty)[-100:].tolist()) + threshold = 0.7 + y_pred_confident = y_pred[y_pred_uncertainty < threshold] + y_pred_unconfident = y_pred[y_pred_uncertainty >= threshold] + y_true_confident = y_true[y_pred_uncertainty < threshold] + y_true_unconfident = y_true[y_pred_uncertainty >= threshold] + + # 显示两部分各自的准确率 + # 一般而言,高置信度集准确率会远高于低置信度的 + acc_confident = (y_pred_confident.argmax(1) == y_true_confident).mean() \ + if len(y_true_confident) else 0. + acc_unconfident = (y_pred_unconfident.argmax(1) == y_true_unconfident).mean() \ + if len(y_true_unconfident) else 0. + message += ' (%s) confident acc: %s' % (len(y_true_confident), + acc_confident) + message += ' (%s) unconfident acc: %s' % (len(y_true_unconfident), + acc_unconfident) + + # 从训练集统计先验分布 + prior = np.zeros(self.func_model.num_intent) + for _, (batch, batch_size) in ex_data_iter: + for intent_label in batch['intent_label']: + prior[intent_label] += 1. + + prior /= prior.sum() + + # 逐个修改低置信度样本,并重新评价准确率 + right, alpha, iters = 0, 1, 1 + for i, y in enumerate(y_pred_unconfident): + Y = np.concatenate([y_pred_confident, y[None]], axis=0) + for j in range(iters): + Y = Y**alpha + Y /= Y.mean(axis=0, keepdims=True) + Y *= prior[None] + Y /= Y.sum(axis=1, keepdims=True) + y = Y[-1] + if y.argmax() == y_true_unconfident[i]: + right += 1 + + # 输出修正后的准确率 + acc_final = (acc_confident * len(y_pred_confident) + + right) / len(y_pred) + if len(y_pred_unconfident): + message += ' new unconfident acc: %s' % ( + right / len(y_pred_unconfident)) + else: + message += ' no unconfident predictions' + message += ' final acc: %s' % acc_final + return acc_original, acc_final, message + + def train_epoch(self, train_label_iter, train_nolabel_iter, + valid_label_iter, valid_nolabel_iter): + """ + Train an epoch. + """ + times = [] + self.epoch += 1 + self.batch_metrics_tracker_label.clear() + self.token_metrics_tracker_label.clear() + self.batch_metrics_tracker_nolabel.clear() + self.token_metrics_tracker_nolabel.clear() + + num_label_batches = len(train_label_iter) + num_nolabel_batches = len( + train_nolabel_iter) if train_nolabel_iter is not None else 0 + num_batches = max(num_label_batches, num_nolabel_batches) + + train_label_iter_loop = iter(train_label_iter) + train_nolabel_iter_loop = iter( + train_nolabel_iter) if train_nolabel_iter is not None else None + report_for_unlabeled_data = True if train_nolabel_iter is not None else False + + for batch_id in range(1, num_batches + 1): + # Do a training iteration + start_time = time.time() + batch_list, batch_size_list, with_label_list, loss_list, metrics_list = [], [], [], [], [] + data_file_list = [] + + # collect batch for labeled data + try: + data_file_label, ( + batch_label, + batch_size_label) = next(train_label_iter_loop) + except StopIteration: + train_label_iter_loop = iter(train_label_iter) + data_file_label, ( + batch_label, + batch_size_label) = next(train_label_iter_loop) + batch_list.append(batch_label) + batch_size_list.append(batch_size_label) + with_label_list.append(True) + data_file_list.append(data_file_label) + + # collect batch for unlabeled data + if train_nolabel_iter is not None: + try: + data_file_nolabel, ( + batch_nolabel, + batch_size_nolabel) = next(train_nolabel_iter_loop) + except StopIteration: + train_nolabel_iter_loop = iter(train_nolabel_iter) + data_file_nolabel, ( + batch_nolabel, + batch_size_nolabel) = next(train_nolabel_iter_loop) + batch_list.append(batch_nolabel) + batch_size_list.append(batch_size_nolabel) + with_label_list.append(False) + data_file_list.append(data_file_nolabel) + + # forward labeled batch and unlabeled batch and collect outputs, respectively + for (batch, batch_size, with_label, data_file) in \ + zip(batch_list, batch_size_list, with_label_list, data_file_list): + batch = type(batch)( + map(lambda kv: (kv[0], self.to_tensor(kv[1])), + batch.items())) + if self.example and with_label: + current_dataset = train_label_iter.data_file_to_dataset[ + data_file] + example_batch = self.reader.retrieve_examples( + dataset=current_dataset, + labels=batch['intent_label'], + inds=batch['ids'], + task='intent') + example_batch = type(example_batch)( + map(lambda kv: (kv[0], self.to_tensor(kv[1])), + example_batch.items())) + for k, v in example_batch.items(): + batch[k] = v + batch['epoch'] = self.epoch + batch['num_steps'] = self.batch_num + metrics = self.model( + batch, + is_training=True, + with_label=with_label, + data_file=data_file) + loss, metrics = self.balance_metrics( + metrics=metrics, batch_size=batch_size) + loss_list.append(loss) + metrics_list.append(metrics) + + # combine loss for labeled data and unlabeled data + # TODO change the computation of combined loss of labeled batch and unlabeled batch + loss = loss_list[0] if len( + loss_list) == 1 else loss_list[0] + loss_list[1] + + # optimization procedure + self.func_model._optimize( + loss, optimizer=self.optimizer, lr_scheduler=self.lr_scheduler) + elapsed = time.time() - start_time + times.append(elapsed) + self.batch_num += 1 + + # track metrics and log temporary message + for (batch_size, metrics, + with_label) in zip(batch_size_list, metrics_list, + with_label_list): + self.track_and_log_message( + metrics=metrics, + batch_id=batch_id, + batch_size=batch_size, + num_batches=num_batches, + times=times, + with_label=with_label) + + # evaluate + if self.valid_steps > 0 and valid_label_iter is not None and valid_nolabel_iter is not None \ + and batch_id % self.valid_steps == 0: + self.evaluate( + data_label_iter=valid_label_iter, + data_nolabel_iter=valid_nolabel_iter) + + # compute accuracy for valid dataset + accuracy = self.infer( + data_iter=valid_label_iter, ex_data_iter=train_label_iter) + + # report summary message and save checkpoints + self.save_and_log_message( + report_for_unlabeled_data, cur_valid_metric=-accuracy) + + def infer(self, data_iter, num_batches=None, ex_data_iter=None): + """ + Inference interface. + """ + self.logger.info('Generation starts ...') + infer_save_file = os.path.join(self.save_dir, + f'infer_{self.epoch}.result.json') + + # Inference + batch_cnt = 0 + pred, true = [], [] + outputs, labels = [], [] + begin_time = time.time() + + with torch.no_grad(): + if self.example: + for _, (batch, batch_size) in tqdm( + ex_data_iter, desc='Building train memory.'): + batch = type(batch)( + map(lambda kv: (kv[0], self.to_tensor(kv[1])), + batch.items())) + result = self.model.infer(inputs=batch) + result = { + name: result[name].cpu().detach().numpy() + for name in result + } + outputs.append(torch.from_numpy(result['features'])) + labels += batch['intent_label'].tolist() + + mem = torch.cat(outputs, dim=0) + mem = mem.cuda() if self.func_model.use_gpu else mem + labels = torch.LongTensor(labels).unsqueeze(0) + labels = labels.cuda() if self.func_model.use_gpu else labels + self.logger.info(f'Memory size: {mem.size()}') + + for _, (batch, batch_size) in tqdm(data_iter, total=num_batches): + batch = type(batch)( + map(lambda kv: (kv[0], self.to_tensor(kv[1])), + batch.items())) + result = self.model.infer(inputs=batch) + result = { + name: result[name].cpu().detach().numpy() + for name in result + } + + if self.example: + features = torch.from_numpy(result['features']) + features = features.cuda( + ) if self.func_model.use_gpu else features + probs = torch.softmax(features.mm(mem.t()), dim=-1) + intent_probs = torch.zeros( + probs.size(0), self.func_model.num_intent) + intent_probs = intent_probs.cuda( + ) if self.func_model.use_gpu else intent_probs + intent_probs = intent_probs.scatter_add( + -1, labels.repeat(probs.size(0), 1), probs) + intent_probs = intent_probs.cpu().detach().numpy() + else: + intent_probs = result['intent_probs'] + + if self.can_norm: + pred += [intent_probs] + true += batch['intent_label'].cpu().detach().tolist() + else: + pred += np.argmax(intent_probs, axis=1).tolist() + true += batch['intent_label'].cpu().detach().tolist() + + batch_cnt += 1 + if batch_cnt == num_batches: + break + + if self.can_norm: + true = np.array(true) + pred = np.concatenate(pred, axis=0) + acc_original, acc_final, message = self.can_normalization( + y_pred=pred, y_true=true, ex_data_iter=ex_data_iter) + accuracy = max(acc_original, acc_final) + infer_results = { + 'accuracy': accuracy, + 'pred_labels': pred.tolist(), + 'message': message + } + metrics_message = f'Accuracy: {accuracy} {message}' + else: + accuracy = sum(p == t for p, t in zip(pred, true)) / len(pred) + infer_results = {'accuracy': accuracy, 'pred_labels': pred} + metrics_message = f'Accuracy: {accuracy}' + + self.logger.info(f'Saved inference results to {infer_save_file}') + with open(infer_save_file, 'w') as fp: + json.dump(infer_results, fp, indent=2) + message_prefix = f'[Infer][{self.epoch}]' + time_cost = f'TIME-{time.time() - begin_time:.3f}' + message = ' '.join([message_prefix, metrics_message, time_cost]) + self.logger.info(message) + return accuracy + + def track_and_log_message(self, metrics, batch_id, batch_size, num_batches, + times, with_label): + # track metrics + batch_metrics_tracker = self.batch_metrics_tracker_label if with_label else self.batch_metrics_tracker_nolabel + token_metrics_tracker = self.token_metrics_tracker_label if with_label else self.token_metrics_tracker_nolabel + + metrics = { + k: v.cpu().detach().numpy() if isinstance(v, torch.Tensor) else v + for k, v in metrics.items() + } + mlm_num = metrics.pop('mlm_num', 0) + + batch_metrics = {k: v for k, v in metrics.items() if 'token' not in k} + token_metrics = {k: v for k, v in metrics.items() if 'token' in k} + batch_metrics_tracker.update(batch_metrics, batch_size) + token_metrics_tracker.update(token_metrics, mlm_num) + + # log message + if self.log_steps > 0 and batch_id % self.log_steps == 0: + batch_metrics_message = batch_metrics_tracker.value() + token_metrics_message = token_metrics_tracker.value() + label_prefix = 'Labeled' if with_label else 'Unlabeled' + message_prefix = f'[Train][{self.epoch}][{batch_id}/{num_batches}][{label_prefix}]' + avg_time = f'AVG_Time-{sum(times[-self.log_steps:]) / self.log_steps:.3f}' + message = ' '.join([ + message_prefix, batch_metrics_message, token_metrics_message, + avg_time + ]) + self.logger.info(message) + + def save_and_log_message(self, + report_for_unlabeled_data, + cur_valid_metric=None): + # report message + batch_metrics_message = self.batch_metrics_tracker_label.summary() + token_metrics_message = self.token_metrics_tracker_label.summary() + message_prefix = f'[Valid][{self.epoch}][Labeled]' + message = ' '.join( + [message_prefix, batch_metrics_message, token_metrics_message]) + self.logger.info(message) + if report_for_unlabeled_data: + batch_metrics_message = self.batch_metrics_tracker_nolabel.summary( + ) + token_metrics_message = self.token_metrics_tracker_nolabel.summary( + ) + message_prefix = f'[Valid][{self.epoch}][Unlabeled]' + message = ' '.join( + [message_prefix, batch_metrics_message, token_metrics_message]) + self.logger.info(message) + + # save checkpoints + assert cur_valid_metric is not None + if self.is_decreased_valid_metric: + is_best = cur_valid_metric < self.best_valid_metric + else: + is_best = cur_valid_metric > self.best_valid_metric + if is_best: + self.best_valid_metric = cur_valid_metric + self.save(is_best) + + def balance_metrics(self, metrics, batch_size): + if self.gpu > 1: + for metric in metrics: + if metric is not None: + assert len(metric) == self.gpu + + intent_loss, mlm, token_mlm, mlm_num, kl, con = metrics + metrics = {} + + intent_loss = torch.mean(intent_loss) + metrics['intent_loss'] = intent_loss + loss = intent_loss + + if mlm is not None: + mlm_num = torch.sum(mlm_num) + token_mlm = torch.sum(mlm) * (batch_size / self.gpu) / mlm_num + mlm = torch.mean(mlm) + metrics['mlm_num'] = mlm_num + metrics['token_mlm'] = token_mlm + metrics['mlm'] = mlm + loss = loss + (token_mlm if self.func_model.token_loss else + mlm) * self.func_model.mlm_ratio + + if kl is not None: + kl = torch.mean(kl) + metrics['kl'] = kl + loss = loss + kl * self.func_model.kl_ratio + + if con is not None: + con = torch.mean(con) + metrics['con'] = con + loss = loss + con + + metrics['loss'] = loss + + assert 'loss' in metrics + return metrics['loss'], metrics + + def load(self): + """ load """ + + def _load_model_state(): + model_state_dict = torch.load( + f'{self.func_model.init_checkpoint}', + map_location=lambda storage, loc: storage) + + if 'module.' in list(model_state_dict.keys())[0]: + new_model_state_dict = OrderedDict() + for k, v in model_state_dict.items(): + assert k[:7] == 'module.' + new_model_state_dict[k[7:]] = v + model_state_dict = new_model_state_dict + + new_model_state_dict = OrderedDict() + parameters = { + name: param + for name, param in self.func_model.named_parameters() + } + for name, param in model_state_dict.items(): + if name in parameters: + if param.shape != parameters[name].shape: + assert hasattr(param, 'numpy') + arr = param.numpy() + z = np.random.normal( + scale=self.func_model.initializer_range, + size=parameters[name].shape).astype('float32') + if name == 'embedder.token_embedding.weight': + z[-param.shape[0]:] = arr + print( + f'part of parameter({name}) random normlize initialize' + ) + else: + if z.shape[0] < param.shape[0]: + z = arr[:z.shape[0]] + print(f'part of parameter({name}) are dropped') + else: + z[:param.shape[0]] = arr + print( + f'part of parameter({name}) random normlize initialize' + ) + dtype, device = param.dtype, param.device + z = torch.tensor(z, dtype=dtype, device=device) + new_model_state_dict[name] = z + else: + new_model_state_dict[name] = param + else: + print(f'parameter({name}) are dropped') + model_state_dict = new_model_state_dict + + for name in parameters: + if name not in model_state_dict: + if parameters[name].requires_grad: + print(f'parameter({name}) random normlize initialize') + z = np.random.normal( + scale=self.func_model.initializer_range, + size=parameters[name].shape).astype('float32') + dtype, device = parameters[name].dtype, parameters[ + name].device + model_state_dict[name] = torch.tensor( + z, dtype=dtype, device=device) + else: + model_state_dict[name] = parameters[name] + + self.func_model.load_state_dict(model_state_dict) + self.logger.info( + f"Loaded model state from '{self.func_model.init_checkpoint}.model'" + ) + + def _load_train_state(): + train_file = f'{self.func_model.init_checkpoint}.train' + if os.path.exists(train_file): + train_state_dict = torch.load( + train_file, map_location=lambda storage, loc: storage) + self.epoch = train_state_dict['epoch'] + self.best_valid_metric = train_state_dict['best_valid_metric'] + if self.optimizer is not None and 'optimizer' in train_state_dict: + self.optimizer.load_state_dict( + train_state_dict['optimizer']) + if self.lr_scheduler is not None and 'lr_scheduler' in train_state_dict: + self.lr_scheduler.load_state_dict( + train_state_dict['lr_scheduler']) + self.logger.info( + f"Loaded train state from '{train_file}' with (epoch-{self.epoch} " + f'best_valid_metric={self.best_valid_metric:.3f})') + else: + self.logger.info(f'Loaded no train state') + + if self.func_model.init_checkpoint is None: + self.logger.info(f'Loaded no model !!!') + return + + if self.do_train: + _load_model_state() + return + + if self.do_infer: + _load_model_state() + _load_train_state() diff --git a/maas_lib/utils/constant.py b/maas_lib/utils/constant.py index bd4f8e17..17e76309 100644 --- a/maas_lib/utils/constant.py +++ b/maas_lib/utils/constant.py @@ -39,6 +39,7 @@ class Tasks(object): conversational = 'conversational' text_generation = 'text-generation' dialog_generation = 'dialog-generation' + dialog_intent = 'dialog-intent' table_question_answering = 'table-question-answering' feature_extraction = 'feature-extraction' sentence_similarity = 'sentence-similarity' diff --git a/maas_lib/utils/nlp/space/criterions.py b/maas_lib/utils/nlp/space/criterions.py new file mode 100644 index 00000000..60f98457 --- /dev/null +++ b/maas_lib/utils/nlp/space/criterions.py @@ -0,0 +1,52 @@ +import torch +import torch.nn.functional as F +from torch.nn.modules.loss import _Loss + + +def compute_kl_loss(p, q, filter_scores=None): + p_loss = F.kl_div( + F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none') + q_loss = F.kl_div( + F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none') + + # You can choose whether to use function "sum" and "mean" depending on your task + p_loss = p_loss.sum(dim=-1) + q_loss = q_loss.sum(dim=-1) + + # mask is for filter mechanism + if filter_scores is not None: + p_loss = filter_scores * p_loss + q_loss = filter_scores * q_loss + + p_loss = p_loss.mean() + q_loss = q_loss.mean() + + loss = (p_loss + q_loss) / 2 + return loss + + +class CatKLLoss(_Loss): + """ + CatKLLoss + """ + + def __init__(self, reduction='mean'): + super(CatKLLoss, self).__init__() + assert reduction in ['none', 'sum', 'mean'] + self.reduction = reduction + + def forward(self, log_qy, log_py): + """ + KL(qy|py) = Eq[qy * log(q(y) / p(y))] + + log_qy: (batch_size, latent_size) + log_py: (batch_size, latent_size) + """ + qy = torch.exp(log_qy) + kl = torch.sum(qy * (log_qy - log_py), dim=1) + + if self.reduction == 'mean': + kl = kl.mean() + elif self.reduction == 'sum': + kl = kl.sum() + return kl diff --git a/maas_lib/utils/nlp/space/utils.py b/maas_lib/utils/nlp/space/utils.py index df0107a1..8448e943 100644 --- a/maas_lib/utils/nlp/space/utils.py +++ b/maas_lib/utils/nlp/space/utils.py @@ -7,6 +7,30 @@ import numpy as np from . import ontology +def max_lens(X): + lens = [len(X)] + while isinstance(X[0], list): + lens.append(max(map(len, X))) + X = [x for xs in X for x in xs] + return lens + + +def list2np(X: object, padding: object = 0, dtype: object = 'int64') -> object: + shape = max_lens(X) + ret = np.full(shape, padding, dtype=np.int32) + + if len(shape) == 1: + ret = np.array(X) + elif len(shape) == 2: + for i, x in enumerate(X): + ret[i, :len(x)] = np.array(x) + elif len(shape) == 3: + for i, xs in enumerate(X): + for j, x in enumerate(xs): + ret[i, j, :len(x)] = np.array(x) + return ret.astype(dtype) + + def clean_replace(s, r, t, forward=True, backward=False): def clean_replace_single(s, r, t, forward, backward, sidx=0): diff --git a/tests/case/nlp/dialog_intent_case.py b/tests/case/nlp/dialog_intent_case.py new file mode 100644 index 00000000..d3442bde --- /dev/null +++ b/tests/case/nlp/dialog_intent_case.py @@ -0,0 +1,4 @@ +test_case = [ + 'How do I locate my card?', + 'I still have not received my new card, I ordered over a week ago.' +] diff --git a/tests/pipelines/nlp/test_dialog_generation.py b/tests/pipelines/nlp/test_dialog_generation.py index 1baee3df..413e70b5 100644 --- a/tests/pipelines/nlp/test_dialog_generation.py +++ b/tests/pipelines/nlp/test_dialog_generation.py @@ -19,14 +19,14 @@ class DialogGenerationTest(unittest.TestCase): def test_run(self): - modeldir = '/Users/yangliu/Desktop/space-dialog-generation' - - preprocessor = DialogGenerationPreprocessor(model_dir=modeldir) - model = DialogGenerationModel( - model_dir=modeldir, - text_field=preprocessor.text_field, - config=preprocessor.config) - print(model.forward(None)) + # modeldir = '/Users/yangliu/Desktop/space-dialog-generation' + # + # preprocessor = DialogGenerationPreprocessor(model_dir=modeldir) + # model = DialogGenerationModel( + # model_dir=modeldir, + # text_field=preprocessor.text_field, + # config=preprocessor.config) + # print(model.forward(None)) # pipeline = DialogGenerationPipeline(model=model, preprocessor=preprocessor) # # history_dialog_info = {} @@ -39,6 +39,7 @@ class DialogGenerationTest(unittest.TestCase): # result = pipeline(user_question, history=history_dialog_info) # # # # print('sys : {}'.format(result['pred_answer'])) + print('test') if __name__ == '__main__': diff --git a/tests/pipelines/nlp/test_dialog_intent.py b/tests/pipelines/nlp/test_dialog_intent.py new file mode 100644 index 00000000..f94a5f67 --- /dev/null +++ b/tests/pipelines/nlp/test_dialog_intent.py @@ -0,0 +1,41 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import os.path as osp +import tempfile +import unittest + +from tests.case.nlp.dialog_generation_case import test_case + +from maas_lib.models.nlp import DialogIntentModel +from maas_lib.pipelines import DialogIntentPipeline, pipeline +from maas_lib.preprocessors import DialogIntentPreprocessor + + +class DialogGenerationTest(unittest.TestCase): + + def test_run(self): + + modeldir = '/Users/yangliu/Desktop/space-dialog-intent' + + preprocessor = DialogIntentPreprocessor(model_dir=modeldir) + model = DialogIntentModel( + model_dir=modeldir, + text_field=preprocessor.text_field, + config=preprocessor.config) + print(model.forward(None)) + # pipeline = DialogGenerationPipeline(model=model, preprocessor=preprocessor) + # + # history_dialog_info = {} + # for step, item in enumerate(test_case['sng0073']['log']): + # user_question = item['user'] + # print('user: {}'.format(user_question)) + # + # # history_dialog_info = merge(history_dialog_info, + # # result) if step > 0 else {} + # result = pipeline(user_question, history=history_dialog_info) + # # + # # print('sys : {}'.format(result['pred_answer'])) + + +if __name__ == '__main__': + unittest.main()