| @@ -65,5 +65,7 @@ class DialogIntentModel(Model): | |||||
| """ | """ | ||||
| from numpy import array, float32 | from numpy import array, float32 | ||||
| import torch | import torch | ||||
| print('--forward--') | |||||
| result = self.trainer.forward(input) | |||||
| return {} | |||||
| return result | |||||
| @@ -3,14 +3,14 @@ from typing import Any, Dict, Optional | |||||
| from maas_lib.models.nlp import DialogIntentModel | from maas_lib.models.nlp import DialogIntentModel | ||||
| from maas_lib.preprocessors import DialogIntentPreprocessor | from maas_lib.preprocessors import DialogIntentPreprocessor | ||||
| from maas_lib.utils.constant import Tasks | from maas_lib.utils.constant import Tasks | ||||
| from ...base import Model, Tensor | |||||
| from ...base import Input, Pipeline | |||||
| from ...builder import PIPELINES | from ...builder import PIPELINES | ||||
| __all__ = ['DialogIntentPipeline'] | __all__ = ['DialogIntentPipeline'] | ||||
| @PIPELINES.register_module(Tasks.dialog_intent, module_name=r'space-intent') | @PIPELINES.register_module(Tasks.dialog_intent, module_name=r'space-intent') | ||||
| class DialogIntentPipeline(Model): | |||||
| class DialogIntentPipeline(Pipeline): | |||||
| def __init__(self, model: DialogIntentModel, | def __init__(self, model: DialogIntentModel, | ||||
| preprocessor: DialogIntentPreprocessor, **kwargs): | preprocessor: DialogIntentPreprocessor, **kwargs): | ||||
| @@ -23,9 +23,9 @@ class DialogIntentPipeline(Model): | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | super().__init__(model=model, preprocessor=preprocessor, **kwargs) | ||||
| self.model = model | self.model = model | ||||
| self.tokenizer = preprocessor.tokenizer | |||||
| # self.tokenizer = preprocessor.tokenizer | |||||
| def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||||
| """process the prediction results | """process the prediction results | ||||
| Args: | Args: | ||||
| @@ -35,16 +35,4 @@ class DialogIntentPipeline(Model): | |||||
| Dict[str, str]: the prediction results | Dict[str, str]: the prediction results | ||||
| """ | """ | ||||
| vocab_size = len(self.tokenizer.vocab) | |||||
| pred_list = inputs['predictions'] | |||||
| pred_ids = pred_list[0][0].cpu().numpy().tolist() | |||||
| for j in range(len(pred_ids)): | |||||
| if pred_ids[j] >= vocab_size: | |||||
| pred_ids[j] = 100 | |||||
| pred = self.tokenizer.convert_ids_to_tokens(pred_ids) | |||||
| pred_string = ''.join(pred).replace( | |||||
| '##', | |||||
| '').split('[SEP]')[0].replace('[CLS]', | |||||
| '').replace('[SEP]', | |||||
| '').replace('[UNK]', '') | |||||
| return {'pred_string': pred_string} | |||||
| return inputs | |||||
| @@ -43,5 +43,7 @@ class DialogIntentPreprocessor(Preprocessor): | |||||
| Returns: | Returns: | ||||
| Dict[str, Any]: the preprocessed data | Dict[str, Any]: the preprocessed data | ||||
| """ | """ | ||||
| samples = self.text_field.preprocessor([data]) | |||||
| samples, _ = self.text_field.collate_fn_multi_turn(samples) | |||||
| return self.text_field.preprocessor(data) | |||||
| return samples | |||||
| @@ -506,6 +506,28 @@ class IntentTrainer(Trainer): | |||||
| self.save_and_log_message( | self.save_and_log_message( | ||||
| report_for_unlabeled_data, cur_valid_metric=-accuracy) | report_for_unlabeled_data, cur_valid_metric=-accuracy) | ||||
| def forward(self, batch): | |||||
| outputs, labels = [], [] | |||||
| pred, true = [], [] | |||||
| with torch.no_grad(): | |||||
| batch = type(batch)( | |||||
| map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items())) | |||||
| result = self.model.infer(inputs=batch) | |||||
| result = { | |||||
| name: result[name].cpu().detach().numpy() | |||||
| for name in result | |||||
| } | |||||
| intent_probs = result['intent_probs'] | |||||
| if self.can_norm: | |||||
| pred += [intent_probs] | |||||
| true += batch['intent_label'].cpu().detach().tolist() | |||||
| else: | |||||
| pred += np.argmax(intent_probs, axis=1).tolist() | |||||
| true += batch['intent_label'].cpu().detach().tolist() | |||||
| return {'pred': pred} | |||||
| def infer(self, data_iter, num_batches=None, ex_data_iter=None): | def infer(self, data_iter, num_batches=None, ex_data_iter=None): | ||||
| """ | """ | ||||
| Inference interface. | Inference interface. | ||||
| @@ -4,11 +4,12 @@ import os.path as osp | |||||
| import tempfile | import tempfile | ||||
| import unittest | import unittest | ||||
| from tests.case.nlp.dialog_generation_case import test_case | |||||
| from tests.case.nlp.dialog_intent_case import test_case | |||||
| from maas_lib.models.nlp import DialogIntentModel | from maas_lib.models.nlp import DialogIntentModel | ||||
| from maas_lib.pipelines import DialogIntentPipeline, pipeline | from maas_lib.pipelines import DialogIntentPipeline, pipeline | ||||
| from maas_lib.preprocessors import DialogIntentPreprocessor | from maas_lib.preprocessors import DialogIntentPreprocessor | ||||
| from maas_lib.utils.constant import Tasks | |||||
| class DialogGenerationTest(unittest.TestCase): | class DialogGenerationTest(unittest.TestCase): | ||||
| @@ -22,19 +23,12 @@ class DialogGenerationTest(unittest.TestCase): | |||||
| model_dir=modeldir, | model_dir=modeldir, | ||||
| text_field=preprocessor.text_field, | text_field=preprocessor.text_field, | ||||
| config=preprocessor.config) | config=preprocessor.config) | ||||
| print(model.forward(None)) | |||||
| # pipeline = DialogGenerationPipeline(model=model, preprocessor=preprocessor) | |||||
| # | |||||
| # history_dialog_info = {} | |||||
| # for step, item in enumerate(test_case['sng0073']['log']): | |||||
| # user_question = item['user'] | |||||
| # print('user: {}'.format(user_question)) | |||||
| # | |||||
| # # history_dialog_info = merge(history_dialog_info, | |||||
| # # result) if step > 0 else {} | |||||
| # result = pipeline(user_question, history=history_dialog_info) | |||||
| # # | |||||
| # # print('sys : {}'.format(result['pred_answer'])) | |||||
| pipeline1 = DialogIntentPipeline( | |||||
| model=model, preprocessor=preprocessor) | |||||
| # pipeline1 = pipeline(task=Tasks.dialog_intent, model=model, preprocessor=preprocessor) | |||||
| for item in test_case: | |||||
| pipeline1(item) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||