diff --git a/maas_lib/models/nlp/space/dialog_intent_model.py b/maas_lib/models/nlp/space/dialog_intent_model.py index 226c5da8..747f6a20 100644 --- a/maas_lib/models/nlp/space/dialog_intent_model.py +++ b/maas_lib/models/nlp/space/dialog_intent_model.py @@ -65,5 +65,7 @@ class DialogIntentModel(Model): """ from numpy import array, float32 import torch + print('--forward--') + result = self.trainer.forward(input) - return {} + return result diff --git a/maas_lib/pipelines/nlp/space/dialog_intent_pipeline.py b/maas_lib/pipelines/nlp/space/dialog_intent_pipeline.py index e9d10551..99862311 100644 --- a/maas_lib/pipelines/nlp/space/dialog_intent_pipeline.py +++ b/maas_lib/pipelines/nlp/space/dialog_intent_pipeline.py @@ -3,14 +3,14 @@ from typing import Any, Dict, Optional from maas_lib.models.nlp import DialogIntentModel from maas_lib.preprocessors import DialogIntentPreprocessor from maas_lib.utils.constant import Tasks -from ...base import Model, Tensor +from ...base import Input, Pipeline from ...builder import PIPELINES __all__ = ['DialogIntentPipeline'] @PIPELINES.register_module(Tasks.dialog_intent, module_name=r'space-intent') -class DialogIntentPipeline(Model): +class DialogIntentPipeline(Pipeline): def __init__(self, model: DialogIntentModel, preprocessor: DialogIntentPreprocessor, **kwargs): @@ -23,9 +23,9 @@ class DialogIntentPipeline(Model): super().__init__(model=model, preprocessor=preprocessor, **kwargs) self.model = model - self.tokenizer = preprocessor.tokenizer + # self.tokenizer = preprocessor.tokenizer - def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: """process the prediction results Args: @@ -35,16 +35,4 @@ class DialogIntentPipeline(Model): Dict[str, str]: the prediction results """ - vocab_size = len(self.tokenizer.vocab) - pred_list = inputs['predictions'] - pred_ids = pred_list[0][0].cpu().numpy().tolist() - for j in range(len(pred_ids)): - if pred_ids[j] >= vocab_size: - pred_ids[j] = 100 - pred = self.tokenizer.convert_ids_to_tokens(pred_ids) - pred_string = ''.join(pred).replace( - '##', - '').split('[SEP]')[0].replace('[CLS]', - '').replace('[SEP]', - '').replace('[UNK]', '') - return {'pred_string': pred_string} + return inputs diff --git a/maas_lib/preprocessors/space/dialog_intent_preprocessor.py b/maas_lib/preprocessors/space/dialog_intent_preprocessor.py index b8c5d34e..8dba5075 100644 --- a/maas_lib/preprocessors/space/dialog_intent_preprocessor.py +++ b/maas_lib/preprocessors/space/dialog_intent_preprocessor.py @@ -43,5 +43,7 @@ class DialogIntentPreprocessor(Preprocessor): Returns: Dict[str, Any]: the preprocessed data """ + samples = self.text_field.preprocessor([data]) + samples, _ = self.text_field.collate_fn_multi_turn(samples) - return self.text_field.preprocessor(data) + return samples diff --git a/maas_lib/trainers/nlp/space/trainers/intent_trainer.py b/maas_lib/trainers/nlp/space/trainers/intent_trainer.py index f736a739..9db24e6d 100644 --- a/maas_lib/trainers/nlp/space/trainers/intent_trainer.py +++ b/maas_lib/trainers/nlp/space/trainers/intent_trainer.py @@ -506,6 +506,28 @@ class IntentTrainer(Trainer): self.save_and_log_message( report_for_unlabeled_data, cur_valid_metric=-accuracy) + def forward(self, batch): + outputs, labels = [], [] + pred, true = [], [] + + with torch.no_grad(): + batch = type(batch)( + map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items())) + result = self.model.infer(inputs=batch) + result = { + name: result[name].cpu().detach().numpy() + for name in result + } + intent_probs = result['intent_probs'] + if self.can_norm: + pred += [intent_probs] + true += batch['intent_label'].cpu().detach().tolist() + else: + pred += np.argmax(intent_probs, axis=1).tolist() + true += batch['intent_label'].cpu().detach().tolist() + + return {'pred': pred} + def infer(self, data_iter, num_batches=None, ex_data_iter=None): """ Inference interface. diff --git a/tests/pipelines/nlp/test_dialog_intent.py b/tests/pipelines/nlp/test_dialog_intent.py index f94a5f67..86e78d06 100644 --- a/tests/pipelines/nlp/test_dialog_intent.py +++ b/tests/pipelines/nlp/test_dialog_intent.py @@ -4,11 +4,12 @@ import os.path as osp import tempfile import unittest -from tests.case.nlp.dialog_generation_case import test_case +from tests.case.nlp.dialog_intent_case import test_case from maas_lib.models.nlp import DialogIntentModel from maas_lib.pipelines import DialogIntentPipeline, pipeline from maas_lib.preprocessors import DialogIntentPreprocessor +from maas_lib.utils.constant import Tasks class DialogGenerationTest(unittest.TestCase): @@ -22,19 +23,12 @@ class DialogGenerationTest(unittest.TestCase): model_dir=modeldir, text_field=preprocessor.text_field, config=preprocessor.config) - print(model.forward(None)) - # pipeline = DialogGenerationPipeline(model=model, preprocessor=preprocessor) - # - # history_dialog_info = {} - # for step, item in enumerate(test_case['sng0073']['log']): - # user_question = item['user'] - # print('user: {}'.format(user_question)) - # - # # history_dialog_info = merge(history_dialog_info, - # # result) if step > 0 else {} - # result = pipeline(user_question, history=history_dialog_info) - # # - # # print('sys : {}'.format(result['pred_answer'])) + pipeline1 = DialogIntentPipeline( + model=model, preprocessor=preprocessor) + # pipeline1 = pipeline(task=Tasks.dialog_intent, model=model, preprocessor=preprocessor) + + for item in test_case: + pipeline1(item) if __name__ == '__main__':